diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index e76472291f9..3ef6c2ae954 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -85,9 +85,8 @@ Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) { } TF_RETURN_IF_ERROR(estimator_->Initialize(item)); - Costs ignored_costs; TF_RETURN_IF_ERROR( - estimator_->PredictCosts(item.graph, metadata, &ignored_costs)); + estimator_->PredictCosts(item.graph, metadata, /*cost=*/nullptr)); const std::unordered_map& device = GetDevices(); std::unordered_map peak_mem_usage = diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index a7e81847cac..a85e293ac00 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -149,12 +149,24 @@ Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, RunMetadata* run_metadata, Costs* costs) const { - GraphDef graph_copy = optimized_graph; - GrapplerItem item = item_->WithGraph(std::move(graph_copy)); + std::unique_ptr item_storage; + const GrapplerItem* item; + // Many callers to PredictCosts() pass the same optimized_graph as was used + // to initialize the estimator. + if (&optimized_graph == &item_->graph) { + item = item_; + } else { + GraphDef graph_copy = optimized_graph; + item_storage = absl::make_unique( + item_->WithGraph(std::move(graph_copy))); + item = item_storage.get(); + } - auto status = scheduler_->Init(&item); + auto status = scheduler_->Init(item); if (!status.ok()) { - costs->execution_time = Costs::Duration::max(); + if (costs) { + costs->execution_time = Costs::Duration::max(); + } return status; } @@ -203,7 +215,11 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, } // run_metadata gets step_stats and partition_graphs from Summary. - *costs = scheduler_->Summary(run_metadata); + if (costs) { + *costs = scheduler_->Summary(run_metadata); + } else if (run_metadata) { + scheduler_->GenerateRunMetadata(run_metadata); + } if (VLOG_IS_ON(1)) { bool verbose = VLOG_IS_ON(2);