diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 2de271d4e85..65846044c53 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -758,7 +758,7 @@ class _TRTEngineResource(tracking.TrackableResource): self._resource_name = resource_name # Track the serialized engine file in the SavedModel. self._filename = self._track_trackable( - tracking.TrackableAsset(filename), "_serialized_trt_resource_filename") + tracking.Asset(filename), "_serialized_trt_resource_filename") self._maximum_cached_engines = maximum_cached_engines def _create_resource(self): diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index b2ea5c4ebc7..0aabb5c7ecb 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -627,7 +627,7 @@ class TextFileInitializer(TableInitializerBase): self._delimiter = delimiter self._name = name self._filename = self._track_trackable( - trackable.TrackableAsset(filename), "_filename") + trackable.Asset(filename), "_filename") super(TextFileInitializer, self).__init__(key_dtype, value_dtype) diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 88f0f819ea7..b4754d4c6e7 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -206,7 +206,7 @@ class Loader(object): return obj elif resource_variable_ops.is_resource_variable(obj): return obj.handle - elif isinstance(obj, tracking.TrackableAsset): + elif isinstance(obj, tracking.Asset): return obj.asset_path elif tensor_util.is_tensor(obj): return obj @@ -343,7 +343,7 @@ class Loader(object): filename = os.path.join( saved_model_utils.get_assets_dir(self._export_dir), self._asset_file_def[proto.asset_file_def_index].filename) - return tracking.TrackableAsset(filename), setattr + return tracking.Asset(filename), setattr def _recreate_function(self, proto): return function_deserialization.recreate_function( diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 0fb77683daa..1a08a0ad950 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -204,8 +204,8 @@ class LoadTest(test.TestCase, parameterized.TestCase): file2 = self._make_asset("contents 2") root = tracking.AutoTrackable() - root.asset1 = tracking.TrackableAsset(file1) - root.asset2 = tracking.TrackableAsset(file2) + root.asset1 = tracking.Asset(file1) + root.asset2 = tracking.Asset(file2) save_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, save_dir) @@ -253,7 +253,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): def test_capture_assets(self, cycles): root = tracking.AutoTrackable() - root.vocab = tracking.TrackableAsset(self._make_asset("contents")) + root.vocab = tracking.Asset(self._make_asset("contents")) root.f = def_function.function( lambda: root.vocab.asset_path, input_signature=[]) @@ -266,7 +266,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): def test_capture_assets_in_graph(self, cycles): root = tracking.AutoTrackable() - root.vocab = tracking.TrackableAsset(self._make_asset("contents")) + root.vocab = tracking.Asset(self._make_asset("contents")) root.f = def_function.function( lambda: root.vocab.asset_path, input_signature=[]) @@ -290,8 +290,8 @@ class LoadTest(test.TestCase, parameterized.TestCase): def test_dedup_assets(self, cycles): vocab = self._make_asset("contents") root = tracking.AutoTrackable() - root.asset1 = tracking.TrackableAsset(vocab) - root.asset2 = tracking.TrackableAsset(vocab) + root.asset1 = tracking.Asset(vocab) + root.asset2 = tracking.Asset(vocab) imported = cycle(root, cycles) self.assertEqual(imported.asset1.asset_path.numpy(), imported.asset2.asset_path.numpy()) diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index e2e62bae386..4ddd18bc6f3 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -201,7 +201,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): for tensor_name, value in loader_impl.get_asset_tensors( self._export_dir, meta_graph_def).items(): asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) - asset_paths.append(tracking.TrackableAsset(value)) + asset_paths.append(tracking.Asset(value)) init_fn = wrapped.prune( feeds=asset_feed_tensors, fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)]) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index b1b69f1ff32..9b253435659 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -270,7 +270,7 @@ class _SaveableView(object): object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id - elif isinstance(obj, tracking.TrackableAsset): + elif isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id @@ -498,7 +498,7 @@ _AssetInfo = collections.namedtuple( "asset_initializers_by_resource", # Map from base asset filenames to full paths "asset_filename_map", - # Map from TrackableAsset to index of corresponding AssetFileDef + # Map from Asset to index of corresponding AssetFileDef "asset_index"]) @@ -662,7 +662,7 @@ def _serialize_object_graph(saveable_view, asset_file_def_index): def _write_object_proto(obj, proto, asset_file_def_index): """Saves an object into SavedObject proto.""" - if isinstance(obj, tracking.TrackableAsset): + if isinstance(obj, tracking.Asset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 542e7130273..1b6503870a7 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -426,7 +426,7 @@ class AssetTests(test.TestCase): def test_asset_path_returned(self): root = tracking.AutoTrackable() - root.path = tracking.TrackableAsset(self._vocab_path) + root.path = tracking.Asset(self._vocab_path) save_dir = os.path.join(self.get_temp_dir(), "saved_model") root.get_asset = def_function.function(lambda: root.path.asset_path) save.save(root, save_dir, signatures=root.get_asset.get_concrete_function()) @@ -469,7 +469,7 @@ class AssetTests(test.TestCase): root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) - root.asset = tracking.TrackableAsset(self._vocab_path) + root.asset = tracking.Asset(self._vocab_path) export_dir = os.path.join(self.get_temp_dir(), "save_dir") save.save(root, export_dir) diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py index 8b0bc6e5e3a..f27b83ccba2 100644 --- a/tensorflow/python/training/tracking/tracking.py +++ b/tensorflow/python/training/tracking/tracking.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import data_structures from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export # global _RESOURCE_TRACKER_STACK @@ -275,8 +276,44 @@ class TrackableResource(CapturableResource): super(TrackableResource, self).__init__(device=device, deleter=deleter) -class TrackableAsset(base.Trackable): - """Base class for asset files which need to be tracked.""" +@tf_export("saved_model.Asset") +class Asset(base.Trackable): + """Represents a file asset to hermetically include in a SavedModel. + + A SavedModel can include arbitrary files, called assets, that are needed + for its use. For example a vocabulary file used initialize a lookup table. + + When a trackable object is exported via `tf.saved_model.save()`, all the + `Asset`s reachable from it are copied into the SavedModel assets directory. + Upon loading, the assets and the serialized functions that depend on them + will refer to the correct filepaths inside the SavedModel directory. + + Example: + + ``` + filename = tf.saved_model.Asset("file.txt") + + @tf.function(input_signature=[]) + def func(): + return tf.io.read_file(filename) + + trackable_obj = tf.train.Checkpoint() + trackable_obj.func = func + trackable_obj.filename = filename + tf.saved_model.save(trackable_obj, "/tmp/saved_model") + + # The created SavedModel is hermetic, it does not depend on + # the original file and can be moved to another path. + tf.io.gfile.remove("file.txt") + tf.io.gfile.rename("/tmp/saved_model", "/tmp/new_location") + + reloaded_obj = tf.saved_model.load("/tmp/new_location") + print(reloaded_obj.func()) + ``` + + Attributes: + asset_path: A 0-D `tf.string` tensor with path to the asset. + """ def __init__(self, path): """Record the full path to the asset.""" @@ -389,5 +426,5 @@ def cached_per_instance(f): ops.register_tensor_conversion_function( - TrackableAsset, + Asset, lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw)) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-asset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-asset.pbtxt new file mode 100644 index 00000000000..0a20385c329 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-asset.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.saved_model.Asset" +tf_class { + is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<type \'object\'>" + member { + name: "asset_path" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt index f3558109ce8..216aa232223 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "ASSETS_KEY" mtype: "<type \'str\'>" } + member { + name: "Asset" + mtype: "<type \'type\'>" + } member { name: "Builder" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-asset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-asset.pbtxt new file mode 100644 index 00000000000..0a20385c329 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-asset.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.saved_model.Asset" +tf_class { + is_instance: "<class \'tensorflow.python.training.tracking.tracking.Asset\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<type \'object\'>" + member { + name: "asset_path" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'path\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt index 94fa0eaad53..0fef87b09c8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "ASSETS_KEY" mtype: "<type \'str\'>" } + member { + name: "Asset" + mtype: "<type \'type\'>" + } member { name: "CLASSIFY_INPUTS" mtype: "<type \'str\'>"