Added graph structure output to summarize_graph
Change: 154606362
This commit is contained in:
parent
736a2eb3de
commit
ad3c84b58b
@ -102,7 +102,18 @@ void PrintBenchmarkUsage(const std::vector<const NodeDef*> placeholders,
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
|
||||
Status PrintStructure(const GraphDef& graph) {
|
||||
GraphDef sorted_graph;
|
||||
TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
|
||||
for (const NodeDef& node : sorted_graph.node()) {
|
||||
std::cout << node.name() << " (" << node.op() << "): ["
|
||||
<< str_util::Join(node.input(), ", ") << "]" << std::endl;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SummarizeGraph(const GraphDef& graph, const string& graph_path,
|
||||
bool print_structure) {
|
||||
std::vector<const NodeDef*> placeholders;
|
||||
std::vector<const NodeDef*> variables;
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
@ -233,13 +244,20 @@ Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
|
||||
|
||||
PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
|
||||
|
||||
if (print_structure) {
|
||||
TF_RETURN_IF_ERROR(PrintStructure(graph));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
|
||||
string in_graph = "";
|
||||
bool print_structure = false;
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("in_graph", &in_graph, "input graph file name"),
|
||||
Flag("print_structure", &print_structure,
|
||||
"whether to print the network connections of the graph"),
|
||||
};
|
||||
string usage = Flags::Usage(argv[0], flag_list);
|
||||
|
||||
@ -269,7 +287,8 @@ int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
Status summarize_result = SummarizeGraph(graph_def, in_graph);
|
||||
Status summarize_result =
|
||||
SummarizeGraph(graph_def, in_graph, print_structure);
|
||||
if (!summarize_result.ok()) {
|
||||
LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
|
||||
return -1;
|
||||
|
Loading…
Reference in New Issue
Block a user