Add tf.saved_model.Asset public symbol.
`Asset` is the mechanism that allows to make hermetic SavedModels that depend on files. It replaces functionality that on TF-1.x was typically provided by the ASSET_FILEPATHS collection. PiperOrigin-RevId: 267534289
This commit is contained in:
parent
14f20c9ff6
commit
f185080c09
@ -750,7 +750,7 @@ class _TRTEngineResource(tracking.TrackableResource):
|
||||
self._resource_name = resource_name
|
||||
# Track the serialized engine file in the SavedModel.
|
||||
self._filename = self._track_trackable(
|
||||
tracking.TrackableAsset(filename), "_serialized_trt_resource_filename")
|
||||
tracking.Asset(filename), "_serialized_trt_resource_filename")
|
||||
self._maximum_cached_engines = maximum_cached_engines
|
||||
|
||||
def _create_resource(self):
|
||||
|
@ -627,7 +627,7 @@ class TextFileInitializer(TableInitializerBase):
|
||||
self._delimiter = delimiter
|
||||
self._name = name
|
||||
self._filename = self._track_trackable(
|
||||
trackable.TrackableAsset(filename), "_filename")
|
||||
trackable.Asset(filename), "_filename")
|
||||
|
||||
super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
|
||||
|
||||
|
@ -206,7 +206,7 @@ class Loader(object):
|
||||
return obj
|
||||
elif resource_variable_ops.is_resource_variable(obj):
|
||||
return obj.handle
|
||||
elif isinstance(obj, tracking.TrackableAsset):
|
||||
elif isinstance(obj, tracking.Asset):
|
||||
return obj.asset_path
|
||||
elif tensor_util.is_tensor(obj):
|
||||
return obj
|
||||
@ -343,7 +343,7 @@ class Loader(object):
|
||||
filename = os.path.join(
|
||||
saved_model_utils.get_assets_dir(self._export_dir),
|
||||
self._asset_file_def[proto.asset_file_def_index].filename)
|
||||
return tracking.TrackableAsset(filename), setattr
|
||||
return tracking.Asset(filename), setattr
|
||||
|
||||
def _recreate_function(self, proto):
|
||||
return function_deserialization.recreate_function(
|
||||
|
@ -204,8 +204,8 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
file2 = self._make_asset("contents 2")
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.asset1 = tracking.TrackableAsset(file1)
|
||||
root.asset2 = tracking.TrackableAsset(file2)
|
||||
root.asset1 = tracking.Asset(file1)
|
||||
root.asset2 = tracking.Asset(file2)
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), "save_dir")
|
||||
save.save(root, save_dir)
|
||||
@ -253,7 +253,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_capture_assets(self, cycles):
|
||||
root = tracking.AutoTrackable()
|
||||
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
|
||||
root.vocab = tracking.Asset(self._make_asset("contents"))
|
||||
root.f = def_function.function(
|
||||
lambda: root.vocab.asset_path,
|
||||
input_signature=[])
|
||||
@ -266,7 +266,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_capture_assets_in_graph(self, cycles):
|
||||
root = tracking.AutoTrackable()
|
||||
root.vocab = tracking.TrackableAsset(self._make_asset("contents"))
|
||||
root.vocab = tracking.Asset(self._make_asset("contents"))
|
||||
root.f = def_function.function(
|
||||
lambda: root.vocab.asset_path,
|
||||
input_signature=[])
|
||||
@ -290,8 +290,8 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
def test_dedup_assets(self, cycles):
|
||||
vocab = self._make_asset("contents")
|
||||
root = tracking.AutoTrackable()
|
||||
root.asset1 = tracking.TrackableAsset(vocab)
|
||||
root.asset2 = tracking.TrackableAsset(vocab)
|
||||
root.asset1 = tracking.Asset(vocab)
|
||||
root.asset2 = tracking.Asset(vocab)
|
||||
imported = cycle(root, cycles)
|
||||
self.assertEqual(imported.asset1.asset_path.numpy(),
|
||||
imported.asset2.asset_path.numpy())
|
||||
|
@ -201,7 +201,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
for tensor_name, value in loader_impl.get_asset_tensors(
|
||||
self._export_dir, meta_graph_def).items():
|
||||
asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
|
||||
asset_paths.append(tracking.TrackableAsset(value))
|
||||
asset_paths.append(tracking.Asset(value))
|
||||
init_fn = wrapped.prune(
|
||||
feeds=asset_feed_tensors,
|
||||
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
|
||||
|
@ -270,7 +270,7 @@ class _SaveableView(object):
|
||||
object_map[obj] = new_variable
|
||||
resource_map[obj.handle] = new_variable.handle
|
||||
self.captured_tensor_node_ids[obj.handle] = node_id
|
||||
elif isinstance(obj, tracking.TrackableAsset):
|
||||
elif isinstance(obj, tracking.Asset):
|
||||
_process_asset(obj, asset_info, resource_map)
|
||||
self.captured_tensor_node_ids[obj.asset_path] = node_id
|
||||
|
||||
@ -498,7 +498,7 @@ _AssetInfo = collections.namedtuple(
|
||||
"asset_initializers_by_resource",
|
||||
# Map from base asset filenames to full paths
|
||||
"asset_filename_map",
|
||||
# Map from TrackableAsset to index of corresponding AssetFileDef
|
||||
# Map from Asset to index of corresponding AssetFileDef
|
||||
"asset_index"])
|
||||
|
||||
|
||||
@ -662,7 +662,7 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
|
||||
|
||||
def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
"""Saves an object into SavedObject proto."""
|
||||
if isinstance(obj, tracking.TrackableAsset):
|
||||
if isinstance(obj, tracking.Asset):
|
||||
proto.asset.SetInParent()
|
||||
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
||||
elif resource_variable_ops.is_resource_variable(obj):
|
||||
|
@ -426,7 +426,7 @@ class AssetTests(test.TestCase):
|
||||
|
||||
def test_asset_path_returned(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.path = tracking.TrackableAsset(self._vocab_path)
|
||||
root.path = tracking.Asset(self._vocab_path)
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
root.get_asset = def_function.function(lambda: root.path.asset_path)
|
||||
save.save(root, save_dir, signatures=root.get_asset.get_concrete_function())
|
||||
@ -469,7 +469,7 @@ class AssetTests(test.TestCase):
|
||||
root.f = def_function.function(
|
||||
lambda x: 2. * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
root.asset = tracking.TrackableAsset(self._vocab_path)
|
||||
root.asset = tracking.Asset(self._vocab_path)
|
||||
|
||||
export_dir = os.path.join(self.get_temp_dir(), "save_dir")
|
||||
save.save(root, export_dir)
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# global _RESOURCE_TRACKER_STACK
|
||||
@ -275,8 +276,44 @@ class TrackableResource(CapturableResource):
|
||||
super(TrackableResource, self).__init__(device=device, deleter=deleter)
|
||||
|
||||
|
||||
class TrackableAsset(base.Trackable):
|
||||
"""Base class for asset files which need to be tracked."""
|
||||
@tf_export("saved_model.Asset")
|
||||
class Asset(base.Trackable):
|
||||
"""Represents a file asset to hermetically include in a SavedModel.
|
||||
|
||||
A SavedModel can include arbitrary files, called assets, that are needed
|
||||
for its use. For example a vocabulary file used initialize a lookup table.
|
||||
|
||||
When a trackable object is exported via `tf.saved_model.save()`, all the
|
||||
`Asset`s reachable from it are copied into the SavedModel assets directory.
|
||||
Upon loading, the assets and the serialized functions that depend on them
|
||||
will refer to the correct filepaths inside the SavedModel directory.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
filename = tf.saved_model.Asset("file.txt")
|
||||
|
||||
@tf.function(input_signature=[])
|
||||
def func():
|
||||
return tf.io.read_file(filename)
|
||||
|
||||
trackable_obj = tf.train.Checkpoint()
|
||||
trackable_obj.func = func
|
||||
trackable_obj.filename = filename
|
||||
tf.saved_model.save(trackable_obj, "/tmp/saved_model")
|
||||
|
||||
# The created SavedModel is hermetic, it does not depend on
|
||||
# the original file and can be moved to another path.
|
||||
tf.io.gfile.remove("file.txt")
|
||||
tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location")
|
||||
|
||||
reloaded_obj = tf.saved_model.load("/tmp/new_location")
|
||||
print(reloaded_obj.func())
|
||||
```
|
||||
|
||||
Attributes:
|
||||
asset_path: A 0-D `tf.string` tensor with path to the asset.
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
"""Record the full path to the asset."""
|
||||
@ -389,5 +426,5 @@ def cached_per_instance(f):
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
TrackableAsset,
|
||||
Asset,
|
||||
lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))
|
||||
|
@ -0,0 +1,14 @@
|
||||
path: "tensorflow.saved_model.Asset"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "asset_path"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "ASSETS_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "Asset"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Builder"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,14 @@
|
||||
path: "tensorflow.saved_model.Asset"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "asset_path"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "ASSETS_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "Asset"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "CLASSIFY_INPUTS"
|
||||
mtype: "<type \'str\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user