Adding support for text output.
Change: 131799027
This commit is contained in:
parent
64e6b7fef4
commit
70401bdb5a
@ -53,6 +53,8 @@ tf.app.flags.DEFINE_boolean("input_binary", False,
|
||||
"""Whether the input files are in binary format.""")
|
||||
tf.app.flags.DEFINE_string("output_graph", "",
|
||||
"""Output 'GraphDef' file name.""")
|
||||
tf.app.flags.DEFINE_boolean("output_binary", True,
|
||||
"""Whether to write a binary format graph.""")
|
||||
tf.app.flags.DEFINE_string("input_node_names", "",
|
||||
"""The name of the input nodes, comma separated.""")
|
||||
tf.app.flags.DEFINE_string("output_node_names", "",
|
||||
@ -66,6 +68,7 @@ def main(unused_args):
|
||||
strip_unused_lib.strip_unused_from_files(FLAGS.input_graph,
|
||||
FLAGS.input_binary,
|
||||
FLAGS.output_graph,
|
||||
FLAGS.output_binary,
|
||||
FLAGS.input_node_names,
|
||||
FLAGS.output_node_names,
|
||||
FLAGS.placeholder_type_enum)
|
||||
|
@ -61,7 +61,7 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
|
||||
|
||||
|
||||
def strip_unused_from_files(input_graph, input_binary, output_graph,
|
||||
input_node_names, output_node_names,
|
||||
output_binary, input_node_names, output_node_names,
|
||||
placeholder_type_enum):
|
||||
"""Removes unused nodes from a graph file."""
|
||||
|
||||
@ -85,6 +85,10 @@ def strip_unused_from_files(input_graph, input_binary, output_graph,
|
||||
output_node_names.split(","),
|
||||
placeholder_type_enum)
|
||||
|
||||
with tf.gfile.GFile(output_graph, "wb") as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
if output_binary:
|
||||
with tf.gfile.GFile(output_graph, "wb") as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
else:
|
||||
with tf.gfile.GFile(output_graph, "w") as f:
|
||||
f.write(text_format.MessageToString(output_graph_def))
|
||||
print("%d ops in the final graph." % len(output_graph_def.node))
|
||||
|
@ -49,11 +49,12 @@ class StripUnusedTest(test_util.TensorFlowTestCase):
|
||||
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
|
||||
input_binary = False
|
||||
input_node_names = "wanted_input_node"
|
||||
output_binary = True
|
||||
output_node_names = "output_node"
|
||||
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
|
||||
|
||||
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
|
||||
output_graph_path,
|
||||
output_graph_path, output_binary,
|
||||
input_node_names,
|
||||
output_node_names,
|
||||
tf.float32.as_datatype_enum)
|
||||
|
Loading…
Reference in New Issue
Block a user