[tf.data] Modify the optimization autotune_buffer_sizes
so that it will inject autotuned PrefetchDatasets after non-prefetched asynchronous datasets. The optimization will also rewrite those existing non-autotuned PrefetchDatasets into autotuned with fixed start value and minimal value for the tunable parameter buffer_size
.
PiperOrigin-RevId: 342347611 Change-Id: I1d4fe8f00944595a6fb9b5b99a1679a493b32edf
This commit is contained in:
parent
5f1da7ab8c
commit
6c838fcbf8
tensorflow
core
grappler/optimizers/data
kernels/data
python/data/experimental/kernel_tests/optimization
@ -31,6 +31,7 @@ namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||
constexpr char kBufferSizeMin[] = "buffer_size_min";
|
||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||
|
||||
constexpr std::array<const char*, 7> kAsyncDatasetOps = {
|
||||
@ -54,8 +55,47 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
|
||||
}
|
||||
MutableGraphView graph(output);
|
||||
|
||||
// Add a const node with value kAutotune
|
||||
NodeDef* autotune_value =
|
||||
graph_utils::AddScalarConstNode(data::model::kAutotune, &graph);
|
||||
|
||||
absl::flat_hash_set<string> already_prefetched;
|
||||
|
||||
// 1) Collect about all existing `PrefetchDataset` nodes, replacing
|
||||
// `prefetch(N)` with `prefetch(AUTOTUNE, buffer_size_min=N)` for all N !=-1.
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
if (node.op() == kPrefetchDataset) {
|
||||
NodeDef* buffer_size_node = graph.GetNode(node.input(1));
|
||||
// We only consider to rewrite if `buffer_size` is constant.
|
||||
if (buffer_size_node->op() == "Const") {
|
||||
int64 initial_buffer_size =
|
||||
buffer_size_node->attr().at("value").tensor().int64_val(0);
|
||||
if (initial_buffer_size != data::model::kAutotune) {
|
||||
TF_RETURN_IF_ERROR(graph.UpdateFanin(node.name(),
|
||||
{buffer_size_node->name(), 0},
|
||||
{autotune_value->name(), 0}));
|
||||
node.mutable_attr()->at(kBufferSizeMin).set_i(initial_buffer_size);
|
||||
stats->num_changes++;
|
||||
}
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"The autotune_buffer_sizes rewrite does not currently support "
|
||||
"non-constant buffer_size input.");
|
||||
}
|
||||
NodeDef* prefetched_node = graph_utils::GetInputNode(node, graph);
|
||||
if (prefetched_node) {
|
||||
already_prefetched.insert(prefetched_node->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> async_datasets;
|
||||
// 2) Insert `prefetch(AUTOTUNE)` after all asynchronous transformations that
|
||||
// are not followed by a `prefetch` yet.
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (already_prefetched.find(node.name()) != already_prefetched.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& async_dataset_op : kAsyncDatasetOps) {
|
||||
if (node.op() == async_dataset_op) {
|
||||
async_datasets.push_back(&node);
|
||||
@ -67,10 +107,6 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
|
||||
|
||||
if (async_datasets.empty()) return Status::OK();
|
||||
|
||||
// Add a const node with value kAutotune
|
||||
NodeDef* autotune_value =
|
||||
graph_utils::AddScalarConstNode(data::model::kAutotune, &graph);
|
||||
|
||||
for (const NodeDef* async_dataset_node : async_datasets) {
|
||||
NodeDef prefetch_node;
|
||||
graph_utils::SetUniqueGraphNodeName(
|
||||
|
@ -24,8 +24,15 @@ namespace grappler {
|
||||
|
||||
constexpr char kAutotune[] = "autotune";
|
||||
|
||||
// This optimization adds `prefetch(AUTOTUNE)` after all asynchronous tf.data
|
||||
// transformations (e.g. parallel map, parallel interleave, and map + batch).
|
||||
// This optimization does the following:
|
||||
//
|
||||
// 1. Adds `prefetch(AUTOTUNE)` after all asynchronous tf.data transformations
|
||||
// (e.g. parallel map, parallel interleave, and map + batch) if they are not
|
||||
// followed by a `prefetch` yet.
|
||||
//
|
||||
// 2. If there exists any `prefetch(buffer_size=N)` for `N>=0`, it will replace
|
||||
// the transformation with autotunable version of `prefetch` which uses N as
|
||||
// the minimum size of the buffer.
|
||||
class AutotuneBufferSizes : public TFDataOptimizerBase {
|
||||
public:
|
||||
AutotuneBufferSizes() = default;
|
||||
|
@ -148,6 +148,116 @@ TEST_P(AutotuneSetting, AutotuneBufferSizesTest) {
|
||||
autotune);
|
||||
}
|
||||
|
||||
class MultipleNodes : public ::testing::TestWithParam<std::tuple<bool, int64>> {
|
||||
};
|
||||
|
||||
TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
|
||||
const bool legacy_autotune = std::get<0>(GetParam());
|
||||
const int64 initial_buffer_size = std::get<1>(GetParam());
|
||||
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
NodeDef *start_val = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_val = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
NodeDef *step_val = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
|
||||
std::vector<string> range_inputs(3);
|
||||
range_inputs[0] = start_val->name();
|
||||
range_inputs[1] = stop_val->name();
|
||||
range_inputs[2] = step_val->name();
|
||||
std::vector<std::pair<string, AttrValue>> range_attrs;
|
||||
NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset",
|
||||
range_inputs, range_attrs, &graph);
|
||||
|
||||
NodeDef *parallelism_val = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
std::vector<string> map_inputs1(2);
|
||||
map_inputs1[0] = range_node->name();
|
||||
map_inputs1[1] = parallelism_val->name();
|
||||
std::vector<std::pair<string, AttrValue>> map_attrs(4);
|
||||
AttrValue attr_val;
|
||||
SetAttrValue("value", &attr_val);
|
||||
map_attrs[0] = std::make_pair("f", attr_val);
|
||||
map_attrs[1] = std::make_pair("Targuments", attr_val);
|
||||
map_attrs[2] = std::make_pair("output_types", attr_val);
|
||||
map_attrs[3] = std::make_pair("output_shapes", attr_val);
|
||||
NodeDef *map_node1 = graph_utils::AddNode("map1", "ParallelMapDatasetV2",
|
||||
map_inputs1, map_attrs, &graph);
|
||||
|
||||
NodeDef *buffer_size_val =
|
||||
graph_utils::AddScalarConstNode<int64>(initial_buffer_size, &graph);
|
||||
std::vector<string> prefetch_inputs(2);
|
||||
prefetch_inputs[0] = map_node1->name();
|
||||
prefetch_inputs[1] = buffer_size_val->name();
|
||||
std::vector<std::pair<string, AttrValue>> prefetch_attrs(4);
|
||||
AttrValue legacy_autotune_attr;
|
||||
SetAttrValue(legacy_autotune, &legacy_autotune_attr);
|
||||
AttrValue buffer_size_min_attr;
|
||||
SetAttrValue(0, &buffer_size_min_attr);
|
||||
prefetch_attrs[0] = std::make_pair("legacy_autotune", legacy_autotune_attr);
|
||||
prefetch_attrs[1] = std::make_pair("buffer_size_min", buffer_size_min_attr);
|
||||
prefetch_attrs[2] = std::make_pair("output_types", attr_val);
|
||||
prefetch_attrs[3] = std::make_pair("output_shapes", attr_val);
|
||||
NodeDef *prefetch_node = graph_utils::AddNode(
|
||||
"prefetch", "PrefetchDataset", prefetch_inputs, prefetch_attrs, &graph);
|
||||
|
||||
std::vector<string> map_inputs2(2);
|
||||
map_inputs2[0] = prefetch_node->name();
|
||||
map_inputs2[1] = parallelism_val->name();
|
||||
NodeDef *map_node2 = graph_utils::AddNode("map2", "ParallelMapDatasetV2",
|
||||
map_inputs2, map_attrs, &graph);
|
||||
|
||||
std::vector<string> map_inputs3(1);
|
||||
map_inputs3[0] = map_node2->name();
|
||||
graph_utils::AddNode("map3", "MapDataset", map_inputs3, map_attrs, &graph);
|
||||
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true));
|
||||
|
||||
std::vector<int> prefetch_indices =
|
||||
graph_utils::FindAllGraphNodesWithOp("PrefetchDataset", output);
|
||||
EXPECT_EQ(prefetch_indices.size(), 2);
|
||||
|
||||
NodeDef new_map_node3 =
|
||||
output.node(graph_utils::FindGraphNodeWithName("map3", output));
|
||||
|
||||
NodeDef new_prefetch_node2 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_map_node3.input(0), output));
|
||||
EXPECT_EQ(new_prefetch_node2.op(), "PrefetchDataset");
|
||||
EXPECT_EQ(new_prefetch_node2.input_size(), 2);
|
||||
EXPECT_TRUE(new_prefetch_node2.attr().find("legacy_autotune") ==
|
||||
new_prefetch_node2.attr().end());
|
||||
EXPECT_TRUE(new_prefetch_node2.attr().find("buffer_size_min") ==
|
||||
new_prefetch_node2.attr().end());
|
||||
NodeDef new_buffer_size_val2 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(1), output));
|
||||
EXPECT_EQ(new_buffer_size_val2.attr().at("value").tensor().int64_val(0), -1);
|
||||
|
||||
NodeDef new_map_node2 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(0), output));
|
||||
EXPECT_EQ(new_map_node2.name(), "map2");
|
||||
|
||||
NodeDef new_prefetch_node1 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_map_node2.input(0), output));
|
||||
EXPECT_EQ(new_prefetch_node1.op(), "PrefetchDataset");
|
||||
EXPECT_EQ(new_prefetch_node1.input_size(), 2);
|
||||
EXPECT_EQ(new_prefetch_node1.attr().at("legacy_autotune").b(),
|
||||
legacy_autotune);
|
||||
EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(),
|
||||
(initial_buffer_size == -1 ? 0 : initial_buffer_size));
|
||||
NodeDef new_buffer_size_val1 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(1), output));
|
||||
EXPECT_EQ(new_buffer_size_val1.attr().at("value").tensor().int64_val(0), -1);
|
||||
|
||||
NodeDef new_map_node1 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(0), output));
|
||||
EXPECT_EQ(new_map_node1.name(), "map1");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, MultipleNodes,
|
||||
::testing::Combine(::testing::Values(true, false),
|
||||
::testing::Values(-1, 3)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, AutotuneSetting, ::testing::Values(false, true));
|
||||
|
||||
} // namespace
|
||||
|
@ -20,11 +20,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
|
||||
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size,
|
||||
int64 buffer_size_min)
|
||||
: buffer_limit_(initial_buffer_size) {
|
||||
if (initial_buffer_size == model::kAutotune) {
|
||||
mode_ = Mode::kUpswing;
|
||||
buffer_limit_ = 1;
|
||||
buffer_limit_ = std::max(int64{1}, buffer_size_min);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ namespace data {
|
||||
// PrefetchAutotuner is NOT thread safe.
|
||||
class PrefetchAutotuner {
|
||||
public:
|
||||
explicit PrefetchAutotuner(int64 initial_buffer_size);
|
||||
explicit PrefetchAutotuner(int64 initial_buffer_size, int64 buffer_size_min);
|
||||
|
||||
int64 buffer_limit() const { return buffer_limit_; }
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace data {
|
||||
namespace {
|
||||
|
||||
TEST(PrefetchAutotuner, Disabled) {
|
||||
PrefetchAutotuner t(2);
|
||||
PrefetchAutotuner t(2, 0);
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(0);
|
||||
t.RecordConsumption(2);
|
||||
@ -33,7 +33,7 @@ TEST(PrefetchAutotuner, Disabled) {
|
||||
}
|
||||
|
||||
TEST(PrefetchAutotuner, Enabled) {
|
||||
PrefetchAutotuner t(model::kAutotune);
|
||||
PrefetchAutotuner t(model::kAutotune, 0);
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to stay the same.
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
@ -58,7 +58,7 @@ TEST(PrefetchAutotuner, Enabled) {
|
||||
}
|
||||
|
||||
TEST(PrefetchAutotuner, EnabledSteady) {
|
||||
PrefetchAutotuner t(model::kAutotune);
|
||||
PrefetchAutotuner t(model::kAutotune, 0);
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
@ -80,6 +80,29 @@ TEST(PrefetchAutotuner, EnabledSteady) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PrefetchAutotuner, StartWithMin) {
|
||||
PrefetchAutotuner t(model::kAutotune, 2);
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(2); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to increase.
|
||||
EXPECT_EQ(4, t.buffer_limit());
|
||||
t.RecordConsumption(4); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(4, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to increase.
|
||||
EXPECT_EQ(8, t.buffer_limit());
|
||||
|
||||
// Never reach zero again.
|
||||
std::vector<size_t> consumption_values = {3, 5, 7, 1, 4, 6, 8, 3, 5, 1, 2, 4};
|
||||
for (int i = 0; i < consumption_values.size(); ++i) {
|
||||
t.RecordConsumption(consumption_values[i]);
|
||||
EXPECT_EQ(8, t.buffer_limit())
|
||||
<< "Failed at index " << i << " with value: " << consumption_values[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -46,6 +46,7 @@ namespace data {
|
||||
/* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
|
||||
/* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
|
||||
/* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -62,12 +63,13 @@ constexpr char kErrorMessageSuffix[] = ".error_message";
|
||||
class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
int64 slack_period, bool legacy_autotune)
|
||||
int64 slack_period, bool legacy_autotune, int64 buffer_size_min)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
buffer_size_(buffer_size),
|
||||
slack_period_(slack_period),
|
||||
legacy_autotune_(legacy_autotune) {
|
||||
legacy_autotune_(legacy_autotune),
|
||||
buffer_size_min_(buffer_size_min) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -114,10 +116,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
b->BuildAttrValue(slack_period_, &slack_period_attr);
|
||||
AttrValue legacy_autotune_attr;
|
||||
b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
|
||||
AttrValue buffer_size_min_attr;
|
||||
b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, buffer_size},
|
||||
{std::make_pair(kSlackPeriod, slack_period_attr),
|
||||
std::make_pair(kLegacyAutotune, legacy_autotune_attr)},
|
||||
std::make_pair(kLegacyAutotune, legacy_autotune_attr),
|
||||
std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -129,8 +135,12 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
: DatasetIterator<Dataset>(params),
|
||||
mu_(std::make_shared<mutex>()),
|
||||
cond_var_(std::make_shared<condition_variable>()),
|
||||
auto_tuner_(params.dataset->buffer_size_),
|
||||
buffer_size_min_(params.dataset->buffer_size_min_),
|
||||
auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
|
||||
legacy_autotune_(params.dataset->legacy_autotune_),
|
||||
// If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
|
||||
// to avoid the created node to be collected as tunable nodes in the
|
||||
// autotuning optimization.
|
||||
buffer_size_(std::make_shared<model::SharedState>(
|
||||
legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
|
||||
cond_var_)) {
|
||||
@ -145,7 +155,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (buffer_size_->value == model::kAutotune) {
|
||||
buffer_size_->value = 0;
|
||||
buffer_size_->value = buffer_size_min_;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(), [this]() { CancelThreads(); },
|
||||
@ -218,7 +228,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeAsyncKnownRatioNode(
|
||||
std::move(args),
|
||||
/*ratio=*/1,
|
||||
{model::MakeParameter(kBufferSize, buffer_size_, /*min=*/0,
|
||||
{model::MakeParameter(kBufferSize, buffer_size_,
|
||||
/*min=*/buffer_size_min_,
|
||||
/*max=*/std::numeric_limits<int64>::max())});
|
||||
}
|
||||
|
||||
@ -536,6 +547,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
|
||||
const std::shared_ptr<condition_variable> cond_var_;
|
||||
const int64 buffer_size_min_;
|
||||
PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
|
||||
std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
|
||||
@ -561,6 +573,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
// Determines whether legacy autotuning should be used.
|
||||
const bool legacy_autotune_ = true;
|
||||
|
||||
// If autotune is enabled, determines the minimal value of `buffer_size`
|
||||
// parameter.
|
||||
const int64 buffer_size_min_ = 0;
|
||||
|
||||
TraceMeMetadata traceme_metadata_;
|
||||
};
|
||||
|
||||
@ -572,6 +588,9 @@ PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
|
||||
if (ctx->HasAttr(kLegacyAutotune)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
|
||||
}
|
||||
if (ctx->HasAttr(kBufferSizeMin)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
|
||||
}
|
||||
}
|
||||
|
||||
void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
@ -588,8 +607,8 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
metrics::RecordTFDataAutotune(kDatasetType);
|
||||
}
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, input, buffer_size, slack_period_, legacy_autotune_);
|
||||
*output = new Dataset(ctx, input, buffer_size, slack_period_,
|
||||
legacy_autotune_, buffer_size_min_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -31,6 +31,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kSlackPeriod = "slack_period";
|
||||
static constexpr const char* const kLegacyAutotune = "legacy_autotune";
|
||||
static constexpr const char* const kBufferSizeMin = "buffer_size_min";
|
||||
|
||||
explicit PrefetchDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
@ -42,6 +43,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
||||
class Dataset;
|
||||
int64 slack_period_ = 0;
|
||||
bool legacy_autotune_ = true;
|
||||
int64 buffer_size_min_ = 0;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -60,7 +60,8 @@ class PrefetchDatasetParams : public DatasetParams {
|
||||
attr_vector->emplace_back(PrefetchDatasetOp::kSlackPeriod, slack_period_);
|
||||
attr_vector->emplace_back(PrefetchDatasetOp::kLegacyAutotune,
|
||||
legacy_autotune_);
|
||||
attr_vector->emplace_back("buffer_size_min", buffer_size_min_);
|
||||
attr_vector->emplace_back(PrefetchDatasetOp::kBufferSizeMin,
|
||||
buffer_size_min_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -81,10 +81,12 @@ class AutotuneBufferSizesTest(test_base.DatasetTestBase,
|
||||
]))
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.prefetch(buffer_size=3)
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x + 1),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
|
Loading…
Reference in New Issue
Block a user