From 8ae38cb7d80ee4155c54ae3b704063def227b102 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Mar 2016 10:47:18 -0800 Subject: [PATCH] Allows for importing of graphs with no variable nodes. Change: 117365836 --- tensorflow/python/training/saver.py | 24 ++++++-- tensorflow/python/training/saver_test.py | 55 +++++++++++++++++++ tensorflow/python/training/supervisor_test.py | 1 + 3 files changed, 75 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 38ba307c31b..e1e399054d7 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -404,8 +404,9 @@ class BaseSaverBuilder(object): if slice_name is None: slice_name = variable._save_slice_info.full_name elif slice_name != variable._save_slice_info.full_name: - raise ValueError("Slices must all be from the same tensor: %s != %s" - % (slice_name, variable._save_slice_info.full_name)) + raise ValueError( + "Slices must all be from the same tensor: %s != %s" + % (slice_name, variable._save_slice_info.full_name)) self._AddVarToSave(vars_to_save, seen_variables, variable, variable._save_slice_info.spec, name) # pylint: enable=protected-access @@ -1292,7 +1293,10 @@ def _import_meta_graph_def(meta_graph_def): meta_graph_def: `MetaGraphDef` protocol buffer. Returns: - A saver constructed rom `saver_def` in `meta_graph_def`. + A saver constructed from `saver_def` in `meta_graph_def` or None. + + A None value is returned if no variables exist in the `meta_graph_def` + (i.e., no variables to restore). """ # Gathers the list of nodes we are interested in. importer.import_graph_def(meta_graph_def.graph_def, name="") @@ -1331,7 +1335,14 @@ def _import_meta_graph_def(meta_graph_def): if meta_graph_def.HasField("saver_def"): return Saver(saver_def=meta_graph_def.saver_def) else: - return Saver() + if variables.all_variables(): + # Return the default saver instance for all graph variables. + return Saver() + else: + # If not graph variables exist, then a Saver cannot be constructed. + logging.info("Saver not created because there are no variables in the" + " graph to restore") + return None def import_meta_graph(meta_graph_or_file): @@ -1390,7 +1401,10 @@ def import_meta_graph(meta_graph_or_file): the path) containing a `MetaGraphDef`. Returns: - A saver constructed rom `saver_def` in `MetaGraphDef`. + A saver constructed rom `saver_def` in `MetaGraphDef` or None. + + A None value is returned if no variables exist in the `MetaGraphDef` + (i.e., there are no variables to restore). """ if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef): return _import_meta_graph_def(meta_graph_or_file) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index ffb3fb80d39..0fac611aecc 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -36,6 +36,7 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import queue_runner_pb2 from tensorflow.python.framework import function from tensorflow.python.platform import gfile +from tensorflow.python.training import saver as saver_module def _TestDir(test_name): @@ -760,6 +761,58 @@ class CheckpointStateTest(tf.test.TestCase): class MetaGraphTest(tf.test.TestCase): + def testNoVariables(self): + test_dir = _TestDir("no_variables") + filename = os.path.join(test_dir, "metafile") + + input_feed_value = -10 # Arbitrary input value for feed_dict. + + orig_graph = tf.Graph() + with self.test_session(graph=orig_graph) as sess: + # Create a minimal graph with zero variables. + input_tensor = tf.placeholder(tf.float32, shape=[], name="input") + offset = tf.constant(42, dtype=tf.float32, name="offset") + output_tensor = tf.add(input_tensor, offset, name="add_offset") + + # Add input and output tensors to graph collections. + tf.add_to_collection("input_tensor", input_tensor) + tf.add_to_collection("output_tensor", output_tensor) + + output_value = sess.run(output_tensor, {input_tensor: input_feed_value}) + self.assertEqual(output_value, 32) + + # Generates MetaGraphDef. + # + # Note that this is calling the saver *module-level* export_meta_graph and + # not the Saver.export_meta_graph instance-level method. + meta_graph_def = saver_module.export_meta_graph( + filename=filename, + graph_def=tf.get_default_graph().as_graph_def(), + collection_list=["input_tensor", "output_tensor"], + saver_def=None, + ) + + # Create a clean graph and import the MetaGraphDef nodes. + new_graph = tf.Graph() + with self.test_session(graph=new_graph) as sess: + # Import the previously export meta graph. + saver_instance = saver_module.import_meta_graph(filename) + # The saver instance should be None since there are no graph variables + # to be restored in this case. + self.assertIsNone(saver_instance) + + # Re-exports the current graph state for comparison to the original. + new_meta_graph_def = saver_module.export_meta_graph(filename + "_new") + self.assertProtoEquals(meta_graph_def, new_meta_graph_def) + + # Ensures that we can still get a reference to our graph collections. + new_input_tensor = tf.get_collection("input_tensor")[0] + new_output_tensor = tf.get_collection("output_tensor")[0] + # Verifies that the new graph computes the same result as the original. + new_output_value = sess.run( + new_output_tensor, {new_input_tensor: input_feed_value}) + self.assertEqual(new_output_value, output_value) + def testAddCollectionDef(self): test_dir = _TestDir("good_collection") filename = os.path.join(test_dir, "metafile") @@ -911,6 +964,7 @@ class MetaGraphTest(tf.test.TestCase): with self.test_session(graph=tf.Graph()): # Imports the binary format graph. saver = tf.train.import_meta_graph(filename) + self.assertIsNotNone(saver) # Exports the graph as text format. saver.export_meta_graph(filename, as_text=True) with self.test_session(graph=tf.Graph()): @@ -947,6 +1001,7 @@ class MetaGraphTest(tf.test.TestCase): with tf.Graph().as_default(): # Restores from MetaGraphDef. new_saver = tf.train.import_meta_graph(filename) + self.assertIsNotNone(new_saver) # Generates a new MetaGraphDef. new_meta_graph_def = new_saver.export_meta_graph() # It should be the same as the original. diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 33fec68c1dd..e1b8cb80909 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -257,6 +257,7 @@ class SupervisorTest(tf.test.TestCase): # Create a new Graph and Supervisor and recover. with tf.Graph().as_default(): new_saver = tf.train.import_meta_graph(".".join([filename, "meta"])) + self.assertIsNotNone(new_saver) sv2 = tf.train.Supervisor(logdir=logdir, saver=new_saver) sess = sv2.prepare_or_wait_for_session("") self.assertEquals(1, sess.run("v0:0"))