diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 4c56eccdad5..42247c664ec 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -732,8 +732,9 @@ Status GraphExecutionState::OptimizeGraph( } grappler::VirtualCluster cluster(device_set_); GraphDef new_graph; - TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer( - item, session_options_->config, cpu_device, &cluster, &new_graph)); + TF_RETURN_IF_ERROR( + grappler::RunMetaOptimizer(std::move(item), session_options_->config, + cpu_device, &cluster, &new_graph)); // Merge optimized graph function library with an original library. // Optimized graph might have new functions specialized for it's diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h index de678d0a390..bff200cbe18 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -58,6 +58,12 @@ class GraphOptimizer { virtual Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) = 0; + // Subclasses may define a version of Optimize that consumes item. + virtual Status Optimize(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph) { + return Optimize(cluster, item, optimized_graph); + } + // Method invoked by the framework so that it can provide feedback // on how well the "optimized_graph" (produced as *optimized_graph from a // call to Optimize) performed. Lower "result" scores are better. diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 82758d1f970..0c8fa0449f5 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -379,7 +379,7 @@ void MetaOptimizer::InitializeVerifiers( } } -Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, +Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes : cfg_.min_graph_nodes(); @@ -426,8 +426,8 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, // Invariant: optimized_graph contains the most recently optimized version of // the graph. - GrapplerItem optimized_item = item; - optimized_graph->Swap(&optimized_item.graph); + auto original_producer = item.graph.versions().producer(); + optimized_graph->Swap(&item.graph); GraphOptimizationResult optimization_result(item.id); GraphOptimizer* sa_optimizer = nullptr; @@ -465,7 +465,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, continue; } - TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &optimized_item, + TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item, optimized_graph, &optimization_result)); if (iteration == 0 && optimizer->name() == "model_pruner") { @@ -498,7 +498,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, // ScopedAllocatorOptimizer must run last. if (sa_optimizer != nullptr) { - TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &optimized_item, + TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item, optimized_graph, &optimization_result)); GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); } @@ -516,8 +516,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); ReassignColocation(optimized_graph); // Make sure that the optimizers preserved the graph version. - DCHECK_EQ(optimized_graph->versions().producer(), - item.graph.versions().producer()); + DCHECK_EQ(optimized_graph->versions().producer(), original_producer); } return Status::OK(); @@ -590,8 +589,8 @@ Status MetaOptimizer::RunOptimizer( return Status::OK(); } -Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { +Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph) { VLOG(1) << "Starting optimization for grappler item: " << item.id; optimization_results_.clear(); @@ -609,21 +608,21 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // remove all the unreachable functions. // TODO(ezhulenev): Construct reachable function library definition directly // from the proto without constructing temporary FunctionLibraryDefinition. - GraphDef trimmed_graph; // do not copy graph with a potentially huge library - *trimmed_graph.mutable_node() = item.graph.node(); - *trimmed_graph.mutable_versions() = item.graph.versions(); - *trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto(); - - GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph)); + *item.graph.mutable_library() = minimized_flib(item.graph).ToProto(); VLOG(1) << absl::Substitute( "Deleted $0 unreachable functions from the graph (library size = $1)", item.graph.library().function_size() - - trimmed_item.graph.library().function_size(), - trimmed_item.graph.library().function_size()); + item.graph.library().function_size(), + item.graph.library().function_size()); + + // Save a few small fields from item before we move it. + bool optimize_function_library = + item.optimization_options().optimize_function_library; + const auto producer = item.graph.versions().producer(); // 1. Optimize main graph - TF_RETURN_IF_ERROR(OptimizeGraph(cluster, trimmed_item, optimized_graph)); + TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(item), optimized_graph)); VLOG(1) << "Optimized main graph."; GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); @@ -675,9 +674,6 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Optimize each function only once. absl::flat_hash_set optimized_funcs; - bool optimize_function_library = - item.optimization_options().optimize_function_library; - while (optimize_function_library) { optimize_function_library = false; @@ -711,8 +707,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Make a GrapplerItem from a FunctionDef. GrapplerFunctionItem func_item; - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( - func, flib, trimmed_item.graph.versions().producer(), &func_item)); + TF_RETURN_IF_ERROR( + MakeGrapplerFunctionItem(func, flib, producer, &func_item)); // If we need to compute the gradient of optimized function at runtime, we // can't perform non-differentiable rewrites. @@ -760,8 +756,9 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(implementation_selector.Optimize( cluster, func_item, &optimized_func_graph)); } else { - TF_RETURN_IF_ERROR( - OptimizeGraph(cluster, func_item, &optimized_func_graph)); + GrapplerFunctionItem func_item_copy = func_item; + TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy), + &optimized_func_graph)); } // Function body optimization might have created new specialized @@ -834,13 +831,14 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) { !rewrite_cfg.custom_optimizers().empty(); } -Status RunMetaOptimizer(const GrapplerItem& item, const ConfigProto& cfg, +Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg, DeviceBase* cpu_device, Cluster* cluster, GraphDef* optimized_graph) { MetaOptimizer optimizer(cpu_device, cfg); optimizer.set_deadline_usec( DeadlineMicroSeconds(cfg.graph_options().rewrite_options())); - return optimizer.Optimize(cluster, item, optimized_graph); + return optimizer.OptimizeConsumeItem(cluster, std::move(item), + optimized_graph); } Status OptimizeGraph( @@ -883,7 +881,7 @@ Status OptimizeGraph( // TODO(nareshmodi): Consider adding and using the more generic GraphOptions // proto (which also contain the OptimizerOptions). TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( - item, config_proto, cpu_device, &cluster, &out_graph)); + std::move(item), config_proto, cpu_device, &cluster, &out_graph)); std::unique_ptr optimized_graph( new tensorflow::Graph(OpRegistry::Global())); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 18392f667b4..f39f0b62bb6 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -42,7 +42,13 @@ class MetaOptimizer : public GraphOptimizer { bool UsesFunctionLibrary() const override { return true; } Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) override; + GraphDef* optimized_graph) override { + GrapplerItem copy(item); + return OptimizeConsumeItem(cluster, std::move(copy), optimized_graph); + } + + Status OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph); void PrintResult(); @@ -77,7 +83,7 @@ class MetaOptimizer : public GraphOptimizer { // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library - Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item, + Status OptimizeGraph(Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph); DeviceBase* const cpu_device_; // may be NULL @@ -111,7 +117,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg); // during constant folding; if NULL, a new device is created for doing constant // folding. For performance, it is recommended to pass in an existing cpu_device // when possible. -Status RunMetaOptimizer(const GrapplerItem& item, const ConfigProto& cfg, +Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg, DeviceBase* cpu_device, Cluster* cluster, GraphDef* optimized_graph); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 0b40363ac7d..595b636c7a9 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -722,12 +722,13 @@ TEST_F(MetaOptimizerTest, OptimizerTimesOut) { rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE); GraphDef output; + GraphDef original = item.graph; const Status status = - RunMetaOptimizer(item, config, nullptr, nullptr, &output); + RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output); EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline."); // Make sure the graph was reverted to the original regardless of when the // optimizer timed out. - CompareGraphs(item.graph, output); + CompareGraphs(original, output); } TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) { @@ -744,11 +745,12 @@ TEST_F(MetaOptimizerTest, MetaOptimizerTimesOut) { rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); GraphDef output; + const int original_node_size = item.graph.node_size(); const Status status = - RunMetaOptimizer(item, config, nullptr, nullptr, &output); + RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output); EXPECT_EQ(status.error_message(), "meta_optimizer exceeded deadline."); // The meta optimizer should manage to finish one iteration. - EXPECT_EQ(item.graph.node_size() + 1, output.node_size()); + EXPECT_EQ(original_node_size + 1, output.node_size()); } TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) { @@ -764,11 +766,12 @@ TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) { rewriter_config.set_meta_optimizer_timeout_ms(2500); rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); GraphDef output; + const int original_node_size = item.graph.node_size(); const Status status = - RunMetaOptimizer(item, config, nullptr, nullptr, &output); + RunMetaOptimizer(std::move(item), config, nullptr, nullptr, &output); TF_EXPECT_OK(status); // The meta optimizer should manage to finish two iterations. - EXPECT_EQ(item.graph.node_size() + 2, output.node_size()); + EXPECT_EQ(original_node_size + 2, output.node_size()); } TEST_F(MetaOptimizerTest, RunPostOptimizationVerifiersOnValidGraph) { diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc index 383d25998db..3717016bba4 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.cc +++ b/tensorflow/core/kernels/data/rewrite_utils.cc @@ -126,7 +126,7 @@ Status ApplyRewrites(OpKernelContext* ctx, tensorflow::ConfigProto config; *config.mutable_graph_options()->mutable_rewrite_options() = config_factory(); TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( - *grappler_item, config, ctx->device(), &cluster, graph_def)); + std::move(*grappler_item), config, ctx->device(), &cluster, graph_def)); // Remove fake sinks after optimizations are done. //