Added graph structure output to summarize_graph

Change: 154606362
This commit is contained in:
Pete Warden 2017-04-28 17:43:15 -08:00 committed by TensorFlower Gardener
parent 736a2eb3de
commit ad3c84b58b

View File

@ -102,7 +102,18 @@ void PrintBenchmarkUsage(const std::vector<const NodeDef*> placeholders,
std::cout << std::endl; std::cout << std::endl;
} }
Status SummarizeGraph(const GraphDef& graph, const string& graph_path) { Status PrintStructure(const GraphDef& graph) {
GraphDef sorted_graph;
TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
for (const NodeDef& node : sorted_graph.node()) {
std::cout << node.name() << " (" << node.op() << "): ["
<< str_util::Join(node.input(), ", ") << "]" << std::endl;
}
return Status::OK();
}
Status SummarizeGraph(const GraphDef& graph, const string& graph_path,
bool print_structure) {
std::vector<const NodeDef*> placeholders; std::vector<const NodeDef*> placeholders;
std::vector<const NodeDef*> variables; std::vector<const NodeDef*> variables;
for (const NodeDef& node : graph.node()) { for (const NodeDef& node : graph.node()) {
@ -233,13 +244,20 @@ Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
PrintBenchmarkUsage(placeholders, variables, outputs, graph_path); PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
if (print_structure) {
TF_RETURN_IF_ERROR(PrintStructure(graph));
}
return Status::OK(); return Status::OK();
} }
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) { int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
string in_graph = ""; string in_graph = "";
bool print_structure = false;
std::vector<Flag> flag_list = { std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"), Flag("in_graph", &in_graph, "input graph file name"),
Flag("print_structure", &print_structure,
"whether to print the network connections of the graph"),
}; };
string usage = Flags::Usage(argv[0], flag_list); string usage = Flags::Usage(argv[0], flag_list);
@ -269,7 +287,8 @@ int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
return -1; return -1;
} }
Status summarize_result = SummarizeGraph(graph_def, in_graph); Status summarize_result =
SummarizeGraph(graph_def, in_graph, print_structure);
if (!summarize_result.ok()) { if (!summarize_result.ok()) {
LOG(ERROR) << summarize_result.error_message() << "\n" << usage; LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
return -1; return -1;