From 25a5fbcf642c5b65ce1302557c3e77db377a62dd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Feb 2016 13:34:56 -0800 Subject: [PATCH] Allow list of initializer nodes to run before freezing the graph. Change: 115703618 --- tensorflow/python/tools/freeze_graph.py | 10 +++++++--- tensorflow/python/tools/freeze_graph_test.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 68fa832cb78..112bd0540e0 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -63,11 +63,13 @@ tf.app.flags.DEFINE_string("filename_tensor_name", "save/Const:0", """The name of the tensor holding the save path.""") tf.app.flags.DEFINE_boolean("clear_devices", True, """Whether to remove device specifications.""") +tf.app.flags.DEFINE_string("initializer_nodes", "", "comma separated list of " + "initializer nodes to run before freezing.") def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, - output_graph, clear_devices): + output_graph, clear_devices, initializer_nodes): """Converts all variables in a graph and checkpoint into constants.""" if not tf.gfile.Exists(input_graph): @@ -112,10 +114,12 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) + if initializer_nodes: + sess.run(initializer_nodes) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) - with tf.gfile.FastGFile(output_graph, "wb") as f: + with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node)) @@ -124,7 +128,7 @@ def main(unused_args): freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, FLAGS.input_checkpoint, FLAGS.output_node_names, FLAGS.restore_op_name, FLAGS.filename_tensor_name, - FLAGS.output_graph, FLAGS.clear_devices) + FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes) if __name__ == "__main__": tf.app.run() diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index 9ed67f00bfc..caa046b0a56 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -66,7 +66,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): input_binary, input_checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, output_graph_path, - clear_devices) + clear_devices, "") # Now we make sure the variable is now a constant, and that the graph still # produces the expected result.