From 2ae0d560ee17654f17b488b93afe2563b2dc84ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 14 Oct 2020 21:23:33 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 337234276 Change-Id: Idd6c4db5c9658a468e1d2758b2e84bbcf401e1a6 --- tensorflow/core/framework/dataset.h | 5 - tensorflow/core/grappler/optimizers/BUILD | 1 - .../grappler/optimizers/data/fusion_utils.cc | 5 +- .../optimizers/data/meta_optimizer.cc | 2 +- .../grappler/optimizers/meta_optimizer.cc | 65 +------- .../optimizers/meta_optimizer_test.cc | 139 ------------------ 6 files changed, 10 insertions(+), 207 deletions(-) diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 6aa31909197..ba0c2b84a1a 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -75,11 +75,6 @@ constexpr char kTFDataResourceTag[] = "tfdata"; class DatasetBase; class SerializationContext; -inline bool IsTFDataFunction(const FunctionDef& func) { - return (func.attr().contains(data::kTFDataFunction) && - func.attr().at(data::kTFDataFunction).b()); -} - // Interface for reading values from a key-value store. // Used for restoring iterator state. This class is thread safe. // Please see comment on IteratorStateWriter for guidance around using the diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 70a5d59e2b3..7fc74b0aca5 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -717,7 +717,6 @@ tf_cuda_cc_test( ":custom_graph_optimizer_registry", ":meta_optimizer", "//tensorflow/cc:cc_ops", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index e09ea575ce4..d70a1ca486e 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -430,8 +430,9 @@ FunctionDef* FuseFunctions( const SetInputFn& set_input, const SetOutputFn& set_output, const SetNodesFn& set_nodes, FunctionDefLibrary* library) { auto has_attrs = [](const FunctionDef& func) { - return !(func.attr_size() == 0 || - (func.attr_size() == 1 && data::IsTFDataFunction(func))); + return !( + func.attr_size() == 0 || + (func.attr_size() == 1 && func.attr().contains(data::kTFDataFunction))); }; if (has_attrs(first_function) || has_attrs(second_function)) { return nullptr; // Functions with attributes are currently not supported. diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index b4317577cb8..cd46a7356ac 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -118,7 +118,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, for (const auto& name : flib.ListFunctionNames()) { auto* func = flib.Find(name); // Skip non tf.data functions. - if (!data::IsTFDataFunction(*func)) continue; + if (!func->attr().contains(data::kTFDataFunction)) continue; VLOG(3) << "Optimize function: function=" << func->signature().name(); optimized_functions = true; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index a36aee8ffbf..8f18dfdeef4 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -93,6 +93,10 @@ bool IsRunOnceOptimizer(const string& name) { name == "auto_mixed_precision_mkl"; } +bool IsTFDataFunction(const FunctionDef& func) { + return func.attr().contains(data::kTFDataFunction); +} + // Creates a function library stub from a real function library: copy only // signatures and attributes of all the function defined in fdef_lib. This stub // can be swapped with real function library in a graph, before passing it to @@ -611,61 +615,6 @@ Status MetaOptimizer::RunOptimizer( return Status::OK(); } -// Propagates _tf_data_function` attributes from functions to their callees. -void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib, - FunctionDefLibrary& fdef_lib) { - // Collect functions that need the attribute in this set. - absl::flat_hash_set tf_data_functions; - std::function collect_tf_data_functions_dfs = - [&](const string& func_name) -> void { - // Return if we already found and added this function. - if (tf_data_functions.contains(func_name)) return; - - // We only get here if the function is (directly or indirectly) called from - // a tf.data function, so add it to the set. - tf_data_functions.insert(func_name); - - const FunctionDef* func_def = flib.Find(func_name); - // Skip functions that are not reachable from the optimized graph. - if (func_def == nullptr) return; - - // Proceed with DFS for functions called from current function. - for (const NodeDef& node : func_def->node_def()) { - if (flib.Contains(node.op())) { - // This is a function call node. - collect_tf_data_functions_dfs(node.op()); - } - // Check if there are functions in attributes. - for (const auto& attr : node.attr()) { - const AttrValue& attr_value = attr.second; - if (attr_value.has_func()) { - collect_tf_data_functions_dfs(attr_value.func().name()); - } - if (attr_value.has_list()) { - for (const auto& func : attr_value.list().func()) { - collect_tf_data_functions_dfs(func.name()); - } - } - } - } - }; - // Perform DFS for all tf.data functions in `fdef_lib`. - for (const auto& func_def : fdef_lib.function()) { - const string& func_name = func_def.signature().name(); - if (data::IsTFDataFunction(func_def)) - collect_tf_data_functions_dfs(func_name); - } - // Set attribute for tf.data functions. We cannot do this in the DFS directly - // because `FunctionLibraryDefinition` does not seem to provide mutable access - // to a `FunctionDef`. - for (FunctionDef& func_def : *fdef_lib.mutable_function()) { - const string& func_name = func_def.signature().name(); - if (tf_data_functions.contains(func_name)) { - (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true); - } - } -} - Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) { const uint64 start_us = Env::Default()->NowMicros(); @@ -773,7 +722,6 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, for (const FunctionDef& function : optimized_graph->library().function()) { find_xla_compiled_functions(function.node_def()); } - PropagateTFDataAttrs(flib, *optimized_graph->mutable_library()); // Optimize each function only once. absl::flat_hash_set optimized_funcs; @@ -799,9 +747,8 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, // the function optimizer, before we can optimize function body. if (IsParametrized(func)) continue; - // Skip tf.data functions as they are optimized by tf.data meta optimizer - // and in function instantiation. - if (data::IsTFDataFunction(func)) continue; + // Skip tf.data functions as they are optimized by tf.data meta optimizer. + if (IsTFDataFunction(func)) continue; VLOG(3) << "Optimize function: function=" << func_name << " [" << function_idx++ << " of " diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 85f7f911635..595b636c7a9 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/substitute.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -1017,144 +1016,6 @@ TEST_F(MetaOptimizerTest, CompressConstants) { } } -// Tests for checking expected behavior when skipping tf.data functions in -// meta optimizer. - -// Custom optimizer which counts its calls. -class TfDataTestOptimizer : public CustomGraphOptimizer { - public: - static void InitCount() { cnt_ = 0; } - static int GetCount() { return cnt_; } - - TfDataTestOptimizer() {} - string name() const override { return "tf_data_test_optimizer"; } - bool UsesFunctionLibrary() const override { return false; } - - Status Init( - const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { - return Status::OK(); - } - - Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) override { - ++cnt_; - *optimized_graph = item.graph; - return Status::OK(); - } - - void Feedback(Cluster* cluster, const GrapplerItem& item, - const GraphDef& optimized_graph, double result) override {} - - private: - static int cnt_; -}; - -int TfDataTestOptimizer::cnt_; - -REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer); - -// Test fixture for parametrized testing. -class TfDataTestFixture - : public ::testing::TestWithParam> { - protected: - void SetUp() override { - is_my_mul_tf_data_ = std::get<0>(GetParam()); - is_my_square_tf_data_ = std::get<1>(GetParam()); - } - void RunTest(); - - private: - // controls which of the functions is flagged as tf.data function - bool is_my_mul_tf_data_ = false; - bool is_my_square_tf_data_ = false; -}; - -TEST_P(TfDataTestFixture, TfDataTests) { RunTest(); } - -// Core test function. -void TfDataTestFixture::RunTest() { - using test::function::NDef; - - // Define function library: - // - // MyMul(x, y) = x * y - // MySquare(x) = MyMul(x, x) - - FunctionDef mul_func = FunctionDefHelper::Create( - "MyMul", {"x:float", "y:float"}, {"z:float"}, {}, - {{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}}, - /*ret_def=*/ - {{"z", "mul:z:0"}}); - (*mul_func.mutable_attr())[data::kTFDataFunction].set_b(is_my_mul_tf_data_); - - FunctionDef square_func = FunctionDefHelper::Create( - "MySquare", {"x:float"}, {"z:float"}, {}, - {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", DT_FLOAT}}}}, - /*ret_def=*/ - {{"z", "my_mul:z:0"}}); - (*square_func.mutable_attr())[data::kTFDataFunction].set_b( - is_my_square_tf_data_); - - // Tensorflow graph: - // - // a = tf.Placeholder(tf.float); - // square = MySquare(a); // a^2 - GrapplerItem item; - item.id = "tf_graph"; - item.graph = test::function::GDef( - {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), - // Calls into function library - NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice), - // Forward outputs - NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)}, - /*funcs=*/ - {mul_func, square_func}); - - // Use only custom optimizer which counts its calls. - TfDataTestOptimizer::InitCount(); - ConfigProto config_proto; - auto& rewriter_config = - *(config_proto.mutable_graph_options()->mutable_rewrite_options()); - rewriter_config.add_optimizers("TfDataTestOptimizer"); - rewriter_config.set_min_graph_nodes(-1); - rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE); - - MetaOptimizer optimizer(nullptr, config_proto); - GraphDef output; - const Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - - // We expect one graph optimization + one optimization for each non-tf.data - // function. Note that if `MySquare` is flagged as a tf.data function, then - // `MyMul` is implicitly also considered a tf.data function because it is - // called from `MySquare`. - int expected_count = 3; - if (is_my_square_tf_data_) - expected_count -= 2; - else if (is_my_mul_tf_data_) - expected_count -= 1; - EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count); - - // We expect that the tf.data-attribute has been propagated from `MySquare` - // to its callee `MyMul` if the value is `true`. Otherwise, the attribute - // values should be unchanged. - FunctionLibraryDefinition flib(OpRegistry::Global(), output.library()); - const FunctionDef* square_func_after_opt = flib.Find("MySquare"); - const FunctionDef* mul_func_after_opt = flib.Find("MyMul"); - - EXPECT_EQ(data::IsTFDataFunction(*square_func_after_opt), - is_my_square_tf_data_); - if (is_my_square_tf_data_ || is_my_mul_tf_data_) { - EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), true); - } else { - EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), false); - } -} - -INSTANTIATE_TEST_SUITE_P(MetaOptimizerTest, TfDataTestFixture, - ::testing::Combine(::testing::Bool(), - ::testing::Bool())); - } // namespace } // namespace grappler } // namespace tensorflow