Adding support for text output.

Change: 131799027
This commit is contained in:
A. Unique TensorFlower 2016-08-30 23:21:27 -08:00 committed by TensorFlower Gardener
parent 64e6b7fef4
commit 70401bdb5a
3 changed files with 12 additions and 4 deletions

View File

@ -53,6 +53,8 @@ tf.app.flags.DEFINE_boolean("input_binary", False,
"""Whether the input files are in binary format.""")
tf.app.flags.DEFINE_string("output_graph", "",
"""Output 'GraphDef' file name.""")
tf.app.flags.DEFINE_boolean("output_binary", True,
"""Whether to write a binary format graph.""")
tf.app.flags.DEFINE_string("input_node_names", "",
"""The name of the input nodes, comma separated.""")
tf.app.flags.DEFINE_string("output_node_names", "",
@ -66,6 +68,7 @@ def main(unused_args):
strip_unused_lib.strip_unused_from_files(FLAGS.input_graph,
FLAGS.input_binary,
FLAGS.output_graph,
FLAGS.output_binary,
FLAGS.input_node_names,
FLAGS.output_node_names,
FLAGS.placeholder_type_enum)

View File

@ -61,7 +61,7 @@ def strip_unused(input_graph_def, input_node_names, output_node_names,
def strip_unused_from_files(input_graph, input_binary, output_graph,
input_node_names, output_node_names,
output_binary, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a graph file."""
@ -85,6 +85,10 @@ def strip_unused_from_files(input_graph, input_binary, output_graph,
output_node_names.split(","),
placeholder_type_enum)
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
if output_binary:
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
else:
with tf.gfile.GFile(output_graph, "w") as f:
f.write(text_format.MessageToString(output_graph_def))
print("%d ops in the final graph." % len(output_graph_def.node))

View File

@ -49,11 +49,12 @@ class StripUnusedTest(test_util.TensorFlowTestCase):
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
input_binary = False
input_node_names = "wanted_input_node"
output_binary = True
output_node_names = "output_node"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
output_graph_path,
output_graph_path, output_binary,
input_node_names,
output_node_names,
tf.float32.as_datatype_enum)