[tf.data] Modify the optimization inject_prefetch
into antotune_buffer_sizes
, which will inject autotuned PrefetchDatasets after non-prefetched asynchronous datasets. The optimization will also rewrite those existing non-autotuned PrefetchDatasets into autotuned with fixed start value and minimal value for the tunable parameter buffer_size
.
PiperOrigin-RevId: 333155780 Change-Id: I5bdd061100ed7bd564f5ad6f7d9cc8c9a723a9ed
This commit is contained in:
parent
ff8cec00c2
commit
1ae23a67ce
@ -15,11 +15,11 @@ cc_library(
|
||||
name = "data",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":autotune_buffer_sizes",
|
||||
":disable_intra_op_parallelism",
|
||||
":filter_fusion",
|
||||
":filter_with_random_uniform_fusion",
|
||||
":hoist_random_uniform",
|
||||
":inject_prefetch",
|
||||
":latency_all_edges",
|
||||
":make_sloppy",
|
||||
":map_and_batch_fusion",
|
||||
@ -59,6 +59,41 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "autotune_buffer_sizes",
|
||||
srcs = ["autotune_buffer_sizes.cc"],
|
||||
hdrs = ["autotune_buffer_sizes.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "autotune_buffer_sizes_test",
|
||||
srcs = ["autotune_buffer_sizes_test.cc"],
|
||||
deps = [
|
||||
":autotune_buffer_sizes",
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "disable_intra_op_parallelism",
|
||||
srcs = ["disable_intra_op_parallelism.cc"],
|
||||
@ -329,41 +364,6 @@ tf_cc_test(
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inject_prefetch",
|
||||
srcs = ["inject_prefetch.cc"],
|
||||
hdrs = ["inject_prefetch.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "inject_prefetch_test",
|
||||
srcs = ["inject_prefetch_test.cc"],
|
||||
deps = [
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
":inject_prefetch",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "latency_all_edges",
|
||||
srcs = ["latency_all_edges.cc"],
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/inject_prefetch.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h"
|
||||
|
||||
#include "tensorflow/core/framework/model.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -31,6 +31,7 @@ namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||
constexpr char kBufferSizeMin[] = "buffer_size_min";
|
||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||
|
||||
constexpr std::array<const char*, 7> kAsyncDatasetOps = {
|
||||
@ -42,15 +43,49 @@ constexpr std::array<const char*, 7> kAsyncDatasetOps = {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
|
||||
absl::flat_hash_set<string> already_prefetched;
|
||||
// 1) Collect about all existing `PrefetchDataset` nodes, replacing
|
||||
// `prefetch(N)` with `prefetch(AUTOTUNE, buffer_size_min=N)` for all N !=-1.
|
||||
for (NodeDef& node : *(output->mutable_node())) {
|
||||
if (node.op() == kPrefetchDataset) {
|
||||
NodeDef* buffer_size_node = graph.GetNode(node.input(1));
|
||||
// We only consider to rewrite if `buffer_size` is constant.
|
||||
if (buffer_size_node->op() == "Const") {
|
||||
int64 initial_buffer_size =
|
||||
buffer_size_node->attr().at("value").tensor().int64_val(0);
|
||||
if (initial_buffer_size != data::model::kAutotune) {
|
||||
buffer_size_node->mutable_attr()
|
||||
->at("value")
|
||||
.mutable_tensor()
|
||||
->set_int64_val(0, data::model::kAutotune);
|
||||
node.mutable_attr()->at(kBufferSizeMin).set_i(initial_buffer_size);
|
||||
}
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"The autotune_buffer_sizes rewrite does not currently support "
|
||||
"non-constant buffer_size input.");
|
||||
}
|
||||
NodeDef* prefetched_node = graph_utils::GetInputNode(node, graph);
|
||||
if (prefetched_node) {
|
||||
already_prefetched.insert(prefetched_node->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> async_datasets;
|
||||
// 2) Insert `prefetch(AUTOTUNE)` after all asynchronous transformations that
|
||||
// are not followed by a `prefetch` yet.
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (already_prefetched.find(node.name()) != already_prefetched.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& async_dataset_op : kAsyncDatasetOps) {
|
||||
if (node.op() == async_dataset_op) {
|
||||
async_datasets.push_back(&node);
|
||||
@ -75,7 +110,6 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster,
|
||||
*prefetch_node.mutable_input()->Add() = async_dataset_node->name();
|
||||
// `buffer_size` input
|
||||
*prefetch_node.mutable_input()->Add() = autotune_value->name();
|
||||
|
||||
for (const auto& attr_name : {"output_types", "output_shapes"}) {
|
||||
graph_utils::CopyAttribute(attr_name, *async_dataset_node,
|
||||
&prefetch_node);
|
||||
@ -87,6 +121,8 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster,
|
||||
}
|
||||
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
// 3) Switch from using legacy algorithm to using performance model
|
||||
// based algorithm for autotuning of all `prefetch` nodes.
|
||||
if (node.op() == kPrefetchDataset) {
|
||||
(*node.mutable_attr())[kLegacyAutotune].set_b(false);
|
||||
stats->num_changes++;
|
||||
@ -96,12 +132,13 @@ Status InjectPrefetch::OptimizeAndCollectStats(Cluster* cluster,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void InjectPrefetch::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) {
|
||||
void AutotuneBufferSizes::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output,
|
||||
double result) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(InjectPrefetch, "inject_prefetch");
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(AutotuneBufferSizes, "autotune_buffer_sizes");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -13,24 +13,32 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_PREFETCH_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_PREFETCH_H_
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTOTUNE_BUFFER_SIZES_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTOTUNE_BUFFER_SIZES_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// This optimization adds `Prefetch(AUTOTUNE)` after all asynchronous tf.data
|
||||
// transformations. This reduces the problem of tuning buffer sizes of these
|
||||
// asynchronous transformations to tuning buffer sizes of the prefetch
|
||||
// transformation.
|
||||
class InjectPrefetch : public TFDataOptimizerBase {
|
||||
// This optimization does the following:
|
||||
//
|
||||
// 1. Adds `prefetch(AUTOTUNE)` after all asynchronous tf.data transformations
|
||||
// (e.g. parallel map, parallel interleave, and map + batch) if they are not
|
||||
// followed by a `prefetch` yet.
|
||||
//
|
||||
// 2. If there exists any `prefetch(buffer_size=N)` for `N>=0`, it will replace
|
||||
// the transformation with autotunable version of `prefetch` which uses N as
|
||||
// the minimum size of the buffer.
|
||||
//
|
||||
// 3. Switches from using legacy autotuning for `prefetch` to using an algorithm
|
||||
// based on the performance model.
|
||||
class AutotuneBufferSizes : public TFDataOptimizerBase {
|
||||
public:
|
||||
InjectPrefetch() = default;
|
||||
~InjectPrefetch() override = default;
|
||||
AutotuneBufferSizes() = default;
|
||||
~AutotuneBufferSizes() override = default;
|
||||
|
||||
string name() const override { return "inject_prefetch"; };
|
||||
string name() const override { return "autotune_buffer_sizes"; };
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
@ -50,4 +58,4 @@ class InjectPrefetch : public TFDataOptimizerBase {
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_PREFETCH_H_
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTOTUNE_BUFFER_SIZES_H_
|
@ -0,0 +1,207 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class SimpleInject : public ::testing::TestWithParam<string> {};
|
||||
|
||||
TEST_P(SimpleInject, AutotuneBufferSizesTest) {
|
||||
const string async_dataset = GetParam();
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
if (async_dataset == "map") {
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
graph_tests_utils::MakeParallelMapNode(
|
||||
"map", "range", "num_parallel_calls", "XTimesTwo",
|
||||
/*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
} else if (async_dataset == "interleave") {
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
graph_tests_utils::MakeParallelInterleaveV2Node(
|
||||
"interleave", "range", "cycle_length", "block_length",
|
||||
"num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
} else if (async_dataset == "map_and_batch") {
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 32}, {"dtype", DT_INT64}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT64}}),
|
||||
NDef("drop_remainder", "Const", {},
|
||||
{{"value", false}, {"dtype", DT_BOOL}}),
|
||||
graph_tests_utils::MakeMapAndBatchNode(
|
||||
"map_and_batch", "range", "batch_size", "num_parallel_calls",
|
||||
"drop_remainder", "XTimesTwo")},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
}
|
||||
|
||||
AutotuneBufferSizes optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
|
||||
int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
|
||||
const NodeDef prefetch_node = output.node(index);
|
||||
EXPECT_FALSE(prefetch_node.attr().at("legacy_autotune").b());
|
||||
EXPECT_EQ(prefetch_node.input_size(), 2);
|
||||
NodeDef async_node = output.node(
|
||||
graph_utils::FindGraphNodeWithName(prefetch_node.input(0), output));
|
||||
EXPECT_EQ(async_node.name(), async_dataset);
|
||||
NodeDef buffer_size_val = output.node(
|
||||
graph_utils::FindGraphNodeWithName(prefetch_node.input(1), output));
|
||||
EXPECT_EQ(buffer_size_val.attr().at("value").tensor().int64_val(0), -1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, SimpleInject,
|
||||
::testing::Values("map", "interleave",
|
||||
"map_and_batch"));
|
||||
|
||||
class MultipleNodes : public ::testing::TestWithParam<std::tuple<bool, int64>> {
|
||||
};
|
||||
|
||||
TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
|
||||
const bool legacy_autotune = std::get<0>(GetParam());
|
||||
const int64 initial_buffer_size = std::get<1>(GetParam());
|
||||
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
NodeDef *start_val = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_val = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
NodeDef *step_val = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
|
||||
std::vector<string> range_inputs(3);
|
||||
range_inputs[0] = start_val->name();
|
||||
range_inputs[1] = stop_val->name();
|
||||
range_inputs[2] = step_val->name();
|
||||
std::vector<std::pair<string, AttrValue>> range_attrs;
|
||||
NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset",
|
||||
range_inputs, range_attrs, &graph);
|
||||
|
||||
NodeDef *parallelism_val = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
std::vector<string> map_inputs1(2);
|
||||
map_inputs1[0] = range_node->name();
|
||||
map_inputs1[1] = parallelism_val->name();
|
||||
std::vector<std::pair<string, AttrValue>> map_attrs(4);
|
||||
AttrValue attr_val;
|
||||
SetAttrValue("value", &attr_val);
|
||||
map_attrs[0] = std::make_pair("f", attr_val);
|
||||
map_attrs[1] = std::make_pair("Targuments", attr_val);
|
||||
map_attrs[2] = std::make_pair("output_types", attr_val);
|
||||
map_attrs[3] = std::make_pair("output_shapes", attr_val);
|
||||
NodeDef *map_node1 = graph_utils::AddNode("map1", "ParallelMapDatasetV2",
|
||||
map_inputs1, map_attrs, &graph);
|
||||
|
||||
NodeDef *buffer_size_val =
|
||||
graph_utils::AddScalarConstNode<int64>(initial_buffer_size, &graph);
|
||||
std::vector<string> prefetch_inputs(2);
|
||||
prefetch_inputs[0] = map_node1->name();
|
||||
prefetch_inputs[1] = buffer_size_val->name();
|
||||
std::vector<std::pair<string, AttrValue>> prefetch_attrs(2);
|
||||
AttrValue legacy_autotune_attr;
|
||||
SetAttrValue(legacy_autotune, &legacy_autotune_attr);
|
||||
prefetch_attrs[0] = std::make_pair("legacy_autotune", legacy_autotune_attr);
|
||||
AttrValue buffer_size_min_attr;
|
||||
SetAttrValue(0, &buffer_size_min_attr);
|
||||
prefetch_attrs[1] = std::make_pair("buffer_size_min", buffer_size_min_attr);
|
||||
NodeDef *prefetch_node = graph_utils::AddNode(
|
||||
"prefetch", "PrefetchDataset", prefetch_inputs, prefetch_attrs, &graph);
|
||||
|
||||
std::vector<string> map_inputs2(2);
|
||||
map_inputs2[0] = prefetch_node->name();
|
||||
map_inputs2[1] = parallelism_val->name();
|
||||
graph_utils::AddNode("map2", "ParallelMapDatasetV2", map_inputs2, map_attrs,
|
||||
&graph);
|
||||
|
||||
EXPECT_EQ(item.graph.node_size(), 9);
|
||||
|
||||
AutotuneBufferSizes optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_EQ(output.node_size(), 11);
|
||||
|
||||
std::vector<int> prefetch_indices =
|
||||
graph_utils::FindAllGraphNodesWithOp("PrefetchDataset", output);
|
||||
EXPECT_EQ(prefetch_indices.size(), 2);
|
||||
NodeDef new_prefetch_node1 = output.node(prefetch_indices[0]);
|
||||
NodeDef new_prefetch_node2 = output.node(prefetch_indices[1]);
|
||||
|
||||
EXPECT_EQ(new_prefetch_node1.input_size(), 2);
|
||||
EXPECT_FALSE(new_prefetch_node1.attr().at("legacy_autotune").b());
|
||||
EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(),
|
||||
(initial_buffer_size == -1 ? 0 : initial_buffer_size));
|
||||
NodeDef new_map_node1 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(0), output));
|
||||
EXPECT_EQ(new_map_node1.name(), "map1");
|
||||
NodeDef new_buffer_size_val1 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(1), output));
|
||||
EXPECT_EQ(new_buffer_size_val1.attr().at("value").tensor().int64_val(0), -1);
|
||||
|
||||
EXPECT_EQ(new_prefetch_node2.input_size(), 2);
|
||||
EXPECT_FALSE(new_prefetch_node2.attr().at("legacy_autotune").b());
|
||||
NodeDef new_map_node2 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(0), output));
|
||||
EXPECT_EQ(new_map_node2.name(), "map2");
|
||||
NodeDef new_buffer_size_val2 = output.node(
|
||||
graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(1), output));
|
||||
EXPECT_EQ(new_buffer_size_val2.attr().at("value").tensor().int64_val(0), -1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, MultipleNodes,
|
||||
::testing::Combine(::testing::Values(true, false),
|
||||
::testing::Values(-1, 3)));
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -1,116 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/inject_prefetch.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(MakeStateless, ParallelMap) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
graph_tests_utils::MakeParallelMapNode("map", "range",
|
||||
"num_parallel_calls", "XTimesTwo",
|
||||
/*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
|
||||
InjectPrefetch optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
|
||||
int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
|
||||
EXPECT_FALSE(output.node(index).attr().at("legacy_autotune").b());
|
||||
}
|
||||
|
||||
TEST(MakeStateless, ParallelInterleave) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("cycle_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("block_length", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT32}}),
|
||||
graph_tests_utils::MakeParallelInterleaveV2Node(
|
||||
"interleave", "range", "cycle_length", "block_length",
|
||||
"num_parallel_calls", "XTimesTwo", /*sloppy=*/false)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
|
||||
InjectPrefetch optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
|
||||
int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
|
||||
EXPECT_FALSE(output.node(index).attr().at("legacy_autotune").b());
|
||||
}
|
||||
|
||||
TEST(MakeStateless, MapAndBatch) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 32}, {"dtype", DT_INT64}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 1}, {"dtype", DT_INT64}}),
|
||||
NDef("drop_remainder", "Const", {},
|
||||
{{"value", false}, {"dtype", DT_BOOL}}),
|
||||
graph_tests_utils::MakeMapAndBatchNode(
|
||||
"map_and_batch", "range", "batch_size", "num_parallel_calls",
|
||||
"drop_remainder", "XTimesTwo")},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
|
||||
InjectPrefetch optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
|
||||
int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
|
||||
EXPECT_FALSE(output.node(index).attr().at("legacy_autotune").b());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -52,7 +52,7 @@ constexpr std::array<const char*, 17> kTFDataOptimizations = {
|
||||
"parallel_batch",
|
||||
"reorder_data_discarding_ops",
|
||||
"slack",
|
||||
"inject_prefetch"};
|
||||
"autotune_buffer_sizes"};
|
||||
|
||||
// Parses a list of string optimizer configurations into a map from
|
||||
// optimizer name -> rewriter config for that optimizer.
|
||||
|
@ -20,11 +20,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
|
||||
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size,
|
||||
int64 buffer_size_min)
|
||||
: buffer_limit_(initial_buffer_size) {
|
||||
if (initial_buffer_size == model::kAutotune) {
|
||||
mode_ = Mode::kUpswing;
|
||||
buffer_limit_ = 1;
|
||||
buffer_limit_ = std::max(1LL, buffer_size_min);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ namespace data {
|
||||
// PrefetchAutotuner is NOT thread safe.
|
||||
class PrefetchAutotuner {
|
||||
public:
|
||||
explicit PrefetchAutotuner(int64 initial_buffer_size);
|
||||
explicit PrefetchAutotuner(int64 initial_buffer_size, int64 buffer_size_min);
|
||||
|
||||
int64 buffer_limit() const { return buffer_limit_; }
|
||||
|
||||
|
@ -23,7 +23,7 @@ namespace data {
|
||||
namespace {
|
||||
|
||||
TEST(PrefetchAutotuner, Disabled) {
|
||||
PrefetchAutotuner t(2);
|
||||
PrefetchAutotuner t(2, 0);
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(0);
|
||||
t.RecordConsumption(2);
|
||||
@ -33,7 +33,7 @@ TEST(PrefetchAutotuner, Disabled) {
|
||||
}
|
||||
|
||||
TEST(PrefetchAutotuner, Enabled) {
|
||||
PrefetchAutotuner t(model::kAutotune);
|
||||
PrefetchAutotuner t(model::kAutotune, 0);
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to stay the same.
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
@ -58,9 +58,9 @@ TEST(PrefetchAutotuner, Enabled) {
|
||||
}
|
||||
|
||||
TEST(PrefetchAutotuner, EnabledSteady) {
|
||||
PrefetchAutotuner t(model::kAutotune);
|
||||
PrefetchAutotuner t(model::kAutotune, 0);
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to increase.
|
||||
t.RecordConsumption(0); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
t.RecordConsumption(1);
|
||||
EXPECT_EQ(1, t.buffer_limit());
|
||||
@ -80,6 +80,29 @@ TEST(PrefetchAutotuner, EnabledSteady) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PrefetchAutotuner, StartWithMin) {
|
||||
PrefetchAutotuner t(model::kAutotune, 2);
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(2);
|
||||
EXPECT_EQ(2, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to increase.
|
||||
EXPECT_EQ(4, t.buffer_limit());
|
||||
t.RecordConsumption(4); // Expect buffer limit to stay the same!
|
||||
EXPECT_EQ(4, t.buffer_limit());
|
||||
t.RecordConsumption(0); // Expect buffer limit to increase.
|
||||
EXPECT_EQ(8, t.buffer_limit());
|
||||
|
||||
// Never reach zero again.
|
||||
std::vector<size_t> consumption_values = {3, 5, 7, 1, 4, 6, 8, 3, 5, 1, 2, 4};
|
||||
for (int i = 0; i < consumption_values.size(); ++i) {
|
||||
t.RecordConsumption(consumption_values[i]);
|
||||
EXPECT_EQ(8, t.buffer_limit())
|
||||
<< "Failed at index " << i << " with value: " << consumption_values[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -57,12 +57,13 @@ constexpr char kErrorMessageSuffix[] = ".error_message";
|
||||
class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
int64 slack_period, bool legacy_autotune)
|
||||
int64 slack_period, bool legacy_autotune, int64 buffer_size_min)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
buffer_size_(buffer_size),
|
||||
slack_period_(slack_period),
|
||||
legacy_autotune_(legacy_autotune) {
|
||||
legacy_autotune_(legacy_autotune),
|
||||
buffer_size_min_(buffer_size_min) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
@ -109,10 +110,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
b->BuildAttrValue(slack_period_, &slack_period_attr);
|
||||
AttrValue legacy_autotune_attr;
|
||||
b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
|
||||
AttrValue buffer_size_min_attr;
|
||||
b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, buffer_size},
|
||||
{std::make_pair(kSlackPeriod, slack_period_attr),
|
||||
std::make_pair(kLegacyAutotune, legacy_autotune_attr)},
|
||||
std::make_pair(kLegacyAutotune, legacy_autotune_attr),
|
||||
std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -124,11 +129,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
: DatasetIterator<Dataset>(params),
|
||||
mu_(std::make_shared<mutex>()),
|
||||
cond_var_(std::make_shared<condition_variable>()),
|
||||
auto_tuner_(params.dataset->buffer_size_),
|
||||
buffer_size_min_(params.dataset->buffer_size_min_),
|
||||
auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
|
||||
legacy_autotune_(params.dataset->legacy_autotune_),
|
||||
buffer_size_(std::make_shared<model::SharedState>(
|
||||
legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
|
||||
cond_var_)) {
|
||||
params.dataset->buffer_size_, mu_, cond_var_)) {
|
||||
slack_us_ = 0;
|
||||
}
|
||||
|
||||
@ -140,7 +145,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
mutex_lock l(*mu_);
|
||||
if (buffer_size_->value == model::kAutotune) {
|
||||
buffer_size_->value = 0;
|
||||
buffer_size_->value = buffer_size_min_;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(), [this]() { CancelThreads(); },
|
||||
@ -213,7 +218,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeAsyncKnownRatioNode(
|
||||
std::move(args),
|
||||
/*ratio=*/1,
|
||||
{model::MakeParameter(kBufferSize, buffer_size_, /*min=*/0,
|
||||
{model::MakeParameter(kBufferSize, buffer_size_,
|
||||
/*min=*/buffer_size_min_,
|
||||
/*max=*/std::numeric_limits<int64>::max())});
|
||||
}
|
||||
|
||||
@ -517,6 +523,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
|
||||
const std::shared_ptr<condition_variable> cond_var_;
|
||||
const int64 buffer_size_min_;
|
||||
PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
|
||||
std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
|
||||
@ -542,6 +549,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
// Determines whether legacy autotuning should be used.
|
||||
const bool legacy_autotune_ = true;
|
||||
|
||||
// If autotune is enabled, determines the minimal value of `buffer_size`
|
||||
// parameter.
|
||||
const int64 buffer_size_min_ = 0;
|
||||
|
||||
TraceMeMetadata traceme_metadata_;
|
||||
};
|
||||
|
||||
@ -553,6 +564,9 @@ PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
|
||||
if (ctx->HasAttr(kLegacyAutotune)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
|
||||
}
|
||||
if (ctx->HasAttr(kBufferSizeMin)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
|
||||
}
|
||||
}
|
||||
|
||||
void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
@ -569,8 +583,8 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
metrics::RecordTFDataAutotune(kDatasetType);
|
||||
}
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, input, buffer_size, slack_period_, legacy_autotune_);
|
||||
*output = new Dataset(ctx, input, buffer_size, slack_period_,
|
||||
legacy_autotune_, buffer_size_min_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -31,6 +31,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kSlackPeriod = "slack_period";
|
||||
static constexpr const char* const kLegacyAutotune = "legacy_autotune";
|
||||
static constexpr const char* const kBufferSizeMin = "buffer_size_min";
|
||||
|
||||
explicit PrefetchDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
@ -42,6 +43,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
||||
class Dataset;
|
||||
int64 slack_period_ = 0;
|
||||
bool legacy_autotune_ = true;
|
||||
int64 buffer_size_min_ = 0;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -28,13 +28,14 @@ class PrefetchDatasetParams : public DatasetParams {
|
||||
PrefetchDatasetParams(T input_dataset_params, int64 buffer_size,
|
||||
DataTypeVector output_dtypes,
|
||||
std::vector<PartialTensorShape> output_shapes,
|
||||
int slack_period, bool legacy_autotune,
|
||||
string node_name)
|
||||
int64 slack_period, bool legacy_autotune,
|
||||
int64 buffer_size_min, string node_name)
|
||||
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
|
||||
std::move(node_name)),
|
||||
buffer_size_(buffer_size),
|
||||
slack_period_(slack_period),
|
||||
legacy_autotune_(legacy_autotune) {
|
||||
legacy_autotune_(legacy_autotune),
|
||||
buffer_size_min_(buffer_size_min) {
|
||||
input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
|
||||
iterator_prefix_ =
|
||||
name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
|
||||
@ -59,6 +60,8 @@ class PrefetchDatasetParams : public DatasetParams {
|
||||
attr_vector->emplace_back(PrefetchDatasetOp::kSlackPeriod, slack_period_);
|
||||
attr_vector->emplace_back(PrefetchDatasetOp::kLegacyAutotune,
|
||||
legacy_autotune_);
|
||||
attr_vector->emplace_back(PrefetchDatasetOp::kBufferSizeMin,
|
||||
buffer_size_min_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -68,8 +71,9 @@ class PrefetchDatasetParams : public DatasetParams {
|
||||
|
||||
private:
|
||||
int64 buffer_size_;
|
||||
int slack_period_;
|
||||
int64 slack_period_;
|
||||
bool legacy_autotune_;
|
||||
int64 buffer_size_min_;
|
||||
};
|
||||
|
||||
// Test case 1: positive buffer size.
|
||||
@ -85,6 +89,7 @@ PrefetchDatasetParams PrefetchDatasetParams1() {
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/0,
|
||||
/*legacy_autotune=*/true,
|
||||
/*buffer_size_min=*/0,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -101,6 +106,7 @@ PrefetchDatasetParams PrefetchDatasetParams2() {
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/0,
|
||||
/*legacy_autotune=*/true,
|
||||
/*buffer_size_min=*/0,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -117,6 +123,7 @@ PrefetchDatasetParams PrefetchDatasetParams3() {
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/0,
|
||||
/*legacy_autotune=*/true,
|
||||
/*buffer_size_min=*/0,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -133,6 +140,7 @@ PrefetchDatasetParams PrefetchDatasetParams4() {
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/5,
|
||||
/*legacy_autotune=*/true,
|
||||
/*buffer_size_min=*/0,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -149,6 +157,24 @@ PrefetchDatasetParams PrefetchDatasetParams5() {
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/5,
|
||||
/*legacy_autotune=*/false,
|
||||
/*buffer_size_min=*/0,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
// Test case 6: buffer_size_min > 0.
|
||||
PrefetchDatasetParams PrefetchDatasetParams6() {
|
||||
auto tensor_slice_dataset_params = TensorSliceDatasetParams(
|
||||
/*components=*/{CreateTensor<int64>(TensorShape{10, 1},
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})},
|
||||
/*node_name=*/"tensor_slice");
|
||||
return PrefetchDatasetParams(
|
||||
/*input_dataset_params=*/tensor_slice_dataset_params,
|
||||
/*buffer_size=*/-1,
|
||||
/*output_dtypes=*/{DT_INT64},
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/0,
|
||||
/*legacy_autotune=*/true,
|
||||
/*buffer_size_min=*/3,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -164,6 +190,7 @@ PrefetchDatasetParams InvalidBufferSizePrefetchDatasetParams() {
|
||||
/*output_shapes=*/{PartialTensorShape({1})},
|
||||
/*slack_period=*/0,
|
||||
/*legacy_autotune=*/true,
|
||||
/*buffer_size_min=*/0,
|
||||
/*node_name=*/kNodeName);
|
||||
}
|
||||
|
||||
@ -190,6 +217,11 @@ std::vector<GetNextTestCase<PrefetchDatasetParams>> GetNextTestCases() {
|
||||
{/*dataset_params=*/
|
||||
PrefetchDatasetParams5(),
|
||||
/*expected_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape{1}, {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}})},
|
||||
{/*dataset_params=*/
|
||||
PrefetchDatasetParams6(),
|
||||
/*expected_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape{1},
|
||||
{{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}})}};
|
||||
|
@ -184,6 +184,7 @@ REGISTER_OP("PrefetchDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("slack_period: int = 0")
|
||||
.Attr("legacy_autotune: bool = true")
|
||||
.Attr("buffer_size_min: int = 0")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size should be a scalar.
|
||||
|
@ -8,6 +8,24 @@ package(
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
tf_py_test(
|
||||
name = "autotune_buffer_sizes_test",
|
||||
size = "small",
|
||||
srcs = ["autotune_buffer_sizes_test.py"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "choose_fastest_dataset_test",
|
||||
size = "small",
|
||||
@ -123,24 +141,6 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "inject_prefetch_test",
|
||||
size = "small",
|
||||
srcs = ["inject_prefetch_test.py"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "latency_all_edges_test",
|
||||
size = "small",
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the `AutotuneBuffers` rewrite."""
|
||||
"""Tests for the `AutotuneBufferSizes` rewrite."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -26,7 +26,8 @@ from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
class AutotuneBufferSizesTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _enable_autotune_buffers(self, dataset):
|
||||
options = dataset_ops.Options()
|
||||
@ -74,11 +75,18 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.apply(
|
||||
testing.assert_next([
|
||||
"ParallelMap", "Prefetch", "ParallelInterleave", "Prefetch",
|
||||
"MapAndBatch", "Prefetch", "FiniteTake"
|
||||
"ParallelMap", "Prefetch", "ParallelMap", "Prefetch", "ParallelMap",
|
||||
"Prefetch", "ParallelInterleave", "Prefetch", "MapAndBatch",
|
||||
"Prefetch", "FiniteTake"
|
||||
]))
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.prefetch(buffer_size=3)
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.map(
|
||||
lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE)
|
||||
dataset = dataset.interleave(
|
||||
lambda x: dataset_ops.Dataset.from_tensors(x + 1),
|
||||
num_parallel_calls=dataset_ops.AUTOTUNE)
|
||||
@ -87,7 +95,7 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.batch(1)
|
||||
dataset = dataset.take(50)
|
||||
dataset = self._enable_autotune_buffers(dataset)
|
||||
self.assertDatasetProduces(dataset, [[i] for i in range(3, 53)])
|
||||
self.assertDatasetProduces(dataset, [[i] for i in range(5, 55)])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNoRegularMap(self):
|
@ -355,7 +355,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"parallel_batch",
|
||||
"shuffle_and_repeat_fusion",
|
||||
"map_vectorization",
|
||||
"inject_prefetch",
|
||||
"autotune_buffer_sizes",
|
||||
"make_sloppy",
|
||||
"latency_all_edges",
|
||||
"slack",
|
||||
@ -403,7 +403,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"parallel_batch",
|
||||
"shuffle_and_repeat_fusion",
|
||||
"map_vectorization",
|
||||
"inject_prefetch",
|
||||
"autotune_buffer_sizes",
|
||||
"make_sloppy",
|
||||
"latency_all_edges",
|
||||
"slack",
|
||||
@ -435,7 +435,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options.experimental_optimization.autotune_cpu_budget = 1000
|
||||
options.experimental_optimization.autotune_ram_budget = 999999999
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
self.assertIn("inject_prefetch", options._graph_rewrites().enabled)
|
||||
self.assertIn("autotune_buffer_sizes", options._graph_rewrites().enabled)
|
||||
autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings()
|
||||
self.assertTrue(autotune)
|
||||
self.assertEqual(algorithm,
|
||||
|
@ -299,9 +299,9 @@ class OptimizationOptions(options.OptionsBase):
|
||||
# prefetch transformations will be autotuned, though this is practically
|
||||
# equivalent to tuning the buffer sizes of the other asynchronous
|
||||
# transformations.
|
||||
result.enabled.append("inject_prefetch")
|
||||
result.enabled.append("autotune_buffer_sizes")
|
||||
if self.autotune is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("inject_prefetch")
|
||||
result.disabled.append("autotune_buffer_sizes")
|
||||
|
||||
return result
|
||||
|
||||
|
@ -2870,7 +2870,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "PrefetchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'legacy_autotune\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'legacy_autotune\', \'buffer_size_min\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Prelinearize"
|
||||
|
@ -2870,7 +2870,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "PrefetchDataset"
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'legacy_autotune\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'input_dataset\', \'buffer_size\', \'output_types\', \'output_shapes\', \'slack_period\', \'legacy_autotune\', \'buffer_size_min\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Prelinearize"
|
||||
|
Loading…
Reference in New Issue
Block a user