Provide mechanism for registering custom resource tensor resolvers for ACD.

ACD only looks at the direct resource inputs of stateful ops. This doesn't work for cases where ops access resources indirectly e.g. consumers of TPUReplicatedInput and in tf.data where the MapDatasetOp may be touching a resource but we need to add control dep from the ReduceDatasetOp. This mechanism will provide a way to notify ACD of the indirect resource accesses of an op.

PiperOrigin-RevId: 290063112
Change-Id: I329007eb99fce2dee9dda03593651992086d0b18
This commit is contained in:
Saurabh Saxena 2020-01-16 07:33:00 -08:00 committed by TensorFlower Gardener
parent 19c96a602f
commit f72e3a7ce8
4 changed files with 110 additions and 7 deletions

View File

@ -22,6 +22,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import registry
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -321,10 +322,7 @@ class AutomaticControlDependencies(object):
resource_inputs = set()
# Check for any resource inputs. If we find any, we update control_inputs
# and last_op_using_resource_tensor.
for inp in op.inputs:
if inp.dtype != dtypes_module.resource:
continue
for inp in _get_resource_inputs(op):
input_id = ops.tensor_id(inp)
# If the op receives the same resource tensor twice as an input, we skip
@ -338,9 +336,11 @@ class AutomaticControlDependencies(object):
self._process_switch(inp.op, ops_which_must_run,
last_op_using_resource_tensor,
merge_for_resource)
is_building_function = op.graph.building_function
# Ensure uses of resources are serialized
if input_id in last_op_using_resource_tensor:
if (last_op_using_resource_tensor[input_id]._control_flow_context # pylint: disable=protected-access
if is_building_function or (
last_op_using_resource_tensor[input_id]._control_flow_context # pylint: disable=protected-access
is op._control_flow_context): # pylint: disable=protected-access
control_inputs.add(last_op_using_resource_tensor[input_id])
# Ensure merges happen after the closing of a cond block
@ -353,8 +353,9 @@ class AutomaticControlDependencies(object):
if None in last_op_using_resource_tensor:
op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
control_inputs = [
c for c in control_inputs if is_building_function or
(c._control_flow_context is op._control_flow_context)] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
# Ensure all ops which must run do run
@ -369,6 +370,60 @@ class AutomaticControlDependencies(object):
])
_acd_resource_resolvers_registry = registry.Registry("acd_resouce_resolvers")
def register_acd_resource_resolver(f):
"""Register a function for resolving resources touched by an op.
Example:
@register_acd_resource_resolver
def ResolveIdentity(op, resource_inputs):
# op: The `Operation` being processed by ACD currently.
# resource_inputs: An `ObjectIdentitySet` that can be updated in-place.
if not resource_inputs:
return False
to_add = []
to_remove = []
for t in resource_inputs:
if t.op.type == "Identity":
to_remove.append(t)
to_add.append(t.op.inputs[0])
if not to_add and not to_remove:
return False
for t in to_remove:
resource_inputs.discard(t)
resource_inputs.update(to_add)
return True # `resource_inputs` was updated.
Args:
f: Python function
Returns:
The function `f` after adding it to the registry.
"""
_acd_resource_resolvers_registry.register(f)
return f
def _get_resource_inputs(op):
"""Returns an iterable of resources touched by this `op`."""
resource_inputs = object_identity.ObjectIdentitySet(
t for t in op.inputs if t.dtype == dtypes_module.resource)
saturated = False
while not saturated:
saturated = True
for key in _acd_resource_resolvers_registry.list():
# Resolvers should return true if they are updating the list of
# resource_inputs.
# TODO(srbs): An alternate would be to just compare the old and new set
# but that may not be as fast.
updated = _acd_resource_resolvers_registry.lookup(key)(op,
resource_inputs)
saturated = saturated and not updated
return resource_inputs
def automatic_control_dependencies(f):
"""Wraps f to automatically insert control dependencies.

View File

@ -29,6 +29,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.compiler.xla import xla
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import config
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
@ -204,6 +205,35 @@ def _enclosing_tpu_device_assignment():
return strategy.extended._device_assignment # pylint: disable=protected-access
@auto_control_deps.register_acd_resource_resolver
def tpu_replicated_input_resolver(op, resource_inputs):
"""Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
# Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
# control deps on the replicated inputs.
if op.type == "TPUReplicatedInput":
if resource_inputs:
resource_inputs.clear()
return True
else:
return False
# Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
# with the actual replicated inputs. This allows ACD to correct add control
# deps when there are multiple calls to `experimental_run_v2` in a
# `tf.function`.
to_remove = []
to_add = []
for resource in resource_inputs:
if resource.op.type == "TPUReplicatedInput":
to_remove.append(resource)
to_add.extend(resource.op.inputs)
if not to_add and not to_remove:
return False
for t in to_remove:
resource_inputs.discard(t)
resource_inputs.update(to_add)
return True
class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside a TPU computation.

View File

@ -195,6 +195,9 @@ class ObjectIdentitySet(collections_abc.MutableSet):
def update(self, items):
self._storage.update([self._wrap_key(item) for item in items])
def clear(self):
self._storage.clear()
def intersection(self, items):
return self._storage.intersection([self._wrap_key(item) for item in items])

View File

@ -85,6 +85,21 @@ class ObjectIdentitySetTest(test.TestCase):
self.assertNotIn(b, diff_set)
self.assertNotIn(c, diff_set)
def testDiscard(self):
a = object()
b = object()
set1 = object_identity.ObjectIdentitySet([a, b])
set1.discard(a)
self.assertIn(b, set1)
self.assertNotIn(a, set1)
def testClear(self):
a = object()
b = object()
set1 = object_identity.ObjectIdentitySet([a, b])
set1.clear()
self.assertLen(set1, 0)
if __name__ == '__main__':
test.main()