Apply import_scope to asset and variable tensors during tf.saved_model.loader.load

This change explicitly declares import_scope as a kwarg for tf.saved_model.loader.load. Previously, tf.saved_model.loader.load implicitly accepted import_scope and passed it through to import_meta_graph through **saver_kwargs.

PiperOrigin-RevId: 200249417
This commit is contained in:
A. Unique TensorFlower 2018-06-12 11:26:38 -07:00 committed by TensorFlower Gardener
parent ba9422a8ad
commit dc7821ccf4
5 changed files with 111 additions and 8 deletions

View File

@ -79,12 +79,14 @@ def _parse_saved_model(export_dir):
constants.SAVED_MODEL_FILENAME_PB))
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
export_dir: Directory where the SavedModel is located.
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
import_scope: Optional `string` -- if specified, prepend this followed by
'/' to all returned asset tensor names.
Returns:
A dictionary of asset tensors, keyed by the name of the asset tensor. The
@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
for asset_any_proto in assets_any_proto:
asset_proto = meta_graph_pb2.AssetFileDef()
asset_any_proto.Unpack(asset_proto)
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
tensor_name = asset_proto.tensor_info.name
if import_scope:
tensor_name = "%s/%s" % (import_scope, tensor_name)
asset_tensor_dict[tensor_name] = os.path.join(
compat.as_bytes(assets_directory),
compat.as_bytes(asset_proto.filename))
return asset_tensor_dict
@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir):
@tf_export("saved_model.loader.load")
def load(sess, tags, export_dir, **saver_kwargs):
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
Args:
@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
SavedModel `save()` API.
export_dir: Directory in which the SavedModel protocol buffer and variables
to be loaded are located.
import_scope: Optional `string` -- if specified, prepend this string
followed by '/' to all loaded tensor names. This scope is applied to
tensor instances loaded into the passed session, but it is *not* written
through to the static `MetaGraphDef` protocol buffer that is returned.
**saver_kwargs: Optional keyword arguments passed through to Saver.
Returns:
@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
)
# Build a saver by importing the meta graph def to load.
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
saver = tf_saver.import_meta_graph(
meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
if saver:
# Build the checkpoint path where the variables are located.
@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
"checkpoints were restored.")
# Get asset tensors, if any.
asset_tensors_dictionary = _get_asset_tensors(export_dir,
meta_graph_def_to_load)
asset_tensors_dictionary = _get_asset_tensors(
export_dir, meta_graph_def_to_load, import_scope=import_scope)
main_op_tensor = (
_get_main_op_tensor(meta_graph_def_to_load) or

View File

@ -1197,6 +1197,59 @@ class SavedModelTest(test.TestCase):
_validate_custom_saver("tag_1", "save_1/restore_all")
_validate_custom_saver("tag_2", "save_2/restore_all")
def testImportScope(self):
export_dir = self._get_export_dir("test_scoped_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Build a SavedModel with a variable, an asset, and a constant tensor.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
"asset_file_tensor")
constant_op.constant("constant value", name="constant_tensor_name")
builder.add_meta_graph_and_variables(
sess, ["tag_name"], assets_collection=asset_collection)
# Save the asset file path for later comparison.
asset_file_path = asset_collection[0].eval()
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
# Restore the SavedModel under an import_scope in a new graph/session.
graph_proto = loader.load(
sess, ["tag_name"], export_dir, import_scope="scope_name")
# The loaded variable tensor should be scoped, but its contents should be
# unchanged.
self.assertEqual(
"scope_name/v:0",
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
self.assertEqual(
42,
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# The loaded asset tensor should be scoped, but the asset file path and
# contents should be unchanged.
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
self.assertEqual(1, len(asset_collection))
self.assertEqual(asset_file_path, asset_collection[0].eval())
self.assertEqual("scope_name/asset_file_tensor:0",
asset_collection[0].name)
# The static asset data inside graph_proto.collection_def should not be
# scoped.
self._validate_asset_collection(export_dir, graph_proto.collection_def,
"foo.txt", "content_foo",
"asset_file_tensor:0")
# The constant tensor should be scoped, but its contents should be
# unchanged.
self.assertEqual(
compat.as_bytes("constant value"),
ops.get_default_graph().get_tensor_by_name(
"scope_name/constant_tensor_name:0").eval())
def testClearDevices(self):
export_dir = self._get_export_dir("test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)

View File

@ -1970,7 +1970,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
else:
if variables._all_saveable_objects(): # pylint: disable=protected-access
if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
# Return the default saver instance for all graph variables.
return Saver()
else:

View File

@ -2339,6 +2339,46 @@ class MetaGraphTest(test.TestCase):
10, size=[1, 10])
})
def testImportIntoNamescopeWithoutVariables(self):
# Save a simple graph that contains no variables into a checkpoint.
test_dir = self._get_test_dir("no_vars_graph")
filename = os.path.join(test_dir, "ckpt")
graph_1 = ops_lib.Graph()
with session.Session(graph=graph_1) as sess:
constant_op.constant([1, 2, 3], name="x")
constant_op.constant([1, 2, 3], name="y")
saver = saver_module.Saver(allow_empty=True)
saver.save(sess, filename)
# Create a fresh graph.
graph_2 = ops_lib.Graph()
with session.Session(graph=graph_2) as sess:
# Restore the above checkpoint under scope "subgraph_1".
new_saver_1 = saver_module.import_meta_graph(
filename + ".meta", graph=graph_2, import_scope="subgraph_1")
# There are no variables to restore, so import_meta_graph should not
# return a Saver.
self.assertIsNone(new_saver_1)
# Create a variable in graph_2 under scope "my_scope".
variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
sess.run(variables.global_variables_initializer())
# Restore the checkpoint into a different scope "subgraph_2".
new_saver_2 = saver_module.import_meta_graph(
filename + ".meta", graph=graph_2, import_scope="subgraph_2")
# Because the variable does not live in scope "subgraph_2",
# import_meta_graph should not attempt to restore the variable. So,
# import_meta_graph still won't return a Saver instance.
self.assertIsNone(new_saver_2)
# However, if we restore the checkpoint under scope "my_scope",
# import_meta_graph will detect the variable and return a Saver for
# restoring it. This should happen even when the variable does not
# originate from graph_1.
new_saver_3 = saver_module.import_meta_graph(
filename + ".meta", graph=graph_2, import_scope="my_scope")
self.assertIsInstance(new_saver_3, saver_module.Saver)
def testImportIntoImplicitNamescope(self):
# Test that we can import a meta graph into an implicit namescope.
test_dir = self._get_test_dir("import_into_namescope")

View File

@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader"
tf_module {
member_method {
name: "load"
argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None"
argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
}
member_method {
name: "maybe_saved_model_directory"