Allow exporting and loading graphs with no variables.
Change: 146257820
This commit is contained in:
parent
cfdd541f3f
commit
eef537ce8d
tensorflow/python/saved_model
@ -336,7 +336,7 @@ class SavedModelBuilder(object):
|
||||
"""
|
||||
if not self._has_saved_variables:
|
||||
raise AssertionError(
|
||||
"Variables and assets have not been saved yet. "
|
||||
"Graph state including variables and assets has not been saved yet. "
|
||||
"Please invoke `add_meta_graph_and_variables()` first.")
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
@ -357,7 +357,8 @@ class SavedModelBuilder(object):
|
||||
saver = tf_saver.Saver(
|
||||
variables.global_variables(),
|
||||
sharded=True,
|
||||
write_version=saver_pb2.SaverDef.V2)
|
||||
write_version=saver_pb2.SaverDef.V2,
|
||||
allow_empty=True)
|
||||
|
||||
meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
|
||||
|
||||
@ -394,8 +395,9 @@ class SavedModelBuilder(object):
|
||||
main_op: Op or group of ops to execute when the graph is loaded.
|
||||
"""
|
||||
if self._has_saved_variables:
|
||||
raise AssertionError("Variables and assets have already been saved. "
|
||||
"Please invoke `add_meta_graph()` instead.")
|
||||
raise AssertionError("Graph state including variables and assets has "
|
||||
"already been saved. Please invoke "
|
||||
"`add_meta_graph()` instead.")
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
@ -426,7 +428,8 @@ class SavedModelBuilder(object):
|
||||
saver = tf_saver.Saver(
|
||||
variables.global_variables(),
|
||||
sharded=True,
|
||||
write_version=saver_pb2.SaverDef.V2)
|
||||
write_version=saver_pb2.SaverDef.V2,
|
||||
allow_empty=True)
|
||||
|
||||
# Save the variables. Also, disable writing the checkpoint state proto. The
|
||||
# file is not used during SavedModel loading. In addition, since a
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.core.protobuf import saved_model_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.util import compat
|
||||
@ -210,14 +211,18 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
# Build a saver by importing the meta graph def to load.
|
||||
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
|
||||
|
||||
# Build the checkpoint path where the variables are located.
|
||||
variables_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.VARIABLES_DIRECTORY),
|
||||
compat.as_bytes(constants.VARIABLES_FILENAME))
|
||||
if saver:
|
||||
# Build the checkpoint path where the variables are located.
|
||||
variables_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.VARIABLES_DIRECTORY),
|
||||
compat.as_bytes(constants.VARIABLES_FILENAME))
|
||||
|
||||
# Restore the variables using the built saver in the provided session.
|
||||
saver.restore(sess, variables_path)
|
||||
# Restore the variables using the built saver in the provided session.
|
||||
saver.restore(sess, variables_path)
|
||||
else:
|
||||
tf_logging.info("The specified SavedModel has no variables; no "
|
||||
"checkpoints were restored.")
|
||||
|
||||
# Get asset tensors, if any.
|
||||
asset_tensors_dictionary = _get_asset_tensors(export_dir,
|
||||
|
@ -246,6 +246,41 @@ class SavedModelTest(test.TestCase):
|
||||
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
|
||||
export_dir)
|
||||
|
||||
def testGraphWithoutVariables(self):
|
||||
export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with no variables.
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
constant_5_name = constant_op.constant(5.0).name
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"])
|
||||
|
||||
# Second graph with no variables
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
constant_6_name = constant_op.constant(6.0).name
|
||||
builder.add_meta_graph(["bar"])
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
# Restore the graph with tag "foo".
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
# Read the constant a from the graph.
|
||||
a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
|
||||
b = constant_op.constant(6.0)
|
||||
c = a * b
|
||||
self.assertEqual(30.0, sess.run(c))
|
||||
|
||||
# Restore the graph with tag "bar".
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["bar"], export_dir)
|
||||
# Read the constant a from the graph.
|
||||
a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
|
||||
b = constant_op.constant(5.0)
|
||||
c = a * b
|
||||
self.assertEqual(30.0, sess.run(c))
|
||||
|
||||
def testNoOverwrite(self):
|
||||
export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user