Add tf.saved_model.Asset public symbol.

`Asset` is the mechanism that allows to make hermetic SavedModels
that depend on files. It replaces functionality that on TF-1.x was
typically provided by the ASSET_FILEPATHS collection.

PiperOrigin-RevId: 267534289
This commit is contained in:
Andr? Susano Pinto 2019-09-05 23:30:29 -07:00 committed by Goldie Gadde
parent 14f20c9ff6
commit f185080c09
12 changed files with 92 additions and 19 deletions

View File

@ -750,7 +750,7 @@ class _TRTEngineResource(tracking.TrackableResource):
self._resource_name = resource_name
# Track the serialized engine file in the SavedModel.
self._filename = self._track_trackable(
tracking.TrackableAsset(filename), "_serialized_trt_resource_filename")
tracking.Asset(filename), "_serialized_trt_resource_filename")
self._maximum_cached_engines = maximum_cached_engines
def _create_resource(self):

View File

@ -627,7 +627,7 @@ class TextFileInitializer(TableInitializerBase):
self._delimiter = delimiter
self._name = name
self._filename = self._track_trackable(
trackable.TrackableAsset(filename), "_filename")
trackable.Asset(filename), "_filename")
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)

View File

@ -206,7 +206,7 @@ class Loader(object):
return obj
elif resource_variable_ops.is_resource_variable(obj):
return obj.handle
elif isinstance(obj, tracking.TrackableAsset):
elif isinstance(obj, tracking.Asset):
return obj.asset_path
elif tensor_util.is_tensor(obj):
return obj
@ -343,7 +343,7 @@ class Loader(object):
filename = os.path.join(
saved_model_utils.get_assets_dir(self._export_dir),
self._asset_file_def[proto.asset_file_def_index].filename)
return tracking.TrackableAsset(filename), setattr
return tracking.Asset(filename), setattr
def _recreate_function(self, proto):
return function_deserialization.recreate_function(

View File

@ -204,8 +204,8 @@ class LoadTest(test.TestCase, parameterized.TestCase):
file2 = self._make_asset("contents 2")
root = tracking.AutoTrackable()
root.asset1 = tracking.TrackableAsset(file1)
root.asset2 = tracking.TrackableAsset(file2)
root.asset1 = tracking.Asset(file1)
root.asset2 = tracking.Asset(file2)
save_dir = os.path.join(self.get_temp_dir(), "save_dir")
save.save(root, save_dir)
@ -253,7 +253,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_capture_assets(self, cycles):
root = tracking.AutoTrackable()
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
root.vocab = tracking.Asset(self._make_asset("contents"))
root.f = def_function.function(
lambda: root.vocab.asset_path,
input_signature=[])
@ -266,7 +266,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_capture_assets_in_graph(self, cycles):
root = tracking.AutoTrackable()
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
root.vocab = tracking.Asset(self._make_asset("contents"))
root.f = def_function.function(
lambda: root.vocab.asset_path,
input_signature=[])
@ -290,8 +290,8 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_dedup_assets(self, cycles):
vocab = self._make_asset("contents")
root = tracking.AutoTrackable()
root.asset1 = tracking.TrackableAsset(vocab)
root.asset2 = tracking.TrackableAsset(vocab)
root.asset1 = tracking.Asset(vocab)
root.asset2 = tracking.Asset(vocab)
imported = cycle(root, cycles)
self.assertEqual(imported.asset1.asset_path.numpy(),
imported.asset2.asset_path.numpy())

View File

@ -201,7 +201,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
for tensor_name, value in loader_impl.get_asset_tensors(
self._export_dir, meta_graph_def).items():
asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
asset_paths.append(tracking.TrackableAsset(value))
asset_paths.append(tracking.Asset(value))
init_fn = wrapped.prune(
feeds=asset_feed_tensors,
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])

View File

@ -270,7 +270,7 @@ class _SaveableView(object):
object_map[obj] = new_variable
resource_map[obj.handle] = new_variable.handle
self.captured_tensor_node_ids[obj.handle] = node_id
elif isinstance(obj, tracking.TrackableAsset):
elif isinstance(obj, tracking.Asset):
_process_asset(obj, asset_info, resource_map)
self.captured_tensor_node_ids[obj.asset_path] = node_id
@ -498,7 +498,7 @@ _AssetInfo = collections.namedtuple(
"asset_initializers_by_resource",
# Map from base asset filenames to full paths
"asset_filename_map",
# Map from TrackableAsset to index of corresponding AssetFileDef
# Map from Asset to index of corresponding AssetFileDef
"asset_index"])
@ -662,7 +662,7 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
def _write_object_proto(obj, proto, asset_file_def_index):
"""Saves an object into SavedObject proto."""
if isinstance(obj, tracking.TrackableAsset):
if isinstance(obj, tracking.Asset):
proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj]
elif resource_variable_ops.is_resource_variable(obj):

View File

@ -426,7 +426,7 @@ class AssetTests(test.TestCase):
def test_asset_path_returned(self):
root = tracking.AutoTrackable()
root.path = tracking.TrackableAsset(self._vocab_path)
root.path = tracking.Asset(self._vocab_path)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
root.get_asset = def_function.function(lambda: root.path.asset_path)
save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
@ -469,7 +469,7 @@ class AssetTests(test.TestCase):
root.f = def_function.function(
lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.asset = tracking.TrackableAsset(self._vocab_path)
root.asset = tracking.Asset(self._vocab_path)
export_dir = os.path.join(self.get_temp_dir(), "save_dir")
save.save(root, export_dir)

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
# global _RESOURCE_TRACKER_STACK
@ -275,8 +276,44 @@ class TrackableResource(CapturableResource):
super(TrackableResource, self).__init__(device=device, deleter=deleter)
class TrackableAsset(base.Trackable):
"""Base class for asset files which need to be tracked."""
@tf_export("saved_model.Asset")
class Asset(base.Trackable):
"""Represents a file asset to hermetically include in a SavedModel.
A SavedModel can include arbitrary files, called assets, that are needed
for its use. For example a vocabulary file used initialize a lookup table.
When a trackable object is exported via `tf.saved_model.save()`, all the
`Asset`s reachable from it are copied into the SavedModel assets directory.
Upon loading, the assets and the serialized functions that depend on them
will refer to the correct filepaths inside the SavedModel directory.
Example:
```
filename = tf.saved_model.Asset("file.txt")
@tf.function(input_signature=[])
def func():
return tf.io.read_file(filename)
trackable_obj = tf.train.Checkpoint()
trackable_obj.func = func
trackable_obj.filename = filename
tf.saved_model.save(trackable_obj, "/tmp/saved_model")
# The created SavedModel is hermetic, it does not depend on
# the original file and can be moved to another path.
tf.io.gfile.remove("file.txt")
tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location")
reloaded_obj = tf.saved_model.load("/tmp/new_location")
print(reloaded_obj.func())
```
Attributes:
asset_path: A 0-D `tf.string` tensor with path to the asset.
"""
def __init__(self, path):
"""Record the full path to the asset."""
@ -389,5 +426,5 @@ def cached_per_instance(f):
ops.register_tensor_conversion_function(
TrackableAsset,
Asset,
lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))

View File

@ -0,0 +1,14 @@
path: "tensorflow.saved_model.Asset"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "asset_path"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "ASSETS_KEY"
mtype: "<type \'str\'>"
}
member {
name: "Asset"
mtype: "<type \'type\'>"
}
member {
name: "Builder"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,14 @@
path: "tensorflow.saved_model.Asset"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "asset_path"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "ASSETS_KEY"
mtype: "<type \'str\'>"
}
member {
name: "Asset"
mtype: "<type \'type\'>"
}
member {
name: "CLASSIFY_INPUTS"
mtype: "<type \'str\'>"