Add tf.saved_model.Asset public symbol.

`Asset` is the mechanism that allows to make hermetic SavedModels
that depend on files. It replaces functionality that on TF-1.x was
typically provided by the ASSET_FILEPATHS collection.

PiperOrigin-RevId: 267534289
This commit is contained in:
Andr? Susano Pinto 2019-09-05 23:30:29 -07:00 committed by Goldie Gadde
parent 14f20c9ff6
commit f185080c09
12 changed files with 92 additions and 19 deletions

View File

@ -750,7 +750,7 @@ class _TRTEngineResource(tracking.TrackableResource):
self._resource_name = resource_name self._resource_name = resource_name
# Track the serialized engine file in the SavedModel. # Track the serialized engine file in the SavedModel.
self._filename = self._track_trackable( self._filename = self._track_trackable(
tracking.TrackableAsset(filename), "_serialized_trt_resource_filename") tracking.Asset(filename), "_serialized_trt_resource_filename")
self._maximum_cached_engines = maximum_cached_engines self._maximum_cached_engines = maximum_cached_engines
def _create_resource(self): def _create_resource(self):

View File

@ -627,7 +627,7 @@ class TextFileInitializer(TableInitializerBase):
self._delimiter = delimiter self._delimiter = delimiter
self._name = name self._name = name
self._filename = self._track_trackable( self._filename = self._track_trackable(
trackable.TrackableAsset(filename), "_filename") trackable.Asset(filename), "_filename")
super(TextFileInitializer, self).__init__(key_dtype, value_dtype) super(TextFileInitializer, self).__init__(key_dtype, value_dtype)

View File

@ -206,7 +206,7 @@ class Loader(object):
return obj return obj
elif resource_variable_ops.is_resource_variable(obj): elif resource_variable_ops.is_resource_variable(obj):
return obj.handle return obj.handle
elif isinstance(obj, tracking.TrackableAsset): elif isinstance(obj, tracking.Asset):
return obj.asset_path return obj.asset_path
elif tensor_util.is_tensor(obj): elif tensor_util.is_tensor(obj):
return obj return obj
@ -343,7 +343,7 @@ class Loader(object):
filename = os.path.join( filename = os.path.join(
saved_model_utils.get_assets_dir(self._export_dir), saved_model_utils.get_assets_dir(self._export_dir),
self._asset_file_def[proto.asset_file_def_index].filename) self._asset_file_def[proto.asset_file_def_index].filename)
return tracking.TrackableAsset(filename), setattr return tracking.Asset(filename), setattr
def _recreate_function(self, proto): def _recreate_function(self, proto):
return function_deserialization.recreate_function( return function_deserialization.recreate_function(

View File

@ -204,8 +204,8 @@ class LoadTest(test.TestCase, parameterized.TestCase):
file2 = self._make_asset("contents 2") file2 = self._make_asset("contents 2")
root = tracking.AutoTrackable() root = tracking.AutoTrackable()
root.asset1 = tracking.TrackableAsset(file1) root.asset1 = tracking.Asset(file1)
root.asset2 = tracking.TrackableAsset(file2) root.asset2 = tracking.Asset(file2)
save_dir = os.path.join(self.get_temp_dir(), "save_dir") save_dir = os.path.join(self.get_temp_dir(), "save_dir")
save.save(root, save_dir) save.save(root, save_dir)
@ -253,7 +253,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_capture_assets(self, cycles): def test_capture_assets(self, cycles):
root = tracking.AutoTrackable() root = tracking.AutoTrackable()
root.vocab = tracking.TrackableAsset(self._make_asset("contents")) root.vocab = tracking.Asset(self._make_asset("contents"))
root.f = def_function.function( root.f = def_function.function(
lambda: root.vocab.asset_path, lambda: root.vocab.asset_path,
input_signature=[]) input_signature=[])
@ -266,7 +266,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_capture_assets_in_graph(self, cycles): def test_capture_assets_in_graph(self, cycles):
root = tracking.AutoTrackable() root = tracking.AutoTrackable()
root.vocab = tracking.TrackableAsset(self._make_asset("contents")) root.vocab = tracking.Asset(self._make_asset("contents"))
root.f = def_function.function( root.f = def_function.function(
lambda: root.vocab.asset_path, lambda: root.vocab.asset_path,
input_signature=[]) input_signature=[])
@ -290,8 +290,8 @@ class LoadTest(test.TestCase, parameterized.TestCase):
def test_dedup_assets(self, cycles): def test_dedup_assets(self, cycles):
vocab = self._make_asset("contents") vocab = self._make_asset("contents")
root = tracking.AutoTrackable() root = tracking.AutoTrackable()
root.asset1 = tracking.TrackableAsset(vocab) root.asset1 = tracking.Asset(vocab)
root.asset2 = tracking.TrackableAsset(vocab) root.asset2 = tracking.Asset(vocab)
imported = cycle(root, cycles) imported = cycle(root, cycles)
self.assertEqual(imported.asset1.asset_path.numpy(), self.assertEqual(imported.asset1.asset_path.numpy(),
imported.asset2.asset_path.numpy()) imported.asset2.asset_path.numpy())

View File

@ -201,7 +201,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
for tensor_name, value in loader_impl.get_asset_tensors( for tensor_name, value in loader_impl.get_asset_tensors(
self._export_dir, meta_graph_def).items(): self._export_dir, meta_graph_def).items():
asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
asset_paths.append(tracking.TrackableAsset(value)) asset_paths.append(tracking.Asset(value))
init_fn = wrapped.prune( init_fn = wrapped.prune(
feeds=asset_feed_tensors, feeds=asset_feed_tensors,
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])

View File

@ -270,7 +270,7 @@ class _SaveableView(object):
object_map[obj] = new_variable object_map[obj] = new_variable
resource_map[obj.handle] = new_variable.handle resource_map[obj.handle] = new_variable.handle
self.captured_tensor_node_ids[obj.handle] = node_id self.captured_tensor_node_ids[obj.handle] = node_id
elif isinstance(obj, tracking.TrackableAsset): elif isinstance(obj, tracking.Asset):
_process_asset(obj, asset_info, resource_map) _process_asset(obj, asset_info, resource_map)
self.captured_tensor_node_ids[obj.asset_path] = node_id self.captured_tensor_node_ids[obj.asset_path] = node_id
@ -498,7 +498,7 @@ _AssetInfo = collections.namedtuple(
"asset_initializers_by_resource", "asset_initializers_by_resource",
# Map from base asset filenames to full paths # Map from base asset filenames to full paths
"asset_filename_map", "asset_filename_map",
# Map from TrackableAsset to index of corresponding AssetFileDef # Map from Asset to index of corresponding AssetFileDef
"asset_index"]) "asset_index"])
@ -662,7 +662,7 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
def _write_object_proto(obj, proto, asset_file_def_index): def _write_object_proto(obj, proto, asset_file_def_index):
"""Saves an object into SavedObject proto.""" """Saves an object into SavedObject proto."""
if isinstance(obj, tracking.TrackableAsset): if isinstance(obj, tracking.Asset):
proto.asset.SetInParent() proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj] proto.asset.asset_file_def_index = asset_file_def_index[obj]
elif resource_variable_ops.is_resource_variable(obj): elif resource_variable_ops.is_resource_variable(obj):

View File

@ -426,7 +426,7 @@ class AssetTests(test.TestCase):
def test_asset_path_returned(self): def test_asset_path_returned(self):
root = tracking.AutoTrackable() root = tracking.AutoTrackable()
root.path = tracking.TrackableAsset(self._vocab_path) root.path = tracking.Asset(self._vocab_path)
save_dir = os.path.join(self.get_temp_dir(), "saved_model") save_dir = os.path.join(self.get_temp_dir(), "saved_model")
root.get_asset = def_function.function(lambda: root.path.asset_path) root.get_asset = def_function.function(lambda: root.path.asset_path)
save.save(root, save_dir, signatures=root.get_asset.get_concrete_function()) save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
@ -469,7 +469,7 @@ class AssetTests(test.TestCase):
root.f = def_function.function( root.f = def_function.function(
lambda x: 2. * x, lambda x: 2. * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root.asset = tracking.TrackableAsset(self._vocab_path) root.asset = tracking.Asset(self._vocab_path)
export_dir = os.path.join(self.get_temp_dir(), "save_dir") export_dir = os.path.join(self.get_temp_dir(), "save_dir")
save.save(root, export_dir) save.save(root, export_dir)

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
# global _RESOURCE_TRACKER_STACK # global _RESOURCE_TRACKER_STACK
@ -275,8 +276,44 @@ class TrackableResource(CapturableResource):
super(TrackableResource, self).__init__(device=device, deleter=deleter) super(TrackableResource, self).__init__(device=device, deleter=deleter)
class TrackableAsset(base.Trackable): @tf_export("saved_model.Asset")
"""Base class for asset files which need to be tracked.""" class Asset(base.Trackable):
"""Represents a file asset to hermetically include in a SavedModel.
A SavedModel can include arbitrary files, called assets, that are needed
for its use. For example a vocabulary file used initialize a lookup table.
When a trackable object is exported via `tf.saved_model.save()`, all the
`Asset`s reachable from it are copied into the SavedModel assets directory.
Upon loading, the assets and the serialized functions that depend on them
will refer to the correct filepaths inside the SavedModel directory.
Example:
```
filename = tf.saved_model.Asset("file.txt")
@tf.function(input_signature=[])
def func():
return tf.io.read_file(filename)
trackable_obj = tf.train.Checkpoint()
trackable_obj.func = func
trackable_obj.filename = filename
tf.saved_model.save(trackable_obj, "/tmp/saved_model")
# The created SavedModel is hermetic, it does not depend on
# the original file and can be moved to another path.
tf.io.gfile.remove("file.txt")
tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location")
reloaded_obj = tf.saved_model.load("/tmp/new_location")
print(reloaded_obj.func())
```
Attributes:
asset_path: A 0-D `tf.string` tensor with path to the asset.
"""
def __init__(self, path): def __init__(self, path):
"""Record the full path to the asset.""" """Record the full path to the asset."""
@ -389,5 +426,5 @@ def cached_per_instance(f):
ops.register_tensor_conversion_function( ops.register_tensor_conversion_function(
TrackableAsset, Asset,
lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw)) lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))

View File

@ -0,0 +1,14 @@
path: "tensorflow.saved_model.Asset"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "asset_path"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "ASSETS_KEY" name: "ASSETS_KEY"
mtype: "<type \'str\'>" mtype: "<type \'str\'>"
} }
member {
name: "Asset"
mtype: "<type \'type\'>"
}
member { member {
name: "Builder" name: "Builder"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -0,0 +1,14 @@
path: "tensorflow.saved_model.Asset"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "asset_path"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "ASSETS_KEY" name: "ASSETS_KEY"
mtype: "<type \'str\'>" mtype: "<type \'str\'>"
} }
member {
name: "Asset"
mtype: "<type \'type\'>"
}
member { member {
name: "CLASSIFY_INPUTS" name: "CLASSIFY_INPUTS"
mtype: "<type \'str\'>" mtype: "<type \'str\'>"