Allows for importing of graphs with no variable nodes.
Change: 117365836
This commit is contained in:
parent
0d0d92ca52
commit
8ae38cb7d8
@ -404,8 +404,9 @@ class BaseSaverBuilder(object):
|
||||
if slice_name is None:
|
||||
slice_name = variable._save_slice_info.full_name
|
||||
elif slice_name != variable._save_slice_info.full_name:
|
||||
raise ValueError("Slices must all be from the same tensor: %s != %s"
|
||||
% (slice_name, variable._save_slice_info.full_name))
|
||||
raise ValueError(
|
||||
"Slices must all be from the same tensor: %s != %s"
|
||||
% (slice_name, variable._save_slice_info.full_name))
|
||||
self._AddVarToSave(vars_to_save, seen_variables,
|
||||
variable, variable._save_slice_info.spec, name)
|
||||
# pylint: enable=protected-access
|
||||
@ -1292,7 +1293,10 @@ def _import_meta_graph_def(meta_graph_def):
|
||||
meta_graph_def: `MetaGraphDef` protocol buffer.
|
||||
|
||||
Returns:
|
||||
A saver constructed rom `saver_def` in `meta_graph_def`.
|
||||
A saver constructed from `saver_def` in `meta_graph_def` or None.
|
||||
|
||||
A None value is returned if no variables exist in the `meta_graph_def`
|
||||
(i.e., no variables to restore).
|
||||
"""
|
||||
# Gathers the list of nodes we are interested in.
|
||||
importer.import_graph_def(meta_graph_def.graph_def, name="")
|
||||
@ -1331,7 +1335,14 @@ def _import_meta_graph_def(meta_graph_def):
|
||||
if meta_graph_def.HasField("saver_def"):
|
||||
return Saver(saver_def=meta_graph_def.saver_def)
|
||||
else:
|
||||
return Saver()
|
||||
if variables.all_variables():
|
||||
# Return the default saver instance for all graph variables.
|
||||
return Saver()
|
||||
else:
|
||||
# If not graph variables exist, then a Saver cannot be constructed.
|
||||
logging.info("Saver not created because there are no variables in the"
|
||||
" graph to restore")
|
||||
return None
|
||||
|
||||
|
||||
def import_meta_graph(meta_graph_or_file):
|
||||
@ -1390,7 +1401,10 @@ def import_meta_graph(meta_graph_or_file):
|
||||
the path) containing a `MetaGraphDef`.
|
||||
|
||||
Returns:
|
||||
A saver constructed rom `saver_def` in `MetaGraphDef`.
|
||||
A saver constructed rom `saver_def` in `MetaGraphDef` or None.
|
||||
|
||||
A None value is returned if no variables exist in the `MetaGraphDef`
|
||||
(i.e., there are no variables to restore).
|
||||
"""
|
||||
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
|
||||
return _import_meta_graph_def(meta_graph_or_file)
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import queue_runner_pb2
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.training import saver as saver_module
|
||||
|
||||
|
||||
def _TestDir(test_name):
|
||||
@ -760,6 +761,58 @@ class CheckpointStateTest(tf.test.TestCase):
|
||||
|
||||
class MetaGraphTest(tf.test.TestCase):
|
||||
|
||||
def testNoVariables(self):
|
||||
test_dir = _TestDir("no_variables")
|
||||
filename = os.path.join(test_dir, "metafile")
|
||||
|
||||
input_feed_value = -10 # Arbitrary input value for feed_dict.
|
||||
|
||||
orig_graph = tf.Graph()
|
||||
with self.test_session(graph=orig_graph) as sess:
|
||||
# Create a minimal graph with zero variables.
|
||||
input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
|
||||
offset = tf.constant(42, dtype=tf.float32, name="offset")
|
||||
output_tensor = tf.add(input_tensor, offset, name="add_offset")
|
||||
|
||||
# Add input and output tensors to graph collections.
|
||||
tf.add_to_collection("input_tensor", input_tensor)
|
||||
tf.add_to_collection("output_tensor", output_tensor)
|
||||
|
||||
output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
|
||||
self.assertEqual(output_value, 32)
|
||||
|
||||
# Generates MetaGraphDef.
|
||||
#
|
||||
# Note that this is calling the saver *module-level* export_meta_graph and
|
||||
# not the Saver.export_meta_graph instance-level method.
|
||||
meta_graph_def = saver_module.export_meta_graph(
|
||||
filename=filename,
|
||||
graph_def=tf.get_default_graph().as_graph_def(),
|
||||
collection_list=["input_tensor", "output_tensor"],
|
||||
saver_def=None,
|
||||
)
|
||||
|
||||
# Create a clean graph and import the MetaGraphDef nodes.
|
||||
new_graph = tf.Graph()
|
||||
with self.test_session(graph=new_graph) as sess:
|
||||
# Import the previously export meta graph.
|
||||
saver_instance = saver_module.import_meta_graph(filename)
|
||||
# The saver instance should be None since there are no graph variables
|
||||
# to be restored in this case.
|
||||
self.assertIsNone(saver_instance)
|
||||
|
||||
# Re-exports the current graph state for comparison to the original.
|
||||
new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
|
||||
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
|
||||
|
||||
# Ensures that we can still get a reference to our graph collections.
|
||||
new_input_tensor = tf.get_collection("input_tensor")[0]
|
||||
new_output_tensor = tf.get_collection("output_tensor")[0]
|
||||
# Verifies that the new graph computes the same result as the original.
|
||||
new_output_value = sess.run(
|
||||
new_output_tensor, {new_input_tensor: input_feed_value})
|
||||
self.assertEqual(new_output_value, output_value)
|
||||
|
||||
def testAddCollectionDef(self):
|
||||
test_dir = _TestDir("good_collection")
|
||||
filename = os.path.join(test_dir, "metafile")
|
||||
@ -911,6 +964,7 @@ class MetaGraphTest(tf.test.TestCase):
|
||||
with self.test_session(graph=tf.Graph()):
|
||||
# Imports the binary format graph.
|
||||
saver = tf.train.import_meta_graph(filename)
|
||||
self.assertIsNotNone(saver)
|
||||
# Exports the graph as text format.
|
||||
saver.export_meta_graph(filename, as_text=True)
|
||||
with self.test_session(graph=tf.Graph()):
|
||||
@ -947,6 +1001,7 @@ class MetaGraphTest(tf.test.TestCase):
|
||||
with tf.Graph().as_default():
|
||||
# Restores from MetaGraphDef.
|
||||
new_saver = tf.train.import_meta_graph(filename)
|
||||
self.assertIsNotNone(new_saver)
|
||||
# Generates a new MetaGraphDef.
|
||||
new_meta_graph_def = new_saver.export_meta_graph()
|
||||
# It should be the same as the original.
|
||||
|
@ -257,6 +257,7 @@ class SupervisorTest(tf.test.TestCase):
|
||||
# Create a new Graph and Supervisor and recover.
|
||||
with tf.Graph().as_default():
|
||||
new_saver = tf.train.import_meta_graph(".".join([filename, "meta"]))
|
||||
self.assertIsNotNone(new_saver)
|
||||
sv2 = tf.train.Supervisor(logdir=logdir, saver=new_saver)
|
||||
sess = sv2.prepare_or_wait_for_session("")
|
||||
self.assertEquals(1, sess.run("v0:0"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user