Allow list of initializer nodes to run before freezing the graph.
Change: 115703618
This commit is contained in:
parent
f417134a5a
commit
25a5fbcf64
@ -63,11 +63,13 @@ tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0",
|
||||
"""The name of the tensor holding the save path.""")
|
||||
tf.app.flags.DEFINE_boolean("clear_devices", True,
|
||||
"""Whether to remove device specifications.""")
|
||||
tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
|
||||
"initializer nodes to run before freezing.")
|
||||
|
||||
|
||||
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
||||
output_node_names, restore_op_name, filename_tensor_name,
|
||||
output_graph, clear_devices):
|
||||
output_graph, clear_devices, initializer_nodes):
|
||||
"""Converts all variables in a graph and checkpoint into constants."""
|
||||
|
||||
if not tf.gfile.Exists(input_graph):
|
||||
@ -112,10 +114,12 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
||||
saver.restore(sess, input_checkpoint)
|
||||
else:
|
||||
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
|
||||
if initializer_nodes:
|
||||
sess.run(initializer_nodes)
|
||||
output_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, input_graph_def, output_node_names.split(","))
|
||||
|
||||
with tf.gfile.FastGFile(output_graph, "wb") as f:
|
||||
with tf.gfile.GFile(output_graph, "wb") as f:
|
||||
f.write(output_graph_def.SerializeToString())
|
||||
print("%d ops in the final graph." % len(output_graph_def.node))
|
||||
|
||||
@ -124,7 +128,7 @@ def main(unused_args):
|
||||
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
|
||||
FLAGS.input_checkpoint, FLAGS.output_node_names,
|
||||
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
|
||||
FLAGS.output_graph, FLAGS.clear_devices)
|
||||
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
|
@ -66,7 +66,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
input_binary, input_checkpoint_path,
|
||||
output_node_names, restore_op_name,
|
||||
filename_tensor_name, output_graph_path,
|
||||
clear_devices)
|
||||
clear_devices, "")
|
||||
|
||||
# Now we make sure the variable is now a constant, and that the graph still
|
||||
# produces the expected result.
|
||||
|
Loading…
Reference in New Issue
Block a user