Allows for importing of graphs with no variable nodes.

Change: 117365836
This commit is contained in:
A. Unique TensorFlower 2016-03-16 10:47:18 -08:00 committed by TensorFlower Gardener
parent 0d0d92ca52
commit 8ae38cb7d8
3 changed files with 75 additions and 5 deletions

View File

@ -404,8 +404,9 @@ class BaseSaverBuilder(object):
if slice_name is None:
slice_name = variable._save_slice_info.full_name
elif slice_name != variable._save_slice_info.full_name:
raise ValueError("Slices must all be from the same tensor: %s != %s"
% (slice_name, variable._save_slice_info.full_name))
raise ValueError(
"Slices must all be from the same tensor: %s != %s"
% (slice_name, variable._save_slice_info.full_name))
self._AddVarToSave(vars_to_save, seen_variables,
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
@ -1292,7 +1293,10 @@ def _import_meta_graph_def(meta_graph_def):
meta_graph_def: `MetaGraphDef` protocol buffer.
Returns:
A saver constructed rom `saver_def` in `meta_graph_def`.
A saver constructed from `saver_def` in `meta_graph_def` or None.
A None value is returned if no variables exist in the `meta_graph_def`
(i.e., no variables to restore).
"""
# Gathers the list of nodes we are interested in.
importer.import_graph_def(meta_graph_def.graph_def, name="")
@ -1331,7 +1335,14 @@ def _import_meta_graph_def(meta_graph_def):
if meta_graph_def.HasField("saver_def"):
return Saver(saver_def=meta_graph_def.saver_def)
else:
return Saver()
if variables.all_variables():
# Return the default saver instance for all graph variables.
return Saver()
else:
# If not graph variables exist, then a Saver cannot be constructed.
logging.info("Saver not created because there are no variables in the"
" graph to restore")
return None
def import_meta_graph(meta_graph_or_file):
@ -1390,7 +1401,10 @@ def import_meta_graph(meta_graph_or_file):
the path) containing a `MetaGraphDef`.
Returns:
A saver constructed rom `saver_def` in `MetaGraphDef`.
A saver constructed rom `saver_def` in `MetaGraphDef` or None.
A None value is returned if no variables exist in the `MetaGraphDef`
(i.e., there are no variables to restore).
"""
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
return _import_meta_graph_def(meta_graph_or_file)

View File

@ -36,6 +36,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.python.framework import function
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_module
def _TestDir(test_name):
@ -760,6 +761,58 @@ class CheckpointStateTest(tf.test.TestCase):
class MetaGraphTest(tf.test.TestCase):
def testNoVariables(self):
test_dir = _TestDir("no_variables")
filename = os.path.join(test_dir, "metafile")
input_feed_value = -10 # Arbitrary input value for feed_dict.
orig_graph = tf.Graph()
with self.test_session(graph=orig_graph) as sess:
# Create a minimal graph with zero variables.
input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
offset = tf.constant(42, dtype=tf.float32, name="offset")
output_tensor = tf.add(input_tensor, offset, name="add_offset")
# Add input and output tensors to graph collections.
tf.add_to_collection("input_tensor", input_tensor)
tf.add_to_collection("output_tensor", output_tensor)
output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
self.assertEqual(output_value, 32)
# Generates MetaGraphDef.
#
# Note that this is calling the saver *module-level* export_meta_graph and
# not the Saver.export_meta_graph instance-level method.
meta_graph_def = saver_module.export_meta_graph(
filename=filename,
graph_def=tf.get_default_graph().as_graph_def(),
collection_list=["input_tensor", "output_tensor"],
saver_def=None,
)
# Create a clean graph and import the MetaGraphDef nodes.
new_graph = tf.Graph()
with self.test_session(graph=new_graph) as sess:
# Import the previously export meta graph.
saver_instance = saver_module.import_meta_graph(filename)
# The saver instance should be None since there are no graph variables
# to be restored in this case.
self.assertIsNone(saver_instance)
# Re-exports the current graph state for comparison to the original.
new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
# Ensures that we can still get a reference to our graph collections.
new_input_tensor = tf.get_collection("input_tensor")[0]
new_output_tensor = tf.get_collection("output_tensor")[0]
# Verifies that the new graph computes the same result as the original.
new_output_value = sess.run(
new_output_tensor, {new_input_tensor: input_feed_value})
self.assertEqual(new_output_value, output_value)
def testAddCollectionDef(self):
test_dir = _TestDir("good_collection")
filename = os.path.join(test_dir, "metafile")
@ -911,6 +964,7 @@ class MetaGraphTest(tf.test.TestCase):
with self.test_session(graph=tf.Graph()):
# Imports the binary format graph.
saver = tf.train.import_meta_graph(filename)
self.assertIsNotNone(saver)
# Exports the graph as text format.
saver.export_meta_graph(filename, as_text=True)
with self.test_session(graph=tf.Graph()):
@ -947,6 +1001,7 @@ class MetaGraphTest(tf.test.TestCase):
with tf.Graph().as_default():
# Restores from MetaGraphDef.
new_saver = tf.train.import_meta_graph(filename)
self.assertIsNotNone(new_saver)
# Generates a new MetaGraphDef.
new_meta_graph_def = new_saver.export_meta_graph()
# It should be the same as the original.

View File

@ -257,6 +257,7 @@ class SupervisorTest(tf.test.TestCase):
# Create a new Graph and Supervisor and recover.
with tf.Graph().as_default():
new_saver = tf.train.import_meta_graph(".".join([filename, "meta"]))
self.assertIsNotNone(new_saver)
sv2 = tf.train.Supervisor(logdir=logdir, saver=new_saver)
sess = sv2.prepare_or_wait_for_session("")
self.assertEquals(1, sess.run("v0:0"))