[tf.data] Add the optimization of disable_prefetch_legacy_autotune
. Also decouple the operations from autotune_buffer_sizes
.
PiperOrigin-RevId: 334500451 Change-Id: I4945477ff53a8b3d6eb5dc6511675b39ff0e0b4c
This commit is contained in:
parent
1f63a2e610
commit
7e25c5c8b2
tensorflow
core/grappler/optimizers/data
BUILDautotune_buffer_sizes.ccautotune_buffer_sizes.hautotune_buffer_sizes_test.ccdisable_prefetch_legacy_autotune.ccdisable_prefetch_legacy_autotune.hdisable_prefetch_legacy_autotune_test.ccmeta_optimizer.cc
python/data/experimental
@ -17,6 +17,7 @@ cc_library(
|
||||
deps = [
|
||||
":autotune_buffer_sizes",
|
||||
":disable_intra_op_parallelism",
|
||||
":disable_prefetch_legacy_autotune",
|
||||
":enable_gradient_descent",
|
||||
":filter_fusion",
|
||||
":filter_with_random_uniform_fusion",
|
||||
@ -131,6 +132,41 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "disable_prefetch_legacy_autotune",
|
||||
srcs = ["disable_prefetch_legacy_autotune.cc"],
|
||||
hdrs = ["disable_prefetch_legacy_autotune.h"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
":optimizer_base",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + tf_protos_all(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "disable_prefetch_legacy_autotune_test",
|
||||
srcs = ["disable_prefetch_legacy_autotune_test.cc"],
|
||||
deps = [
|
||||
":disable_prefetch_legacy_autotune",
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "enable_gradient_descent",
|
||||
srcs = ["enable_gradient_descent.cc"],
|
||||
|
@ -74,6 +74,7 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
|
||||
{buffer_size_node->name(), 0},
|
||||
{autotune_value->name(), 0}));
|
||||
node.mutable_attr()->at(kBufferSizeMin).set_i(initial_buffer_size);
|
||||
stats->num_changes++;
|
||||
}
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
@ -97,6 +98,7 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
|
||||
for (const auto& async_dataset_op : kAsyncDatasetOps) {
|
||||
if (node.op() == async_dataset_op) {
|
||||
async_datasets.push_back(&node);
|
||||
stats->num_changes++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -124,15 +126,6 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster,
|
||||
graph.UpdateFanouts(async_dataset_node->name(), added_node->name()));
|
||||
}
|
||||
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
// 3) Switch from using legacy algorithm to using performance model
|
||||
// based algorithm for autotuning of all `prefetch` nodes.
|
||||
if (node.op() == kPrefetchDataset) {
|
||||
(*node.mutable_attr())[kLegacyAutotune].set_b(false);
|
||||
stats->num_changes++;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -33,9 +33,6 @@ constexpr char kAutotune[] = "autotune";
|
||||
// 2. If there exists any `prefetch(buffer_size=N)` for `N>=0`, it will replace
|
||||
// the transformation with autotunable version of `prefetch` which uses N as
|
||||
// the minimum size of the buffer.
|
||||
//
|
||||
// 3. Switches from using legacy autotuning for `prefetch` to using an algorithm
|
||||
// based on the performance model.
|
||||
class AutotuneBufferSizes : public TFDataOptimizerBase {
|
||||
public:
|
||||
AutotuneBufferSizes() = default;
|
||||
|
@ -105,7 +105,8 @@ TEST_P(SimpleInject, AutotuneBufferSizesTest) {
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output));
|
||||
int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output);
|
||||
const NodeDef prefetch_node = output.node(index);
|
||||
EXPECT_FALSE(prefetch_node.attr().at("legacy_autotune").b());
|
||||
EXPECT_TRUE(prefetch_node.attr().find("legacy_autotune") ==
|
||||
prefetch_node.attr().end());
|
||||
EXPECT_EQ(prefetch_node.input_size(), 2);
|
||||
NodeDef async_node = output.node(
|
||||
graph_utils::FindGraphNodeWithName(prefetch_node.input(0), output));
|
||||
@ -226,7 +227,8 @@ TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
|
||||
graph_utils::FindGraphNodeWithName(new_map_node3.input(0), output));
|
||||
EXPECT_EQ(new_prefetch_node2.op(), "PrefetchDataset");
|
||||
EXPECT_EQ(new_prefetch_node2.input_size(), 2);
|
||||
EXPECT_FALSE(new_prefetch_node2.attr().at("legacy_autotune").b());
|
||||
EXPECT_TRUE(new_prefetch_node2.attr().find("legacy_autotune") ==
|
||||
new_prefetch_node2.attr().end());
|
||||
EXPECT_TRUE(new_prefetch_node2.attr().find("buffer_size_min") ==
|
||||
new_prefetch_node2.attr().end());
|
||||
NodeDef new_buffer_size_val2 = output.node(
|
||||
@ -241,7 +243,8 @@ TEST_P(MultipleNodes, AutotuneBufferSizesTest) {
|
||||
graph_utils::FindGraphNodeWithName(new_map_node2.input(0), output));
|
||||
EXPECT_EQ(new_prefetch_node1.op(), "PrefetchDataset");
|
||||
EXPECT_EQ(new_prefetch_node1.input_size(), 2);
|
||||
EXPECT_FALSE(new_prefetch_node1.attr().at("legacy_autotune").b());
|
||||
EXPECT_EQ(new_prefetch_node1.attr().at("legacy_autotune").b(),
|
||||
legacy_autotune);
|
||||
EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(),
|
||||
(initial_buffer_size == -1 ? 0 : initial_buffer_size));
|
||||
NodeDef new_buffer_size_val1 = output.node(
|
||||
|
@ -0,0 +1,75 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
|
||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
|
||||
|
||||
} // namespace
|
||||
|
||||
Status DisablePrefetchLegacyAutotune::OptimizeAndCollectStats(
|
||||
Cluster* cluster, const GrapplerItem& item, GraphDef* output,
|
||||
OptimizationStats* stats) {
|
||||
*output = item.graph;
|
||||
if (!autotune_) {
|
||||
VLOG(1) << "The optimization disable_prefetch_legacy_autotune is not "
|
||||
"applied if autotune is off.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
MutableGraphView graph(output);
|
||||
|
||||
for (NodeDef& node : *output->mutable_node()) {
|
||||
if (node.op() == kPrefetchDataset) {
|
||||
if (node.attr().find(kLegacyAutotune) == node.attr().end() ||
|
||||
node.attr().at(kLegacyAutotune).b()) {
|
||||
// If `legacy_autotune` does not exist as attr or it is true, set it to
|
||||
// false.
|
||||
(*node.mutable_attr())[kLegacyAutotune].set_b(false);
|
||||
stats->num_changes++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DisablePrefetchLegacyAutotune::Feedback(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
const GraphDef& optimize_output,
|
||||
double result) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(DisablePrefetchLegacyAutotune,
|
||||
"disable_prefetch_legacy_autotune");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
constexpr char kAutotune[] = "autotune";
|
||||
|
||||
// This optimization disables the lagacy autotune option for PrefetchDataset.
|
||||
class DisablePrefetchLegacyAutotune : public TFDataOptimizerBase {
|
||||
public:
|
||||
DisablePrefetchLegacyAutotune() = default;
|
||||
~DisablePrefetchLegacyAutotune() override = default;
|
||||
|
||||
string name() const override { return "disable_prefetch_legacy_autotune"; };
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
if (!config) return Status::OK();
|
||||
|
||||
const string& autotune = config->parameter_map().at(kAutotune).s();
|
||||
if (autotune == "true") {
|
||||
autotune_ = true;
|
||||
} else if (autotune == "false") {
|
||||
autotune_ = false;
|
||||
} else {
|
||||
return errors::InvalidArgument("Received an invalid value for parameter ",
|
||||
kAutotune, ": ", autotune);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output,
|
||||
OptimizationStats* stats) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
bool autotune_ = true;
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_
|
@ -0,0 +1,91 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
using test::function::NDef;
|
||||
|
||||
Status OptimizeWithDisablePrefetchLegacyAutotune(const GrapplerItem &item,
|
||||
GraphDef *output,
|
||||
bool autotune) {
|
||||
DisablePrefetchLegacyAutotune optimizer;
|
||||
RewriterConfig_CustomGraphOptimizer config;
|
||||
if (autotune) {
|
||||
(*config.mutable_parameter_map())["autotune"].set_s("true");
|
||||
} else {
|
||||
(*config.mutable_parameter_map())["autotune"].set_s("false");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(optimizer.Init(&config));
|
||||
return optimizer.Optimize(nullptr, item, output);
|
||||
}
|
||||
|
||||
class RewriteTest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(RewriteTest, DisablePrefetchLegacyAutotune) {
|
||||
const bool autotune = GetParam();
|
||||
GrapplerItem item;
|
||||
|
||||
item.graph = test::function::GDef({
|
||||
NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"},
|
||||
{{"output_shapes", gtl::ArraySlice<TensorShape>{}},
|
||||
{"output_types", gtl::ArraySlice<DataType>{}}}),
|
||||
NDef("prefetch1", "PrefetchDataset", {"range"},
|
||||
{{"legacy_autotune", true}}),
|
||||
NDef("prefetch2", "PrefetchDataset", {"prefetch1"},
|
||||
{{"legacy_autotune", false}}),
|
||||
NDef("prefetch3", "PrefetchDataset", {"prefetch2"}, {}),
|
||||
});
|
||||
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(
|
||||
OptimizeWithDisablePrefetchLegacyAutotune(item, &output, autotune));
|
||||
|
||||
NodeDef prefetch_node1 =
|
||||
output.node(graph_utils::FindGraphNodeWithName("prefetch1", output));
|
||||
EXPECT_EQ(prefetch_node1.attr().at("legacy_autotune").b(), !autotune);
|
||||
NodeDef prefetch_node2 =
|
||||
output.node(graph_utils::FindGraphNodeWithName("prefetch2", output));
|
||||
EXPECT_FALSE(prefetch_node2.attr().at("legacy_autotune").b());
|
||||
NodeDef prefetch_node3 =
|
||||
output.node(graph_utils::FindGraphNodeWithName("prefetch3", output));
|
||||
if (autotune) {
|
||||
EXPECT_FALSE(prefetch_node3.attr().at("legacy_autotune").b());
|
||||
} else {
|
||||
EXPECT_TRUE(prefetch_node3.attr().find("legacy_autotune") ==
|
||||
prefetch_node3.attr().end());
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, RewriteTest, ::testing::Values(false, true));
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -35,7 +35,7 @@ using ConfigMap =
|
||||
std::map<string, tensorflow::RewriterConfig_CustomGraphOptimizer>;
|
||||
|
||||
// tf.data optimizations, in the order we want to perform them.
|
||||
constexpr std::array<const char*, 18> kTFDataOptimizations = {
|
||||
constexpr std::array<const char*, 19> kTFDataOptimizations = {
|
||||
"noop_elimination",
|
||||
"disable_intra_op_parallelism",
|
||||
"shuffle_and_repeat_fusion",
|
||||
@ -53,6 +53,7 @@ constexpr std::array<const char*, 18> kTFDataOptimizations = {
|
||||
"reorder_data_discarding_ops",
|
||||
"slack",
|
||||
"autotune_buffer_sizes",
|
||||
"disable_prefetch_legacy_autotune",
|
||||
"enable_gradient_descent"};
|
||||
|
||||
// Parses a list of string optimizer configurations into a map from
|
||||
|
@ -393,6 +393,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"make_sloppy",
|
||||
"latency_all_edges",
|
||||
"slack",
|
||||
"disable_prefetch_legacy_autotune",
|
||||
]
|
||||
expected_optimizations_disabled = []
|
||||
expected_optimizations_default = []
|
||||
@ -441,6 +442,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"make_sloppy",
|
||||
"latency_all_edges",
|
||||
"slack",
|
||||
"disable_prefetch_legacy_autotune",
|
||||
]
|
||||
expected_optimizations_default = []
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
@ -470,6 +472,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options.experimental_optimization.autotune_ram_budget = 999999999
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
self.assertIn("autotune_buffer_sizes", options._graph_rewrites().enabled)
|
||||
self.assertIn("disable_prefetch_legacy_autotune",
|
||||
options._graph_rewrites().enabled)
|
||||
|
||||
autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings()
|
||||
self.assertTrue(autotune)
|
||||
self.assertEqual(algorithm,
|
||||
|
@ -300,8 +300,11 @@ class OptimizationOptions(options.OptionsBase):
|
||||
# equivalent to tuning the buffer sizes of the other asynchronous
|
||||
# transformations.
|
||||
result.enabled.append("autotune_buffer_sizes")
|
||||
result.enabled.append("disable_prefetch_legacy_autotune")
|
||||
|
||||
if self.autotune is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("autotune_buffer_sizes")
|
||||
result.disabled.append("disable_prefetch_legacy_autotune")
|
||||
|
||||
return result
|
||||
|
||||
@ -312,6 +315,7 @@ class OptimizationOptions(options.OptionsBase):
|
||||
graph_rewrite_configs = []
|
||||
autotune_only_optimizations = [
|
||||
"autotune_buffer_sizes",
|
||||
"disable_prefetch_legacy_autotune",
|
||||
"enable_gradient_descent",
|
||||
"map_parallelization"
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user