Re-track resource children when loading.

Also updated the comment in saved_object_graph.proto to note that resources can also have "children", as the saving code does not explicitly filter for user_object.

PiperOrigin-RevId: 360768531
Change-Id: Ib9e708d2a112fb8e89be8a4570053a12368cd3a3
This commit is contained in:
Katherine Wu 2021-03-03 15:18:48 -08:00 committed by TensorFlower Gardener
parent aa4c75e42a
commit 149691c4b1
3 changed files with 38 additions and 2 deletions
tensorflow
core/protobuf
python/saved_model

View File

@ -37,7 +37,7 @@ message SavedObject {
// Objects which this object depends on: named edges in the dependency
// graph.
//
// Note: currently only valid if kind == "user_object".
// Note: currently only valid if kind == "user_object" or "resource".
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
// Removed when forking SavedObject from TrackableObjectGraph.

View File

@ -628,7 +628,7 @@ class Loader(object):
return imported_constant, setattr
def _recreate_resource(self, proto):
return _RestoredResource(device=proto.device), setattr
return _RestoredResource(device=proto.device), _setattr_and_track
# TODO(b/124205571,b/124092991): Solve destruction of resources.
@ -672,6 +672,13 @@ def _call_attribute(instance, *args, **kwargs):
return instance.__call__(*args, **kwargs)
def _setattr_and_track(obj, name, value):
"""Sets new attribute and marks it as a dependency if Trackable."""
setattr(obj, name, value)
if isinstance(value, base.Trackable):
obj._track_trackable(value, name) # pylint:disable=protected-access
@tf_export("__internal__.saved_model.load_partial", v1=[])
def load_partial(export_dir, filters, tags=None, options=None):
"""Partially load a SavedModel (saved from V2).

View File

@ -1968,6 +1968,35 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
def test_load_resource_with_dependency(self, cycles):
# Test with StaticHashTable, which has a _initializer attribute that tracks
# the Asset vocab table.
class MyLookupModel(tracking.AutoTrackable):
def __init__(self, vocab_file):
vocab_initializer = lookup_ops.TextFileInitializer(
vocab_file,
key_dtype=dtypes.string,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
self._vocab_table = lookup_ops.StaticHashTable(vocab_initializer,
default_value=-1)
@def_function.function(input_signature=[
tensor_spec.TensorSpec((None,), dtypes.string)])
def __call__(self, inputs):
return self._vocab_table.lookup(inputs)
vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"]))
root = MyLookupModel(vocab_file)
imported = cycle(root, cycles)
file_io.delete_file(vocab_file)
self.assertAllEqual(imported(constant_op.constant(["d", "b"])),
[3, 1])
class SingleCycleTests(test.TestCase, parameterized.TestCase):