Added a new method to extract the graph properties from a cost graph without

having to run the model. This will simplify the process of creating regression
tests

PiperOrigin-RevId: 158050327
This commit is contained in:
Benoit Steiner 2017-06-05 13:00:48 -07:00 committed by TensorFlower Gardener
parent 27f1b80c22
commit 2ccfe8e764
2 changed files with 6 additions and 1 deletions

View File

@ -218,9 +218,13 @@ Status GraphProperties::InferDynamically(Cluster* cluster) {
TF_RETURN_IF_ERROR(
cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
return InferFromCostGraph(metadata.cost_graph());
}
Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
std::unordered_map<string, const NodeDef*> name_to_node; // Empty
for (auto& node : metadata.cost_graph().node()) {
for (auto& node : cost_graph.node()) {
name_to_cost[node.name()] = &node;
std::vector<OpInfo::TensorProperties> output_properties;

View File

@ -36,6 +36,7 @@ class GraphProperties {
Status InferStatically();
Status InferDynamically(Cluster* cluster);
Status InferFromCostGraph(const CostGraphDef& cost_graph);
bool HasOutputProperties(const string& name) const;
std::vector<OpInfo::TensorProperties> GetInputProperties(