Factor out scope tracking from TrackableResource
SavedModel will need to track basically every Tensor someone might keep around. We don't really want to collect all of those with resource_tracker_scope. Removes an incorrect warning from dataset usage. SavedModel will still re-create everything as TrackableResources for now; we could make this distinction in the saved format if it's helpful, although I don't see a good reason for now. PiperOrigin-RevId: 246204427
This commit is contained in:
parent
f9a682e70f
commit
2ca698bc9a
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -207,6 +209,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset_fn = self.make_interleave_fn(*interleave_fn_args)
|
||||
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
|
||||
|
||||
def testNoWarnings(self):
|
||||
with test.mock.patch.object(warnings, "warn") as mock_log:
|
||||
dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
|
||||
dataset_fn(dataset_ops.Dataset.range(10))
|
||||
self.assertEmpty(mock_log.call_args_list)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Concatenate", lambda x, y: x.concatenate(y),
|
||||
lambda: dataset_ops.Dataset.range(0),
|
||||
|
@ -2878,7 +2878,7 @@ class BatchDataset(UnaryDataset):
|
||||
return self._structure
|
||||
|
||||
|
||||
class _VariantTracker(tracking.TrackableResource):
|
||||
class _VariantTracker(tracking.CapturableResource):
|
||||
"""Allows export of functions capturing a Dataset in SavedModels.
|
||||
|
||||
When saving a SavedModel, `tf.saved_model.save` traverses the object
|
||||
|
@ -66,7 +66,7 @@ class _Loader(object):
|
||||
self._restore_checkpoint()
|
||||
|
||||
for node in self._nodes:
|
||||
if isinstance(node, tracking.TrackableResource):
|
||||
if isinstance(node, tracking.CapturableResource):
|
||||
init_op = node._initialize() # pylint: disable=protected-access
|
||||
if not context.executing_eagerly():
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||
@ -122,8 +122,8 @@ class _Loader(object):
|
||||
return obj.asset_path
|
||||
elif tensor_util.is_tensor(obj):
|
||||
return obj
|
||||
elif isinstance(obj, tracking.TrackableResource):
|
||||
# Note: this executes restored functions in the TrackableResource.
|
||||
elif isinstance(obj, tracking.CapturableResource):
|
||||
# Note: this executes restored functions in the CapturableResource.
|
||||
return obj.resource_handle
|
||||
raise ValueError("Can't convert node %s to tensor" % (type(obj)))
|
||||
|
||||
|
@ -30,7 +30,7 @@ from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
class _Initializer(tracking.TrackableResource):
|
||||
class _Initializer(tracking.CapturableResource):
|
||||
"""Represents an initialization operation restored from a SavedModel.
|
||||
|
||||
Without this object re-export of imported 1.x SavedModels would omit the
|
||||
|
@ -220,7 +220,7 @@ class _SaveableView(object):
|
||||
asset_filename_map={},
|
||||
asset_index={})
|
||||
for node_id, obj in enumerate(self.nodes):
|
||||
if isinstance(obj, tracking.TrackableResource):
|
||||
if isinstance(obj, tracking.CapturableResource):
|
||||
# pylint: disable=protected-access
|
||||
with ops.device(obj._resource_device):
|
||||
new_resource = obj._create_resource()
|
||||
@ -424,7 +424,7 @@ def _generate_signatures(signature_functions, resource_map):
|
||||
|
||||
|
||||
def _trace_resource_initializers(accessible_objects):
|
||||
"""Create concrete functions from `TrackableResource` objects."""
|
||||
"""Create concrete functions from `CapturableResource` objects."""
|
||||
resource_initializers = []
|
||||
|
||||
def _wrap_initializer(obj):
|
||||
@ -435,7 +435,7 @@ def _trace_resource_initializers(accessible_objects):
|
||||
return lambda: _wrap_initializer(obj)
|
||||
|
||||
for obj in accessible_objects:
|
||||
if isinstance(obj, tracking.TrackableResource):
|
||||
if isinstance(obj, tracking.CapturableResource):
|
||||
resource_initializers.append(def_function.function(
|
||||
_wrap_obj_initializer(obj),
|
||||
# All inputs are captures.
|
||||
@ -608,7 +608,7 @@ def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
function_serialization.serialize_bare_concrete_function(obj))
|
||||
elif isinstance(obj, _CapturedConstant):
|
||||
proto.constant.operation = obj.graph_tensor.op.name
|
||||
elif isinstance(obj, tracking.TrackableResource):
|
||||
elif isinstance(obj, tracking.CapturableResource):
|
||||
proto.resource.device = obj._resource_device # pylint: disable=protected-access
|
||||
else:
|
||||
registered_type_proto = revived_types.serialize(obj)
|
||||
|
@ -150,11 +150,18 @@ def resource_tracker_scope(resource_tracker):
|
||||
_RESOURCE_TRACKER_STACK = old
|
||||
|
||||
|
||||
class TrackableResource(base.Trackable):
|
||||
"""Base class for all resources that need to be tracked."""
|
||||
class CapturableResource(base.Trackable):
|
||||
"""Holds a Tensor which a tf.function can capture.
|
||||
|
||||
`CapturableResource`s are discovered by traversing the graph of object
|
||||
attributes, e.g. during `tf.saved_model.save`. They are excluded from the
|
||||
scope-based tracking of `TrackableResource`; generally things that require
|
||||
initialization should inherit from `TrackableResource` instead of
|
||||
`CapturableResource` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, device=""):
|
||||
"""Initialize the `TrackableResource`.
|
||||
"""Initialize the `CapturableResource`.
|
||||
|
||||
Args:
|
||||
device: A string indicating a required placement for this resource,
|
||||
@ -162,10 +169,6 @@ class TrackableResource(base.Trackable):
|
||||
device allows the user to place resource creation, so generally this
|
||||
should be blank unless the resource only makes sense on one device.
|
||||
"""
|
||||
global _RESOURCE_TRACKER_STACK
|
||||
for resource_tracker in _RESOURCE_TRACKER_STACK:
|
||||
resource_tracker.add_resource(self)
|
||||
|
||||
self._resource_handle = None
|
||||
self._resource_device = device
|
||||
|
||||
@ -203,6 +206,24 @@ class TrackableResource(base.Trackable):
|
||||
}
|
||||
|
||||
|
||||
class TrackableResource(CapturableResource):
|
||||
"""Adds scope tracking to CapturableResource."""
|
||||
|
||||
def __init__(self, device=""):
|
||||
"""Initialize the `TrackableResource`.
|
||||
|
||||
Args:
|
||||
device: A string indicating a required placement for this resource,
|
||||
e.g. "CPU" if this resource must be created on a CPU device. A blank
|
||||
device allows the user to place resource creation, so generally this
|
||||
should be blank unless the resource only makes sense on one device.
|
||||
"""
|
||||
global _RESOURCE_TRACKER_STACK
|
||||
for resource_tracker in _RESOURCE_TRACKER_STACK:
|
||||
resource_tracker.add_resource(self)
|
||||
super(TrackableResource, self).__init__(device=device)
|
||||
|
||||
|
||||
class TrackableAsset(base.Trackable):
|
||||
"""Base class for asset files which need to be tracked."""
|
||||
|
||||
|
@ -5,6 +5,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -4,6 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -3,6 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -4,6 +4,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -3,6 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
@ -3,6 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
|
Loading…
x
Reference in New Issue
Block a user