Added graph structure output to summarize_graph
Change: 154606362
This commit is contained in:
parent
736a2eb3de
commit
ad3c84b58b
@ -102,7 +102,18 @@ void PrintBenchmarkUsage(const std::vector<const NodeDef*> placeholders,
|
|||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
|
Status PrintStructure(const GraphDef& graph) {
|
||||||
|
GraphDef sorted_graph;
|
||||||
|
TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
|
||||||
|
for (const NodeDef& node : sorted_graph.node()) {
|
||||||
|
std::cout << node.name() << " (" << node.op() << "): ["
|
||||||
|
<< str_util::Join(node.input(), ", ") << "]" << std::endl;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SummarizeGraph(const GraphDef& graph, const string& graph_path,
|
||||||
|
bool print_structure) {
|
||||||
std::vector<const NodeDef*> placeholders;
|
std::vector<const NodeDef*> placeholders;
|
||||||
std::vector<const NodeDef*> variables;
|
std::vector<const NodeDef*> variables;
|
||||||
for (const NodeDef& node : graph.node()) {
|
for (const NodeDef& node : graph.node()) {
|
||||||
@ -233,13 +244,20 @@ Status SummarizeGraph(const GraphDef& graph, const string& graph_path) {
|
|||||||
|
|
||||||
PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
|
PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
|
||||||
|
|
||||||
|
if (print_structure) {
|
||||||
|
TF_RETURN_IF_ERROR(PrintStructure(graph));
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
|
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
|
||||||
string in_graph = "";
|
string in_graph = "";
|
||||||
|
bool print_structure = false;
|
||||||
std::vector<Flag> flag_list = {
|
std::vector<Flag> flag_list = {
|
||||||
Flag("in_graph", &in_graph, "input graph file name"),
|
Flag("in_graph", &in_graph, "input graph file name"),
|
||||||
|
Flag("print_structure", &print_structure,
|
||||||
|
"whether to print the network connections of the graph"),
|
||||||
};
|
};
|
||||||
string usage = Flags::Usage(argv[0], flag_list);
|
string usage = Flags::Usage(argv[0], flag_list);
|
||||||
|
|
||||||
@ -269,7 +287,8 @@ int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status summarize_result = SummarizeGraph(graph_def, in_graph);
|
Status summarize_result =
|
||||||
|
SummarizeGraph(graph_def, in_graph, print_structure);
|
||||||
if (!summarize_result.ok()) {
|
if (!summarize_result.ok()) {
|
||||||
LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
|
LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
|
||||||
return -1;
|
return -1;
|
||||||
|
Loading…
Reference in New Issue
Block a user