[tf.data] Refactoring of tf.data static optimizations.
After this CL, Grappler optimizations are applied to tf.data-based tf.functions during function instantiation as opposed to during input pipeline graph construction. This makes the application of Grappler optimizations consistent between tf.data and non-tf.data tf.functions. PiperOrigin-RevId: 313466120 Change-Id: I43a6f9bdedc12baad4aca344b462735b066f58e0
This commit is contained in:
parent
0ef1057c2d
commit
f82acdd576
@ -603,16 +603,8 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:common_subgraph_elimination",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:function_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:loop_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
||||
"//tensorflow/core/grappler/optimizers:remapper",
|
||||
"//tensorflow/core/grappler/optimizers:shape_optimizer",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -21,15 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
@ -60,14 +52,6 @@ constexpr std::array<const char*, 15> kTFDataOptimizations = {
|
||||
"slack",
|
||||
"inject_prefetch"};
|
||||
|
||||
// Standard grappler optimizations, in the order we want to perform them.
|
||||
// The order matches the order in the generic meta optimizer.
|
||||
constexpr std::array<const char*, 9> kGrapplerOptimizations = {
|
||||
"pruning", "function", "common_subgraph_elimination",
|
||||
"shape", "arithmetic", "layout_optimizer",
|
||||
"remapper", "loop", "dependency",
|
||||
};
|
||||
|
||||
// Parses a list of string optimizer configurations into a map from
|
||||
// optimizer name -> rewriter config for that optimizer.
|
||||
Status ToConfigMap(
|
||||
@ -118,11 +102,6 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||
}
|
||||
|
||||
for (const auto& optimization : kGrapplerOptimizations) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||
}
|
||||
|
||||
// Store the final result of all the optimizations in `output`.
|
||||
output->Swap(&optimized_item.graph);
|
||||
|
||||
@ -132,16 +111,17 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
.ReachableDefinitions(*output);
|
||||
const auto producer = output->versions().producer();
|
||||
bool optimized_functions = false;
|
||||
for (const FunctionDef& func : output->library().function()) {
|
||||
for (const auto& name : flib.ListFunctionNames()) {
|
||||
auto* func = flib.Find(name);
|
||||
// Skip non tf.data functions.
|
||||
if (!func.attr().contains(data::kTFDataFunction)) continue;
|
||||
VLOG(3) << "Optimize function: function=" << func.signature().name();
|
||||
if (!func->attr().contains(data::kTFDataFunction)) continue;
|
||||
VLOG(3) << "Optimize function: function=" << func->signature().name();
|
||||
optimized_functions = true;
|
||||
|
||||
// Make a GrapplerItem from a FunctionDef.
|
||||
GrapplerFunctionItem func_item;
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
|
||||
MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
|
||||
|
||||
GraphDef optimized_func_graph;
|
||||
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
|
||||
@ -162,7 +142,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// Replace optimized function with a new FunctionDef.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flib.ReplaceFunction(func.signature().name(), optimized_func));
|
||||
flib.ReplaceFunction(func->signature().name(), optimized_func));
|
||||
}
|
||||
if (optimized_functions) {
|
||||
*output->mutable_library() = flib.ToProto();
|
||||
@ -221,27 +201,6 @@ Status TFDataMetaOptimizer::Init(
|
||||
}
|
||||
}
|
||||
|
||||
// Enable a subset of grappler optimization that are enabled by default.
|
||||
//
|
||||
// Layout optimizations are excluded because they assume that ops without
|
||||
// explicit device assignment will be placed on GPU (if available) but that's
|
||||
// not the case for operations within tf.data functions.
|
||||
//
|
||||
// TODO(b/120437209): Re-enable constant folding.
|
||||
//
|
||||
// TODO(jsimsa): Make the set of generic Grappler optimization applied to
|
||||
// tf.data functions configurable.
|
||||
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
|
||||
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
|
||||
enabled_optimizers_["remapping"] = MakeUnique<Remapper>(RewriterConfig::ON);
|
||||
enabled_optimizers_["common_subgraph_elimination"] =
|
||||
MakeUnique<CommonSubgraphElimination>();
|
||||
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
|
||||
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
|
||||
enabled_optimizers_["loop"] = MakeUnique<LoopOptimizer>();
|
||||
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
|
||||
RewriterConfig::ON, /*lower_control_flow=*/true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_not_mobile",
|
||||
"tf_cc_test",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
@ -150,6 +151,7 @@ cc_library(
|
||||
":dataset_utils",
|
||||
":single_threaded_executor",
|
||||
":stats_utils",
|
||||
"@com_google_absl//absl/time",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -158,8 +160,10 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -35,6 +35,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
@ -612,6 +617,28 @@ Status CapturedFunction::Instantiate(
|
||||
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
||||
inst_opts.output_devices.push_back(inst_opts.target);
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||
optimization_options.allow_pruning_stateful_and_dataset_ops = false;
|
||||
ConfigProto config_proto = inst_opts.config_proto;
|
||||
// Layout optimizations are excluded because they assume that ops without
|
||||
// explicit device assignment will be placed on GPU (if available) but
|
||||
// that's not the case for operations within tf.data functions.
|
||||
config_proto.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->set_layout_optimizer(RewriterConfig::OFF);
|
||||
// TODO(b/120437209): Re-enable constant folding.
|
||||
config_proto.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->set_constant_folding(RewriterConfig::OFF);
|
||||
inst_opts.optimize_graph_fn =
|
||||
std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3,
|
||||
std::placeholders::_4, std::placeholders::_5,
|
||||
std::move(config_proto), fdef->signature().name(),
|
||||
std::move(optimization_options), std::placeholders::_6);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle f_handle;
|
||||
|
Loading…
Reference in New Issue
Block a user