[tf.data] Refactoring of tf.data static optimizations.
After this CL, Grappler optimizations are applied to tf.data-based tf.functions during function instantiation as opposed to during input pipeline graph construction. This makes the application of Grappler optimizations consistent between tf.data and non-tf.data tf.functions. PiperOrigin-RevId: 313466120 Change-Id: I43a6f9bdedc12baad4aca344b462735b066f58e0
This commit is contained in:
parent
0ef1057c2d
commit
f82acdd576
@ -603,16 +603,8 @@ cc_library(
|
|||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
"//tensorflow/core/grappler/optimizers:arithmetic_optimizer",
|
|
||||||
"//tensorflow/core/grappler/optimizers:common_subgraph_elimination",
|
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||||
"//tensorflow/core/grappler/optimizers:dependency_optimizer",
|
|
||||||
"//tensorflow/core/grappler/optimizers:function_optimizer",
|
|
||||||
"//tensorflow/core/grappler/optimizers:loop_optimizer",
|
|
||||||
"//tensorflow/core/grappler/optimizers:model_pruner",
|
|
||||||
"//tensorflow/core/grappler/optimizers:remapper",
|
|
||||||
"//tensorflow/core/grappler/optimizers:shape_optimizer",
|
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
"//tensorflow/core/grappler/utils:functions",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -21,15 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
|
||||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
@ -60,14 +52,6 @@ constexpr std::array<const char*, 15> kTFDataOptimizations = {
|
|||||||
"slack",
|
"slack",
|
||||||
"inject_prefetch"};
|
"inject_prefetch"};
|
||||||
|
|
||||||
// Standard grappler optimizations, in the order we want to perform them.
|
|
||||||
// The order matches the order in the generic meta optimizer.
|
|
||||||
constexpr std::array<const char*, 9> kGrapplerOptimizations = {
|
|
||||||
"pruning", "function", "common_subgraph_elimination",
|
|
||||||
"shape", "arithmetic", "layout_optimizer",
|
|
||||||
"remapper", "loop", "dependency",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parses a list of string optimizer configurations into a map from
|
// Parses a list of string optimizer configurations into a map from
|
||||||
// optimizer name -> rewriter config for that optimizer.
|
// optimizer name -> rewriter config for that optimizer.
|
||||||
Status ToConfigMap(
|
Status ToConfigMap(
|
||||||
@ -118,11 +102,6 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
ApplyOptimization(optimization, cluster, &optimized_item));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& optimization : kGrapplerOptimizations) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ApplyOptimization(optimization, cluster, &optimized_item));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the final result of all the optimizations in `output`.
|
// Store the final result of all the optimizations in `output`.
|
||||||
output->Swap(&optimized_item.graph);
|
output->Swap(&optimized_item.graph);
|
||||||
|
|
||||||
@ -132,16 +111,17 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
.ReachableDefinitions(*output);
|
.ReachableDefinitions(*output);
|
||||||
const auto producer = output->versions().producer();
|
const auto producer = output->versions().producer();
|
||||||
bool optimized_functions = false;
|
bool optimized_functions = false;
|
||||||
for (const FunctionDef& func : output->library().function()) {
|
for (const auto& name : flib.ListFunctionNames()) {
|
||||||
|
auto* func = flib.Find(name);
|
||||||
// Skip non tf.data functions.
|
// Skip non tf.data functions.
|
||||||
if (!func.attr().contains(data::kTFDataFunction)) continue;
|
if (!func->attr().contains(data::kTFDataFunction)) continue;
|
||||||
VLOG(3) << "Optimize function: function=" << func.signature().name();
|
VLOG(3) << "Optimize function: function=" << func->signature().name();
|
||||||
optimized_functions = true;
|
optimized_functions = true;
|
||||||
|
|
||||||
// Make a GrapplerItem from a FunctionDef.
|
// Make a GrapplerItem from a FunctionDef.
|
||||||
GrapplerFunctionItem func_item;
|
GrapplerFunctionItem func_item;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
|
MakeGrapplerFunctionItem(*func, flib, producer, &func_item));
|
||||||
|
|
||||||
GraphDef optimized_func_graph;
|
GraphDef optimized_func_graph;
|
||||||
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
|
TF_RETURN_IF_ERROR(Optimize(cluster, func_item, &optimized_func_graph));
|
||||||
@ -162,7 +142,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// Replace optimized function with a new FunctionDef.
|
// Replace optimized function with a new FunctionDef.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
flib.ReplaceFunction(func.signature().name(), optimized_func));
|
flib.ReplaceFunction(func->signature().name(), optimized_func));
|
||||||
}
|
}
|
||||||
if (optimized_functions) {
|
if (optimized_functions) {
|
||||||
*output->mutable_library() = flib.ToProto();
|
*output->mutable_library() = flib.ToProto();
|
||||||
@ -221,27 +201,6 @@ Status TFDataMetaOptimizer::Init(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enable a subset of grappler optimization that are enabled by default.
|
|
||||||
//
|
|
||||||
// Layout optimizations are excluded because they assume that ops without
|
|
||||||
// explicit device assignment will be placed on GPU (if available) but that's
|
|
||||||
// not the case for operations within tf.data functions.
|
|
||||||
//
|
|
||||||
// TODO(b/120437209): Re-enable constant folding.
|
|
||||||
//
|
|
||||||
// TODO(jsimsa): Make the set of generic Grappler optimization applied to
|
|
||||||
// tf.data functions configurable.
|
|
||||||
enabled_optimizers_["pruning"] = MakeUnique<ModelPruner>();
|
|
||||||
enabled_optimizers_["shape"] = MakeUnique<ShapeOptimizer>();
|
|
||||||
enabled_optimizers_["remapping"] = MakeUnique<Remapper>(RewriterConfig::ON);
|
|
||||||
enabled_optimizers_["common_subgraph_elimination"] =
|
|
||||||
MakeUnique<CommonSubgraphElimination>();
|
|
||||||
enabled_optimizers_["arithmetic"] = MakeUnique<ArithmeticOptimizer>();
|
|
||||||
enabled_optimizers_["dependency"] = MakeUnique<DependencyOptimizer>();
|
|
||||||
enabled_optimizers_["loop"] = MakeUnique<LoopOptimizer>();
|
|
||||||
enabled_optimizers_["function"] = MakeUnique<FunctionOptimizer>(
|
|
||||||
RewriterConfig::ON, /*lower_control_flow=*/true);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"if_not_mobile",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_kernel_library",
|
"tf_kernel_library",
|
||||||
)
|
)
|
||||||
@ -150,6 +151,7 @@ cc_library(
|
|||||||
":dataset_utils",
|
":dataset_utils",
|
||||||
":single_threaded_executor",
|
":single_threaded_executor",
|
||||||
":stats_utils",
|
":stats_utils",
|
||||||
|
"@com_google_absl//absl/time",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
@ -158,8 +160,10 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/kernels:variable_ops",
|
"//tensorflow/core/kernels:variable_ops",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"@com_google_absl//absl/time",
|
] + if_not_mobile([
|
||||||
],
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler/optimizers:meta_optimizer",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -35,6 +35,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/notification.h"
|
#include "tensorflow/core/platform/notification.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||||
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
namespace {
|
||||||
@ -612,6 +617,28 @@ Status CapturedFunction::Instantiate(
|
|||||||
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
for (size_t i = 0; i < fdef->signature().output_arg_size(); ++i) {
|
||||||
inst_opts.output_devices.push_back(inst_opts.target);
|
inst_opts.output_devices.push_back(inst_opts.target);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
grappler::GrapplerItem::OptimizationOptions optimization_options;
|
||||||
|
optimization_options.allow_pruning_stateful_and_dataset_ops = false;
|
||||||
|
ConfigProto config_proto = inst_opts.config_proto;
|
||||||
|
// Layout optimizations are excluded because they assume that ops without
|
||||||
|
// explicit device assignment will be placed on GPU (if available) but
|
||||||
|
// that's not the case for operations within tf.data functions.
|
||||||
|
config_proto.mutable_graph_options()
|
||||||
|
->mutable_rewrite_options()
|
||||||
|
->set_layout_optimizer(RewriterConfig::OFF);
|
||||||
|
// TODO(b/120437209): Re-enable constant folding.
|
||||||
|
config_proto.mutable_graph_options()
|
||||||
|
->mutable_rewrite_options()
|
||||||
|
->set_constant_folding(RewriterConfig::OFF);
|
||||||
|
inst_opts.optimize_graph_fn =
|
||||||
|
std::bind(tensorflow::grappler::OptimizeGraph, std::placeholders::_1,
|
||||||
|
std::placeholders::_2, std::placeholders::_3,
|
||||||
|
std::placeholders::_4, std::placeholders::_5,
|
||||||
|
std::move(config_proto), fdef->signature().name(),
|
||||||
|
std::move(optimization_options), std::placeholders::_6);
|
||||||
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime::Handle f_handle;
|
FunctionLibraryRuntime::Handle f_handle;
|
||||||
|
Loading…
Reference in New Issue
Block a user