[Grappler] Remove several gratuitous graph and function library copies from the Grappler-related callstacks in TensorFlow.
This reduces the time spent optimizing a particular model I am benchmarking by about 8%. PiperOrigin-RevId: 299457198 Change-Id: I46688f3e215f5ab9ec55520d0ef324e04aa49e31
This commit is contained in:
parent
33f9e8d283
commit
8054b80990
tensorflow/core
common_runtime
grappler/optimizers
kernels/data
@ -732,8 +732,9 @@ Status GraphExecutionState::OptimizeGraph(
|
|||||||
}
|
}
|
||||||
grappler::VirtualCluster cluster(device_set_);
|
grappler::VirtualCluster cluster(device_set_);
|
||||||
GraphDef new_graph;
|
GraphDef new_graph;
|
||||||
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
|
TF_RETURN_IF_ERROR(
|
||||||
item, session_options_->config, cpu_device, &cluster, &new_graph));
|
grappler::RunMetaOptimizer(std::move(item), session_options_->config,
|
||||||
|
cpu_device, &cluster, &new_graph));
|
||||||
|
|
||||||
// Merge optimized graph function library with an original library.
|
// Merge optimized graph function library with an original library.
|
||||||
// Optimized graph might have new functions specialized for it's
|
// Optimized graph might have new functions specialized for it's
|
||||||
|
@ -58,6 +58,12 @@ class GraphOptimizer {
|
|||||||
virtual Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
virtual Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) = 0;
|
GraphDef* optimized_graph) = 0;
|
||||||
|
|
||||||
|
// Subclasses may define a version of Optimize that consumes item.
|
||||||
|
virtual Status Optimize(Cluster* cluster, GrapplerItem&& item,
|
||||||
|
GraphDef* optimized_graph) {
|
||||||
|
return Optimize(cluster, item, optimized_graph);
|
||||||
|
}
|
||||||
|
|
||||||
// Method invoked by the framework so that it can provide feedback
|
// Method invoked by the framework so that it can provide feedback
|
||||||
// on how well the "optimized_graph" (produced as *optimized_graph from a
|
// on how well the "optimized_graph" (produced as *optimized_graph from a
|
||||||
// call to Optimize) performed. Lower "result" scores are better.
|
// call to Optimize) performed. Lower "result" scores are better.
|
||||||
|
@ -379,7 +379,7 @@ void MetaOptimizer::InitializeVerifiers(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
|
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
|
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
|
||||||
: cfg_.min_graph_nodes();
|
: cfg_.min_graph_nodes();
|
||||||
@ -426,8 +426,8 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// Invariant: optimized_graph contains the most recently optimized version of
|
// Invariant: optimized_graph contains the most recently optimized version of
|
||||||
// the graph.
|
// the graph.
|
||||||
GrapplerItem optimized_item = item;
|
auto original_producer = item.graph.versions().producer();
|
||||||
optimized_graph->Swap(&optimized_item.graph);
|
optimized_graph->Swap(&item.graph);
|
||||||
|
|
||||||
GraphOptimizationResult optimization_result(item.id);
|
GraphOptimizationResult optimization_result(item.id);
|
||||||
GraphOptimizer* sa_optimizer = nullptr;
|
GraphOptimizer* sa_optimizer = nullptr;
|
||||||
@ -465,7 +465,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &optimized_item,
|
TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
|
||||||
optimized_graph, &optimization_result));
|
optimized_graph, &optimization_result));
|
||||||
|
|
||||||
if (iteration == 0 && optimizer->name() == "model_pruner") {
|
if (iteration == 0 && optimizer->name() == "model_pruner") {
|
||||||
@ -498,7 +498,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// ScopedAllocatorOptimizer must run last.
|
// ScopedAllocatorOptimizer must run last.
|
||||||
if (sa_optimizer != nullptr) {
|
if (sa_optimizer != nullptr) {
|
||||||
TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &optimized_item,
|
TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
|
||||||
optimized_graph, &optimization_result));
|
optimized_graph, &optimization_result));
|
||||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||||
}
|
}
|
||||||
@ -516,8 +516,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
|
|||||||
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
|
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
|
||||||
ReassignColocation(optimized_graph);
|
ReassignColocation(optimized_graph);
|
||||||
// Make sure that the optimizers preserved the graph version.
|
// Make sure that the optimizers preserved the graph version.
|
||||||
DCHECK_EQ(optimized_graph->versions().producer(),
|
DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
|
||||||
item.graph.versions().producer());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -590,8 +589,8 @@ Status MetaOptimizer::RunOptimizer(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
VLOG(1) << "Starting optimization for grappler item: " << item.id;
|
VLOG(1) << "Starting optimization for grappler item: " << item.id;
|
||||||
optimization_results_.clear();
|
optimization_results_.clear();
|
||||||
|
|
||||||
@ -609,21 +608,21 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// remove all the unreachable functions.
|
// remove all the unreachable functions.
|
||||||
// TODO(ezhulenev): Construct reachable function library definition directly
|
// TODO(ezhulenev): Construct reachable function library definition directly
|
||||||
// from the proto without constructing temporary FunctionLibraryDefinition.
|
// from the proto without constructing temporary FunctionLibraryDefinition.
|
||||||
GraphDef trimmed_graph; // do not copy graph with a potentially huge library
|
*item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
|
||||||
*trimmed_graph.mutable_node() = item.graph.node();
|
|
||||||
*trimmed_graph.mutable_versions() = item.graph.versions();
|
|
||||||
*trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto();
|
|
||||||
|
|
||||||
GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
|
|
||||||
|
|
||||||
VLOG(1) << absl::Substitute(
|
VLOG(1) << absl::Substitute(
|
||||||
"Deleted $0 unreachable functions from the graph (library size = $1)",
|
"Deleted $0 unreachable functions from the graph (library size = $1)",
|
||||||
item.graph.library().function_size() -
|
item.graph.library().function_size() -
|
||||||
trimmed_item.graph.library().function_size(),
|
item.graph.library().function_size(),
|
||||||
trimmed_item.graph.library().function_size());
|
item.graph.library().function_size());
|
||||||
|
|
||||||
|
// Save a few small fields from item before we move it.
|
||||||
|
bool optimize_function_library =
|
||||||
|
item.optimization_options().optimize_function_library;
|
||||||
|
const auto producer = item.graph.versions().producer();
|
||||||
|
|
||||||
// 1. Optimize main graph
|
// 1. Optimize main graph
|
||||||
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, trimmed_item, optimized_graph));
|
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(item), optimized_graph));
|
||||||
VLOG(1) << "Optimized main graph.";
|
VLOG(1) << "Optimized main graph.";
|
||||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||||
|
|
||||||
@ -675,9 +674,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// Optimize each function only once.
|
// Optimize each function only once.
|
||||||
absl::flat_hash_set<string> optimized_funcs;
|
absl::flat_hash_set<string> optimized_funcs;
|
||||||
bool optimize_function_library =
|
|
||||||
item.optimization_options().optimize_function_library;
|
|
||||||
|
|
||||||
while (optimize_function_library) {
|
while (optimize_function_library) {
|
||||||
optimize_function_library = false;
|
optimize_function_library = false;
|
||||||
|
|
||||||
@ -711,8 +707,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// Make a GrapplerItem from a FunctionDef.
|
// Make a GrapplerItem from a FunctionDef.
|
||||||
GrapplerFunctionItem func_item;
|
GrapplerFunctionItem func_item;
|
||||||
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
|
TF_RETURN_IF_ERROR(
|
||||||
func, flib, trimmed_item.graph.versions().producer(), &func_item));
|
MakeGrapplerFunctionItem(func, flib, producer, &func_item));
|
||||||
|
|
||||||
// If we need to compute the gradient of optimized function at runtime, we
|
// If we need to compute the gradient of optimized function at runtime, we
|
||||||
// can't perform non-differentiable rewrites.
|
// can't perform non-differentiable rewrites.
|
||||||
@ -760,8 +756,9 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
TF_RETURN_IF_ERROR(implementation_selector.Optimize(
|
TF_RETURN_IF_ERROR(implementation_selector.Optimize(
|
||||||
cluster, func_item, &optimized_func_graph));
|
cluster, func_item, &optimized_func_graph));
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(
|
GrapplerFunctionItem func_item_copy = func_item;
|
||||||
OptimizeGraph(cluster, func_item, &optimized_func_graph));
|
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
|
||||||
|
&optimized_func_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function body optimization might have created new specialized
|
// Function body optimization might have created new specialized
|
||||||
@ -834,13 +831,14 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
|
|||||||
!rewrite_cfg.custom_optimizers().empty();
|
!rewrite_cfg.custom_optimizers().empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RunMetaOptimizer(const GrapplerItem& item, const ConfigProto& cfg,
|
Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
|
||||||
DeviceBase* cpu_device, Cluster* cluster,
|
DeviceBase* cpu_device, Cluster* cluster,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
MetaOptimizer optimizer(cpu_device, cfg);
|
MetaOptimizer optimizer(cpu_device, cfg);
|
||||||
optimizer.set_deadline_usec(
|
optimizer.set_deadline_usec(
|
||||||
DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
|
DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
|
||||||
return optimizer.Optimize(cluster, item, optimized_graph);
|
return optimizer.OptimizeConsumeItem(cluster, std::move(item),
|
||||||
|
optimized_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status OptimizeGraph(
|
Status OptimizeGraph(
|
||||||
@ -883,7 +881,7 @@ Status OptimizeGraph(
|
|||||||
// TODO(nareshmodi): Consider adding and using the more generic GraphOptions
|
// TODO(nareshmodi): Consider adding and using the more generic GraphOptions
|
||||||
// proto (which also contain the OptimizerOptions).
|
// proto (which also contain the OptimizerOptions).
|
||||||
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
||||||
item, config_proto, cpu_device, &cluster, &out_graph));
|
std::move(item), config_proto, cpu_device, &cluster, &out_graph));
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::Graph> optimized_graph(
|
std::unique_ptr<tensorflow::Graph> optimized_graph(
|
||||||
new tensorflow::Graph(OpRegistry::Global()));
|
new tensorflow::Graph(OpRegistry::Global()));
|
||||||
|
@ -42,7 +42,13 @@ class MetaOptimizer : public GraphOptimizer {
|
|||||||
bool UsesFunctionLibrary() const override { return true; }
|
bool UsesFunctionLibrary() const override { return true; }
|
||||||
|
|
||||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) override;
|
GraphDef* optimized_graph) override {
|
||||||
|
GrapplerItem copy(item);
|
||||||
|
return OptimizeConsumeItem(cluster, std::move(copy), optimized_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
|
||||||
|
GraphDef* optimized_graph);
|
||||||
|
|
||||||
void PrintResult();
|
void PrintResult();
|
||||||
|
|
||||||
@ -77,7 +83,7 @@ class MetaOptimizer : public GraphOptimizer {
|
|||||||
|
|
||||||
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
|
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
|
||||||
// multiple such passes: 1) for the main graph 2) for the function library
|
// multiple such passes: 1) for the main graph 2) for the function library
|
||||||
Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
|
Status OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
|
||||||
GraphDef* optimized_graph);
|
GraphDef* optimized_graph);
|
||||||
|
|
||||||
DeviceBase* const cpu_device_; // may be NULL
|
DeviceBase* const cpu_device_; // may be NULL
|
||||||
@ -111,7 +117,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg);
|
|||||||
// during constant folding; if NULL, a new device is created for doing constant
|
// during constant folding; if NULL, a new device is created for doing constant
|
||||||
// folding. For performance, it is recommended to pass in an existing cpu_device
|
// folding. For performance, it is recommended to pass in an existing cpu_device
|
||||||
// when possible.
|
// when possible.
|
||||||
Status RunMetaOptimizer(const GrapplerItem& item, const ConfigProto& cfg,
|
Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
|
||||||
DeviceBase* cpu_device, Cluster* cluster,
|
DeviceBase* cpu_device, Cluster* cluster,
|
||||||
GraphDef* optimized_graph);
|
GraphDef* optimized_graph);
|
||||||
|
|
||||||
|
@ -722,12 +722,13 @@ TEST_F(MetaOptimizerTest, OptimizerTimesOut) {
|
|||||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
||||||
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
|
GraphDef original = item.graph;
|
||||||
const Status status =
|
const Status status =
|
||||||
RunMetaOptimizer(item, config, nullptr, nullptr, &output);
|
RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
|
||||||
EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
|
EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
|
||||||
// Make sure the graph was reverted to the original regardless of when the
|
// Make sure the graph was reverted to the original regardless of when the
|
||||||
// optimizer timed out.
|
// optimizer timed out.
|
||||||
CompareGraphs(item.graph, output);
|
CompareGraphs(original, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) {
|
TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) {
|
||||||
@ -744,11 +745,12 @@ TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) {
|
|||||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
|
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
|
||||||
|
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
|
const int original_node_size = item.graph.node_size();
|
||||||
const Status status =
|
const Status status =
|
||||||
RunMetaOptimizer(item, config, nullptr, nullptr, &output);
|
RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
|
||||||
EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
|
EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline.");
|
||||||
// The meta optimizer should manage to finish one iteration.
|
// The meta optimizer should manage to finish one iteration.
|
||||||
EXPECT_EQ(item.graph.node_size() + 1, output.node_size());
|
EXPECT_EQ(original_node_size + 1, output.node_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
|
TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
|
||||||
@ -764,11 +766,12 @@ TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
|
|||||||
rewriter_config.set_meta_optimizer_timeout_ms(2500);
|
rewriter_config.set_meta_optimizer_timeout_ms(2500);
|
||||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
|
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
|
const int original_node_size = item.graph.node_size();
|
||||||
const Status status =
|
const Status status =
|
||||||
RunMetaOptimizer(item, config, nullptr, nullptr, &output);
|
RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output);
|
||||||
TF_EXPECT_OK(status);
|
TF_EXPECT_OK(status);
|
||||||
// The meta optimizer should manage to finish two iterations.
|
// The meta optimizer should manage to finish two iterations.
|
||||||
EXPECT_EQ(item.graph.node_size() + 2, output.node_size());
|
EXPECT_EQ(original_node_size + 2, output.node_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnValidGraph) {
|
TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnValidGraph) {
|
||||||
|
@ -126,7 +126,7 @@ Status ApplyRewrites(OpKernelContext* ctx,
|
|||||||
tensorflow::ConfigProto config;
|
tensorflow::ConfigProto config;
|
||||||
*config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
|
*config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
|
||||||
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
|
||||||
*grappler_item, config, ctx->device(), &cluster, graph_def));
|
std::move(*grappler_item), config, ctx->device(), &cluster, graph_def));
|
||||||
|
|
||||||
// Remove fake sinks after optimizations are done.
|
// Remove fake sinks after optimizations are done.
|
||||||
//
|
//
|
||||||
|
Loading…
Reference in New Issue
Block a user