diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index bab28d44686..a927afc5b30 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -603,16 +603,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:arithmetic_optimizer", - "//tensorflow/core/grappler/optimizers:common_subgraph_elimination", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/grappler/optimizers:dependency_optimizer", - "//tensorflow/core/grappler/optimizers:function_optimizer", - "//tensorflow/core/grappler/optimizers:loop_optimizer", - "//tensorflow/core/grappler/optimizers:model_pruner", - "//tensorflow/core/grappler/optimizers:remapper", - "//tensorflow/core/grappler/optimizers:shape_optimizer", "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core:framework", diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index 3591cd525ac..5804c3ee01a 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -21,15 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" -#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" -#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" -#include "tensorflow/core/grappler/optimizers/function_optimizer.h" -#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" -#include "tensorflow/core/grappler/optimizers/model_pruner.h" -#include "tensorflow/core/grappler/optimizers/remapper.h" -#include "tensorflow/core/grappler/optimizers/shape_optimizer.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" @@ -60,14 +52,6 @@ constexpr std::array kTFDataOptimizations = { "slack", "inject_prefetch"}; -// Standard grappler optimizations, in the order we want to perform them. -// The order matches the order in the generic meta optimizer. -constexpr std::array kGrapplerOptimizations = { - "pruning", "function", "common_subgraph_elimination", - "shape", "arithmetic", "layout_optimizer", - "remapper", "loop", "dependency", -}; - // Parses a list of string optimizer configurations into a map from // optimizer name -> rewriter config for that optimizer. Status ToConfigMap( @@ -118,11 +102,6 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, ApplyOptimization(optimization, cluster, &optimized_item)); } - for (const auto& optimization : kGrapplerOptimizations) { - TF_RETURN_IF_ERROR( - ApplyOptimization(optimization, cluster, &optimized_item)); - } - // Store the final result of all the optimizations in `output`. output->Swap(&optimized_item.graph); @@ -132,16 +111,17 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, .ReachableDefinitions(*output); const auto producer = output->versions().producer(); bool optimized_functions = false; - for (const FunctionDef& func : output->library().function()) { + for (const auto& name : flib.ListFunctionNames()) { + auto* func = flib.Find(name); // Skip non tf.data functions. - if (!func.attr().contains(data::kTFDataFunction)) continue; - VLOG(3) << "Optimize function: function=" << func.signature().name(); + if (!func->attr().contains(data::kTFDataFunction)) continue; + VLOG(3) << "Optimize function: function=" << func->signature().name(); optimized_functions = true; // Make a GrapplerItem from a FunctionDef. GrapplerFunctionItem func_item; TF_RETURN_IF_ERROR( - MakeGrapplerFunctionItem(func, flib, producer, &func_item)); + MakeGrapplerFunctionItem(*func, flib, producer, &func_item)); GraphDef optimized_func_graph; TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph)); @@ -162,7 +142,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Replace optimized function with a new FunctionDef. TF_RETURN_IF_ERROR( - flib.ReplaceFunction(func.signature().name(), optimized_func)); + flib.ReplaceFunction(func->signature().name(), optimized_func)); } if (optimized_functions) { *output->mutable_library() = flib.ToProto(); @@ -221,27 +201,6 @@ Status TFDataMetaOptimizer::Init( } } - // Enable a subset of grappler optimization that are enabled by default. - // - // Layout optimizations are excluded because they assume that ops without - // explicit device assignment will be placed on GPU (if available) but that's - // not the case for operations within tf.data functions. - // - // TODO(b/120437209): Re-enable constant folding. - // - // TODO(jsimsa): Make the set of generic Grappler optimization applied to - // tf.data functions configurable. - enabled_optimizers_["pruning"] = MakeUnique(); - enabled_optimizers_["shape"] = MakeUnique(); - enabled_optimizers_["remapping"] = MakeUnique(RewriterConfig::ON); - enabled_optimizers_["common_subgraph_elimination"] = - MakeUnique(); - enabled_optimizers_["arithmetic"] = MakeUnique(); - enabled_optimizers_["dependency"] = MakeUnique(); - enabled_optimizers_["loop"] = MakeUnique(); - enabled_optimizers_["function"] = MakeUnique( - RewriterConfig::ON, /*lower_control_flow=*/true); - return Status::OK(); } diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index d088abc00e6..6d0351202df 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -3,6 +3,7 @@ load( "//tensorflow:tensorflow.bzl", + "if_not_mobile", "tf_cc_test", "tf_kernel_library", ) @@ -150,6 +151,7 @@ cc_library( ":dataset_utils", ":single_threaded_executor", ":stats_utils", + "@com_google_absl//absl/time", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -158,8 +160,10 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/time", - ], + ] + if_not_mobile([ + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + ]), ) cc_library( diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index dd64475d7d6..d79cb25ec8b 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -35,6 +35,11 @@ limitations under the License. #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/profiler/lib/traceme.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#endif // !IS_MOBILE_PLATFORM + namespace tensorflow { namespace data { namespace { @@ -612,6 +617,28 @@ Status CapturedFunction::Instantiate( for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) { inst_opts.output_devices.push_back(inst_opts.target); } + +#if !defined(IS_MOBILE_PLATFORM) + grappler::GrapplerItem::OptimizationOptions optimization_options; + optimization_options.allow_pruning_stateful_and_dataset_ops = false; + ConfigProto config_proto = inst_opts.config_proto; + // Layout optimizations are excluded because they assume that ops without + // explicit device assignment will be placed on GPU (if available) but + // that's not the case for operations within tf.data functions. + config_proto.mutable_graph_options() + ->mutable_rewrite_options() + ->set_layout_optimizer(RewriterConfig::OFF); + // TODO(b/120437209): Re-enable constant folding. + config_proto.mutable_graph_options() + ->mutable_rewrite_options() + ->set_constant_folding(RewriterConfig::OFF); + inst_opts.optimize_graph_fn = + std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, + std::placeholders::_4, std::placeholders::_5, + std::move(config_proto), fdef->signature().name(), + std::move(optimization_options), std::placeholders::_6); +#endif // !IS_MOBILE_PLATFORM } FunctionLibraryRuntime::Handle f_handle;