Apply import_scope to asset and variable tensors during tf.saved_model.loader.load
This change explicitly declares import_scope as a kwarg for tf.saved_model.loader.load. Previously, tf.saved_model.loader.load implicitly accepted import_scope and passed it through to import_meta_graph through **saver_kwargs. PiperOrigin-RevId: 200249417
This commit is contained in:
parent
ba9422a8ad
commit
dc7821ccf4
@ -79,12 +79,14 @@ def _parse_saved_model(export_dir):
|
||||
constants.SAVED_MODEL_FILENAME_PB))
|
||||
|
||||
|
||||
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
|
||||
def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
|
||||
"""Gets the asset tensors, if defined in the meta graph def to load.
|
||||
|
||||
Args:
|
||||
export_dir: Directory where the SavedModel is located.
|
||||
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
|
||||
import_scope: Optional `string` -- if specified, prepend this followed by
|
||||
'/' to all returned asset tensor names.
|
||||
|
||||
Returns:
|
||||
A dictionary of asset tensors, keyed by the name of the asset tensor. The
|
||||
@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
|
||||
for asset_any_proto in assets_any_proto:
|
||||
asset_proto = meta_graph_pb2.AssetFileDef()
|
||||
asset_any_proto.Unpack(asset_proto)
|
||||
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
|
||||
tensor_name = asset_proto.tensor_info.name
|
||||
if import_scope:
|
||||
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
||||
asset_tensor_dict[tensor_name] = os.path.join(
|
||||
compat.as_bytes(assets_directory),
|
||||
compat.as_bytes(asset_proto.filename))
|
||||
return asset_tensor_dict
|
||||
@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir):
|
||||
|
||||
|
||||
@tf_export("saved_model.loader.load")
|
||||
def load(sess, tags, export_dir, **saver_kwargs):
|
||||
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
|
||||
"""Loads the model from a SavedModel as specified by tags.
|
||||
|
||||
Args:
|
||||
@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
SavedModel `save()` API.
|
||||
export_dir: Directory in which the SavedModel protocol buffer and variables
|
||||
to be loaded are located.
|
||||
import_scope: Optional `string` -- if specified, prepend this string
|
||||
followed by '/' to all loaded tensor names. This scope is applied to
|
||||
tensor instances loaded into the passed session, but it is *not* written
|
||||
through to the static `MetaGraphDef` protocol buffer that is returned.
|
||||
**saver_kwargs: Optional keyword arguments passed through to Saver.
|
||||
|
||||
Returns:
|
||||
@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
)
|
||||
|
||||
# Build a saver by importing the meta graph def to load.
|
||||
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
|
||||
saver = tf_saver.import_meta_graph(
|
||||
meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
|
||||
|
||||
if saver:
|
||||
# Build the checkpoint path where the variables are located.
|
||||
@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
"checkpoints were restored.")
|
||||
|
||||
# Get asset tensors, if any.
|
||||
asset_tensors_dictionary = _get_asset_tensors(export_dir,
|
||||
meta_graph_def_to_load)
|
||||
asset_tensors_dictionary = _get_asset_tensors(
|
||||
export_dir, meta_graph_def_to_load, import_scope=import_scope)
|
||||
|
||||
main_op_tensor = (
|
||||
_get_main_op_tensor(meta_graph_def_to_load) or
|
||||
|
@ -1197,6 +1197,59 @@ class SavedModelTest(test.TestCase):
|
||||
_validate_custom_saver("tag_1", "save_1/restore_all")
|
||||
_validate_custom_saver("tag_2", "save_2/restore_all")
|
||||
|
||||
def testImportScope(self):
|
||||
export_dir = self._get_export_dir("test_scoped_assets")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
# Build a SavedModel with a variable, an asset, and a constant tensor.
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
constant_op.constant("constant value", name="constant_tensor_name")
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["tag_name"], assets_collection=asset_collection)
|
||||
|
||||
# Save the asset file path for later comparison.
|
||||
asset_file_path = asset_collection[0].eval()
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
# Restore the SavedModel under an import_scope in a new graph/session.
|
||||
graph_proto = loader.load(
|
||||
sess, ["tag_name"], export_dir, import_scope="scope_name")
|
||||
|
||||
# The loaded variable tensor should be scoped, but its contents should be
|
||||
# unchanged.
|
||||
self.assertEqual(
|
||||
"scope_name/v:0",
|
||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
|
||||
self.assertEqual(
|
||||
42,
|
||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||
|
||||
# The loaded asset tensor should be scoped, but the asset file path and
|
||||
# contents should be unchanged.
|
||||
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||
self.assertEqual(1, len(asset_collection))
|
||||
self.assertEqual(asset_file_path, asset_collection[0].eval())
|
||||
self.assertEqual("scope_name/asset_file_tensor:0",
|
||||
asset_collection[0].name)
|
||||
# The static asset data inside graph_proto.collection_def should not be
|
||||
# scoped.
|
||||
self._validate_asset_collection(export_dir, graph_proto.collection_def,
|
||||
"foo.txt", "content_foo",
|
||||
"asset_file_tensor:0")
|
||||
|
||||
# The constant tensor should be scoped, but its contents should be
|
||||
# unchanged.
|
||||
self.assertEqual(
|
||||
compat.as_bytes("constant value"),
|
||||
ops.get_default_graph().get_tensor_by_name(
|
||||
"scope_name/constant_tensor_name:0").eval())
|
||||
|
||||
def testClearDevices(self):
|
||||
export_dir = self._get_export_dir("test_clear_devices")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
@ -1970,7 +1970,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
|
||||
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
|
||||
else:
|
||||
if variables._all_saveable_objects(): # pylint: disable=protected-access
|
||||
if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
|
||||
# Return the default saver instance for all graph variables.
|
||||
return Saver()
|
||||
else:
|
||||
|
@ -2339,6 +2339,46 @@ class MetaGraphTest(test.TestCase):
|
||||
10, size=[1, 10])
|
||||
})
|
||||
|
||||
def testImportIntoNamescopeWithoutVariables(self):
|
||||
# Save a simple graph that contains no variables into a checkpoint.
|
||||
test_dir = self._get_test_dir("no_vars_graph")
|
||||
filename = os.path.join(test_dir, "ckpt")
|
||||
graph_1 = ops_lib.Graph()
|
||||
with session.Session(graph=graph_1) as sess:
|
||||
constant_op.constant([1, 2, 3], name="x")
|
||||
constant_op.constant([1, 2, 3], name="y")
|
||||
saver = saver_module.Saver(allow_empty=True)
|
||||
saver.save(sess, filename)
|
||||
|
||||
# Create a fresh graph.
|
||||
graph_2 = ops_lib.Graph()
|
||||
with session.Session(graph=graph_2) as sess:
|
||||
# Restore the above checkpoint under scope "subgraph_1".
|
||||
new_saver_1 = saver_module.import_meta_graph(
|
||||
filename + ".meta", graph=graph_2, import_scope="subgraph_1")
|
||||
# There are no variables to restore, so import_meta_graph should not
|
||||
# return a Saver.
|
||||
self.assertIsNone(new_saver_1)
|
||||
|
||||
# Create a variable in graph_2 under scope "my_scope".
|
||||
variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Restore the checkpoint into a different scope "subgraph_2".
|
||||
new_saver_2 = saver_module.import_meta_graph(
|
||||
filename + ".meta", graph=graph_2, import_scope="subgraph_2")
|
||||
# Because the variable does not live in scope "subgraph_2",
|
||||
# import_meta_graph should not attempt to restore the variable. So,
|
||||
# import_meta_graph still won't return a Saver instance.
|
||||
self.assertIsNone(new_saver_2)
|
||||
|
||||
# However, if we restore the checkpoint under scope "my_scope",
|
||||
# import_meta_graph will detect the variable and return a Saver for
|
||||
# restoring it. This should happen even when the variable does not
|
||||
# originate from graph_1.
|
||||
new_saver_3 = saver_module.import_meta_graph(
|
||||
filename + ".meta", graph=graph_2, import_scope="my_scope")
|
||||
self.assertIsInstance(new_saver_3, saver_module.Saver)
|
||||
|
||||
def testImportIntoImplicitNamescope(self):
|
||||
# Test that we can import a meta graph into an implicit namescope.
|
||||
test_dir = self._get_test_dir("import_into_namescope")
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "load"
|
||||
argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None"
|
||||
argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "maybe_saved_model_directory"
|
||||
|
Loading…
Reference in New Issue
Block a user