Replace the references to PredictCosts() with PredictCostsAndReturnRunMetadata().

PiperOrigin-RevId: 228368607
This commit is contained in:
Doe Hyun Yoon 2019-01-08 11:37:17 -08:00 committed by TensorFlower Gardener
parent 1bdf37e722
commit 35279e5c61
2 changed files with 9 additions and 4 deletions

View File

@ -98,9 +98,10 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
AnalyticalCostEstimator estimator(cluster_.get(), true);
TF_ASSERT_OK(estimator.Initialize(item));
CostGraphDef cost_graph;
RunMetadata run_metadata;
Costs summary;
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary));
TF_ASSERT_OK(estimator.PredictCostsAndReturnRunMetadata(
item.graph, &run_metadata, &summary));
EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time);
// Note there are totally 17 nodes (RandomUniform creates 2 nodes), but

View File

@ -42,9 +42,13 @@ Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report,
void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator,
CostGraphDef* cost_graph, int64* total_time) {
TF_CHECK_OK(cost_estimator->Initialize(*item_));
RunMetadata run_metadata;
Costs costs;
const Status status =
cost_estimator->PredictCosts(item_->graph, cost_graph, &costs);
const Status status = cost_estimator->PredictCostsAndReturnRunMetadata(
item_->graph, &run_metadata, &costs);
if (cost_graph) {
cost_graph->Swap(run_metadata.mutable_cost_graph());
}
*total_time = costs.execution_time.count();
if (!status.ok()) {
LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "