Re-track resource children when loading.
Also updated the comment in saved_object_graph.proto to note that resources can also have "children", as the saving code does not explicitly filter for user_object. PiperOrigin-RevId: 360768531 Change-Id: Ib9e708d2a112fb8e89be8a4570053a12368cd3a3
This commit is contained in:
parent
aa4c75e42a
commit
149691c4b1
tensorflow
@ -37,7 +37,7 @@ message SavedObject {
|
||||
// Objects which this object depends on: named edges in the dependency
|
||||
// graph.
|
||||
//
|
||||
// Note: currently only valid if kind == "user_object".
|
||||
// Note: currently only valid if kind == "user_object" or "resource".
|
||||
repeated TrackableObjectGraph.TrackableObject.ObjectReference children = 1;
|
||||
|
||||
// Removed when forking SavedObject from TrackableObjectGraph.
|
||||
|
@ -628,7 +628,7 @@ class Loader(object):
|
||||
return imported_constant, setattr
|
||||
|
||||
def _recreate_resource(self, proto):
|
||||
return _RestoredResource(device=proto.device), setattr
|
||||
return _RestoredResource(device=proto.device), _setattr_and_track
|
||||
|
||||
|
||||
# TODO(b/124205571,b/124092991): Solve destruction of resources.
|
||||
@ -672,6 +672,13 @@ def _call_attribute(instance, *args, **kwargs):
|
||||
return instance.__call__(*args, **kwargs)
|
||||
|
||||
|
||||
def _setattr_and_track(obj, name, value):
|
||||
"""Sets new attribute and marks it as a dependency if Trackable."""
|
||||
setattr(obj, name, value)
|
||||
if isinstance(value, base.Trackable):
|
||||
obj._track_trackable(value, name) # pylint:disable=protected-access
|
||||
|
||||
|
||||
@tf_export("__internal__.saved_model.load_partial", v1=[])
|
||||
def load_partial(export_dir, filters, tags=None, options=None):
|
||||
"""Partially load a SavedModel (saved from V2).
|
||||
|
@ -1968,6 +1968,35 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
|
||||
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
|
||||
|
||||
def test_load_resource_with_dependency(self, cycles):
|
||||
# Test with StaticHashTable, which has a _initializer attribute that tracks
|
||||
# the Asset vocab table.
|
||||
|
||||
class MyLookupModel(tracking.AutoTrackable):
|
||||
|
||||
def __init__(self, vocab_file):
|
||||
|
||||
vocab_initializer = lookup_ops.TextFileInitializer(
|
||||
vocab_file,
|
||||
key_dtype=dtypes.string,
|
||||
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
|
||||
value_dtype=dtypes.int64,
|
||||
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
|
||||
self._vocab_table = lookup_ops.StaticHashTable(vocab_initializer,
|
||||
default_value=-1)
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec((None,), dtypes.string)])
|
||||
def __call__(self, inputs):
|
||||
return self._vocab_table.lookup(inputs)
|
||||
|
||||
vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"]))
|
||||
root = MyLookupModel(vocab_file)
|
||||
imported = cycle(root, cycles)
|
||||
file_io.delete_file(vocab_file)
|
||||
self.assertAllEqual(imported(constant_op.constant(["d", "b"])),
|
||||
[3, 1])
|
||||
|
||||
|
||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user