diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 721cf535b9c..5c0515803d3 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -218,9 +218,13 @@ Status GraphProperties::InferDynamically(Cluster* cluster) { TF_RETURN_IF_ERROR( cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata)); + return InferFromCostGraph(metadata.cost_graph()); +} + +Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) { std::unordered_map name_to_cost; std::unordered_map name_to_node; // Empty - for (auto& node : metadata.cost_graph().node()) { + for (auto& node : cost_graph.node()) { name_to_cost[node.name()] = &node; std::vector output_properties; diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index d2b466175be..f2f6ad19a7e 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -36,6 +36,7 @@ class GraphProperties { Status InferStatically(); Status InferDynamically(Cluster* cluster); + Status InferFromCostGraph(const CostGraphDef& cost_graph); bool HasOutputProperties(const string& name) const; std::vector GetInputProperties(