Allow exporting and loading graphs with no variables.

Change: 146257820
This commit is contained in:
David Soergel 2017-02-01 10:14:13 -08:00 committed by TensorFlower Gardener
parent cfdd541f3f
commit eef537ce8d
3 changed files with 55 additions and 12 deletions

View File

@ -336,7 +336,7 @@ class SavedModelBuilder(object):
"""
if not self._has_saved_variables:
raise AssertionError(
"Variables and assets have not been saved yet. "
"Graph state including variables and assets has not been saved yet. "
"Please invoke `add_meta_graph_and_variables()` first.")
# Validate the signature def map to ensure all included TensorInfos are
@ -357,7 +357,8 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver(
variables.global_variables(),
sharded=True,
write_version=saver_pb2.SaverDef.V2)
write_version=saver_pb2.SaverDef.V2,
allow_empty=True)
meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
@ -394,8 +395,9 @@ class SavedModelBuilder(object):
main_op: Op or group of ops to execute when the graph is loaded.
"""
if self._has_saved_variables:
raise AssertionError("Variables and assets have already been saved. "
"Please invoke `add_meta_graph()` instead.")
raise AssertionError("Graph state including variables and assets has "
"already been saved. Please invoke "
"`add_meta_graph()` instead.")
# Validate the signature def map to ensure all included TensorInfos are
# properly populated.
@ -426,7 +428,8 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver(
variables.global_variables(),
sharded=True,
write_version=saver_pb2.SaverDef.V2)
write_version=saver_pb2.SaverDef.V2,
allow_empty=True)
# Save the variables. Also, disable writing the checkpoint state proto. The
# file is not used during SavedModel loading. In addition, since a

View File

@ -27,6 +27,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
@ -210,14 +211,18 @@ def load(sess, tags, export_dir, **saver_kwargs):
# Build a saver by importing the meta graph def to load.
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
# Build the checkpoint path where the variables are located.
variables_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY),
compat.as_bytes(constants.VARIABLES_FILENAME))
if saver:
# Build the checkpoint path where the variables are located.
variables_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY),
compat.as_bytes(constants.VARIABLES_FILENAME))
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)
else:
tf_logging.info("The specified SavedModel has no variables; no "
"checkpoints were restored.")
# Get asset tensors, if any.
asset_tensors_dictionary = _get_asset_tensors(export_dir,

View File

@ -246,6 +246,41 @@ class SavedModelTest(test.TestCase):
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
def testGraphWithoutVariables(self):
export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with no variables.
with self.test_session(graph=ops.Graph()) as sess:
constant_5_name = constant_op.constant(5.0).name
builder.add_meta_graph_and_variables(sess, ["foo"])
# Second graph with no variables
with self.test_session(graph=ops.Graph()) as sess:
constant_6_name = constant_op.constant(6.0).name
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo".
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
b = constant_op.constant(6.0)
c = a * b
self.assertEqual(30.0, sess.run(c))
# Restore the graph with tag "bar".
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
b = constant_op.constant(5.0)
c = a * b
self.assertEqual(30.0, sess.run(c))
def testNoOverwrite(self):
export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite")
builder = saved_model_builder.SavedModelBuilder(export_dir)