Allow list of initializer nodes to run before freezing the graph.
Change: 115703618
This commit is contained in:
parent
f417134a5a
commit
25a5fbcf64
@ -63,11 +63,13 @@ tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
|
|||||||
"""The name of the tensor holding the save path.""")
|
"""The name of the tensor holding the save path.""")
|
||||||
tf.app.flags.DEFINE_boolean("clear_devices", True,
|
tf.app.flags.DEFINE_boolean("clear_devices", True,
|
||||||
"""Whether to remove device specifications.""")
|
"""Whether to remove device specifications.""")
|
||||||
|
tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
|
||||||
|
"initializer nodes to run before freezing.")
|
||||||
|
|
||||||
|
|
||||||
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
||||||
output_node_names, restore_op_name, filename_tensor_name,
|
output_node_names, restore_op_name, filename_tensor_name,
|
||||||
output_graph, clear_devices):
|
output_graph, clear_devices, initializer_nodes):
|
||||||
"""Converts all variables in a graph and checkpoint into constants."""
|
"""Converts all variables in a graph and checkpoint into constants."""
|
||||||
|
|
||||||
if not tf.gfile.Exists(input_graph):
|
if not tf.gfile.Exists(input_graph):
|
||||||
@ -112,10 +114,12 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
|||||||
saver.restore(sess, input_checkpoint)
|
saver.restore(sess, input_checkpoint)
|
||||||
else:
|
else:
|
||||||
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
|
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
|
||||||
|
if initializer_nodes:
|
||||||
|
sess.run(initializer_nodes)
|
||||||
output_graph_def = graph_util.convert_variables_to_constants(
|
output_graph_def = graph_util.convert_variables_to_constants(
|
||||||
sess, input_graph_def, output_node_names.split(","))
|
sess, input_graph_def, output_node_names.split(","))
|
||||||
|
|
||||||
with tf.gfile.FastGFile(output_graph, "wb") as f:
|
with tf.gfile.GFile(output_graph, "wb") as f:
|
||||||
f.write(output_graph_def.SerializeToString())
|
f.write(output_graph_def.SerializeToString())
|
||||||
print("%d ops in the final graph." % len(output_graph_def.node))
|
print("%d ops in the final graph." % len(output_graph_def.node))
|
||||||
|
|
||||||
@ -124,7 +128,7 @@ def main(unused_args):
|
|||||||
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
|
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
|
||||||
FLAGS.input_checkpoint, FLAGS.output_node_names,
|
FLAGS.input_checkpoint, FLAGS.output_node_names,
|
||||||
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
|
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
|
||||||
FLAGS.output_graph, FLAGS.clear_devices)
|
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.app.run()
|
tf.app.run()
|
||||||
|
@ -66,7 +66,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
input_binary, input_checkpoint_path,
|
input_binary, input_checkpoint_path,
|
||||||
output_node_names, restore_op_name,
|
output_node_names, restore_op_name,
|
||||||
filename_tensor_name, output_graph_path,
|
filename_tensor_name, output_graph_path,
|
||||||
clear_devices)
|
clear_devices, "")
|
||||||
|
|
||||||
# Now we make sure the variable is now a constant, and that the graph still
|
# Now we make sure the variable is now a constant, and that the graph still
|
||||||
# produces the expected result.
|
# produces the expected result.
|
||||||
|
Loading…
Reference in New Issue
Block a user