[tf.data] Refactoring of tf.data static optimizations.

After this CL, Grappler optimizations are applied to tf.data-based tf.functions during function instantiation as opposed to during input pipeline graph construction. This makes the application of Grappler optimizations consistent between tf.data and non-tf.data tf.functions.

PiperOrigin-RevId: 313466120
Change-Id: I43a6f9bdedc12baad4aca344b462735b066f58e0
This commit is contained in:
Jiri Simsa 2020-05-27 14:37:12 -07:00 committed by TensorFlower Gardener
parent 0ef1057c2d
commit f82acdd576
4 changed files with 39 additions and 57 deletions

View File

@ -603,16 +603,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
"//tensorflow/core/grappler/optimizers:common_subgraph_elimination",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
"//tensorflow/core/grappler/optimizers:function_optimizer",
"//tensorflow/core/grappler/optimizers:loop_optimizer",
"//tensorflow/core/grappler/optimizers:model_pruner",
"//tensorflow/core/grappler/optimizers:remapper",
"//tensorflow/core/grappler/optimizers:shape_optimizer",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core:framework",

View File

@ -21,15 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
@ -60,14 +52,6 @@ constexpr std::array<const char*, 15> kTFDataOptimizations = {
"slack",
"inject_prefetch"};
// Standard grappler optimizations, in the order we want to perform them.
// The order matches the order in the generic meta optimizer.
constexpr std::array<const char*, 9> kGrapplerOptimizations = {
"pruning", "function", "common_subgraph_elimination",
"shape", "arithmetic", "layout_optimizer",
"remapper", "loop", "dependency",
};
// Parses a list of string optimizer configurations into a map from
// optimizer name -> rewriter config for that optimizer.
Status ToConfigMap(
@ -118,11 +102,6 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
ApplyOptimization(optimization, cluster, &optimized_item));
}
for (const auto& optimization : kGrapplerOptimizations) {
TF_RETURN_IF_ERROR(
ApplyOptimization(optimization, cluster, &optimized_item));
}
// Store the final result of all the optimizations in `output`.
output->Swap(&optimized_item.graph);
@ -132,16 +111,17 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
.ReachableDefinitions(*output);
const auto producer = output->versions().producer();
bool optimized_functions = false;
for (const FunctionDef& func : output->library().function()) {
for (const auto& name : flib.ListFunctionNames()) {
auto* func = flib.Find(name);
// Skip non tf.data functions.
if (!func.attr().contains(data::kTFDataFunction)) continue;
VLOG(3) << "Optimize function: function=" << func.signature().name();
if (!func->attr().contains(data::kTFDataFunction)) continue;
VLOG(3) << "Optimize function: function=" << func->signature().name();
optimized_functions = true;
// Make a GrapplerItem from a FunctionDef.
GrapplerFunctionItem func_item;
TF_RETURN_IF_ERROR(
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
GraphDef optimized_func_graph;
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
@ -162,7 +142,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Replace optimized function with a new FunctionDef.
TF_RETURN_IF_ERROR(
flib.ReplaceFunction(func.signature().name(), optimized_func));
flib.ReplaceFunction(func->signature().name(), optimized_func));
}
if (optimized_functions) {
*output->mutable_library() = flib.ToProto();
@ -221,27 +201,6 @@ Status TFDataMetaOptimizer::Init(
}
}
// Enable a subset of grappler optimization that are enabled by default.
//
// Layout optimizations are excluded because they assume that ops without
// explicit device assignment will be placed on GPU (if available) but that's
// not the case for operations within tf.data functions.
//
// TODO(b/120437209): Re-enable constant folding.
//
// TODO(jsimsa): Make the set of generic Grappler optimization applied to
// tf.data functions configurable.
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
enabled_optimizers_["remapping"] = MakeUnique<Remapper>(RewriterConfig::ON);
enabled_optimizers_["common_subgraph_elimination"] =
MakeUnique<CommonSubgraphElimination>();
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
enabled_optimizers_["loop"] = MakeUnique<LoopOptimizer>();
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
RewriterConfig::ON, /*lower_control_flow=*/true);
return Status::OK();
}

View File

@ -3,6 +3,7 @@
load(
"//tensorflow:tensorflow.bzl",
"if_not_mobile",
"tf_cc_test",
"tf_kernel_library",
)
@ -150,6 +151,7 @@ cc_library(
":dataset_utils",
":single_threaded_executor",
":stats_utils",
"@com_google_absl//absl/time",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -158,8 +160,10 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/time",
],
] + if_not_mobile([
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
]),
)
cc_library(

View File

@ -35,6 +35,11 @@ limitations under the License.
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#endif // !IS_MOBILE_PLATFORM
namespace tensorflow {
namespace data {
namespace {
@ -612,6 +617,28 @@ Status CapturedFunction::Instantiate(
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
inst_opts.output_devices.push_back(inst_opts.target);
}
#if !defined(IS_MOBILE_PLATFORM)
grappler::GrapplerItem::OptimizationOptions optimization_options;
optimization_options.allow_pruning_stateful_and_dataset_ops = false;
ConfigProto config_proto = inst_opts.config_proto;
// Layout optimizations are excluded because they assume that ops without
// explicit device assignment will be placed on GPU (if available) but
// that's not the case for operations within tf.data functions.
config_proto.mutable_graph_options()
->mutable_rewrite_options()
->set_layout_optimizer(RewriterConfig::OFF);
// TODO(b/120437209): Re-enable constant folding.
config_proto.mutable_graph_options()
->mutable_rewrite_options()
->set_constant_folding(RewriterConfig::OFF);
inst_opts.optimize_graph_fn =
std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3,
std::placeholders::_4, std::placeholders::_5,
std::move(config_proto), fdef->signature().name(),
std::move(optimization_options), std::placeholders::_6);
#endif // !IS_MOBILE_PLATFORM
}
FunctionLibraryRuntime::Handle f_handle;