Factor out scope tracking from TrackableResource

SavedModel will need to track basically every Tensor someone might keep around. We don't really want to collect all of those with resource_tracker_scope.

Removes an incorrect warning from dataset usage.

SavedModel will still re-create everything as TrackableResources for now; we could make this distinction in the saved format if it's helpful, although I don't see a good reason for now.

PiperOrigin-RevId: 246204427
This commit is contained in:
Allen Lavoie 2019-05-01 14:54:28 -07:00 committed by TensorFlower Gardener
parent f9a682e70f
commit 2ca698bc9a
12 changed files with 51 additions and 16 deletions

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
from absl.testing import parameterized
import numpy as np
@ -207,6 +209,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset_fn = self.make_interleave_fn(*interleave_fn_args)
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
def testNoWarnings(self):
with test.mock.patch.object(warnings, "warn") as mock_log:
dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
dataset_fn(dataset_ops.Dataset.range(10))
self.assertEmpty(mock_log.call_args_list)
@parameterized.named_parameters(
("Concatenate", lambda x, y: x.concatenate(y),
lambda: dataset_ops.Dataset.range(0),

View File

@ -2878,7 +2878,7 @@ class BatchDataset(UnaryDataset):
return self._structure
class _VariantTracker(tracking.TrackableResource):
class _VariantTracker(tracking.CapturableResource):
"""Allows export of functions capturing a Dataset in SavedModels.
When saving a SavedModel, `tf.saved_model.save` traverses the object

View File

@ -66,7 +66,7 @@ class _Loader(object):
self._restore_checkpoint()
for node in self._nodes:
if isinstance(node, tracking.TrackableResource):
if isinstance(node, tracking.CapturableResource):
init_op = node._initialize() # pylint: disable=protected-access
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
@ -122,8 +122,8 @@ class _Loader(object):
return obj.asset_path
elif tensor_util.is_tensor(obj):
return obj
elif isinstance(obj, tracking.TrackableResource):
# Note: this executes restored functions in the TrackableResource.
elif isinstance(obj, tracking.CapturableResource):
# Note: this executes restored functions in the CapturableResource.
return obj.resource_handle
raise ValueError("Can't convert node %s to tensor" % (type(obj)))

View File

@ -30,7 +30,7 @@ from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training.tracking import tracking
class _Initializer(tracking.TrackableResource):
class _Initializer(tracking.CapturableResource):
"""Represents an initialization operation restored from a SavedModel.
Without this object re-export of imported 1.x SavedModels would omit the

View File

@ -220,7 +220,7 @@ class _SaveableView(object):
asset_filename_map={},
asset_index={})
for node_id, obj in enumerate(self.nodes):
if isinstance(obj, tracking.TrackableResource):
if isinstance(obj, tracking.CapturableResource):
# pylint: disable=protected-access
with ops.device(obj._resource_device):
new_resource = obj._create_resource()
@ -424,7 +424,7 @@ def _generate_signatures(signature_functions, resource_map):
def _trace_resource_initializers(accessible_objects):
"""Create concrete functions from `TrackableResource` objects."""
"""Create concrete functions from `CapturableResource` objects."""
resource_initializers = []
def _wrap_initializer(obj):
@ -435,7 +435,7 @@ def _trace_resource_initializers(accessible_objects):
return lambda: _wrap_initializer(obj)
for obj in accessible_objects:
if isinstance(obj, tracking.TrackableResource):
if isinstance(obj, tracking.CapturableResource):
resource_initializers.append(def_function.function(
_wrap_obj_initializer(obj),
# All inputs are captures.
@ -608,7 +608,7 @@ def _write_object_proto(obj, proto, asset_file_def_index):
function_serialization.serialize_bare_concrete_function(obj))
elif isinstance(obj, _CapturedConstant):
proto.constant.operation = obj.graph_tensor.op.name
elif isinstance(obj, tracking.TrackableResource):
elif isinstance(obj, tracking.CapturableResource):
proto.resource.device = obj._resource_device # pylint: disable=protected-access
else:
registered_type_proto = revived_types.serialize(obj)

View File

@ -150,11 +150,18 @@ def resource_tracker_scope(resource_tracker):
_RESOURCE_TRACKER_STACK = old
class TrackableResource(base.Trackable):
"""Base class for all resources that need to be tracked."""
class CapturableResource(base.Trackable):
"""Holds a Tensor which a tf.function can capture.
`CapturableResource`s are discovered by traversing the graph of object
attributes, e.g. during `tf.saved_model.save`. They are excluded from the
scope-based tracking of `TrackableResource`; generally things that require
initialization should inherit from `TrackableResource` instead of
`CapturableResource` directly.
"""
def __init__(self, device=""):
"""Initialize the `TrackableResource`.
"""Initialize the `CapturableResource`.
Args:
device: A string indicating a required placement for this resource,
@ -162,10 +169,6 @@ class TrackableResource(base.Trackable):
device allows the user to place resource creation, so generally this
should be blank unless the resource only makes sense on one device.
"""
global _RESOURCE_TRACKER_STACK
for resource_tracker in _RESOURCE_TRACKER_STACK:
resource_tracker.add_resource(self)
self._resource_handle = None
self._resource_device = device
@ -203,6 +206,24 @@ class TrackableResource(base.Trackable):
}
class TrackableResource(CapturableResource):
"""Adds scope tracking to CapturableResource."""
def __init__(self, device=""):
"""Initialize the `TrackableResource`.
Args:
device: A string indicating a required placement for this resource,
e.g. "CPU" if this resource must be created on a CPU device. A blank
device allows the user to place resource creation, so generally this
should be blank unless the resource only makes sense on one device.
"""
global _RESOURCE_TRACKER_STACK
for resource_tracker in _RESOURCE_TRACKER_STACK:
resource_tracker.add_resource(self)
super(TrackableResource, self).__init__(device=device)
class TrackableAsset(base.Trackable):
"""Base class for asset files which need to be tracked."""

View File

@ -5,6 +5,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {

View File

@ -4,6 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {

View File

@ -3,6 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {

View File

@ -4,6 +4,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.InitializableLookupTableBase\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {

View File

@ -3,6 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.StaticVocabularyTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {

View File

@ -3,6 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DenseHashTable\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.LookupInterface\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.TrackableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.CapturableResource\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {