From 35279e5c61d3ff5882c31d3d291256321ba27267 Mon Sep 17 00:00:00 2001 From: Doe Hyun Yoon Date: Tue, 8 Jan 2019 11:37:17 -0800 Subject: [PATCH] Replace the references to PredictCosts() with PredictCostsAndReturnRunMetadata(). PiperOrigin-RevId: 228368607 --- .../core/grappler/costs/analytical_cost_estimator_test.cc | 5 +++-- tensorflow/python/grappler/cost_analyzer.cc | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc index a9a1abfa989..8a563443782 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -98,9 +98,10 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { AnalyticalCostEstimator estimator(cluster_.get(), true); TF_ASSERT_OK(estimator.Initialize(item)); - CostGraphDef cost_graph; + RunMetadata run_metadata; Costs summary; - TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary)); + TF_ASSERT_OK(estimator.PredictCostsAndReturnRunMetadata( + item.graph, &run_metadata, &summary)); EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time); // Note there are totally 17 nodes (RandomUniform creates 2 nodes), but diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index b474e198949..bb8c6d5b855 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -42,9 +42,13 @@ Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report, void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph, int64* total_time) { TF_CHECK_OK(cost_estimator->Initialize(*item_)); + RunMetadata run_metadata; Costs costs; - const Status status = - cost_estimator->PredictCosts(item_->graph, cost_graph, &costs); + const Status status = cost_estimator->PredictCostsAndReturnRunMetadata( + item_->graph, &run_metadata, &costs); + if (cost_graph) { + cost_graph->Swap(run_metadata.mutable_cost_graph()); + } *total_time = costs.execution_time.count(); if (!status.ok()) { LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "