Allow list of initializer nodes to run before freezing the graph.

Change: 115703618
This commit is contained in:
A. Unique TensorFlower 2016-02-26 13:34:56 -08:00 committed by TensorFlower Gardener
parent f417134a5a
commit 25a5fbcf64
2 changed files with 8 additions and 4 deletions

View File

@ -63,11 +63,13 @@ tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
"""The name of the tensor holding the save path.""")
tf.app.flags.DEFINE_boolean("clear_devices", True,
"""Whether to remove device specifications.""")
tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
"initializer nodes to run before freezing.")
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices):
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
@ -112,10 +114,12 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","))
with tf.gfile.FastGFile(output_graph, "wb") as f:
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
@ -124,7 +128,7 @@ def main(unused_args):
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
FLAGS.input_checkpoint, FLAGS.output_node_names,
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
FLAGS.output_graph, FLAGS.clear_devices)
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
if __name__ == "__main__":
tf.app.run()

View File

@ -66,7 +66,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
input_binary, input_checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices)
clear_devices, "")
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.