Provide mechanism for registering custom resource tensor resolvers for ACD.
ACD only looks at the direct resource inputs of stateful ops. This doesn't work for cases where ops access resources indirectly e.g. consumers of TPUReplicatedInput and in tf.data where the MapDatasetOp may be touching a resource but we need to add control dep from the ReduceDatasetOp. This mechanism will provide a way to notify ACD of the indirect resource accesses of an op. PiperOrigin-RevId: 290063112 Change-Id: I329007eb99fce2dee9dda03593651992086d0b18
This commit is contained in:
parent
19c96a602f
commit
f72e3a7ce8
@ -22,6 +22,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes as dtypes_module
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import registry
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -321,10 +322,7 @@ class AutomaticControlDependencies(object):
|
||||
resource_inputs = set()
|
||||
# Check for any resource inputs. If we find any, we update control_inputs
|
||||
# and last_op_using_resource_tensor.
|
||||
for inp in op.inputs:
|
||||
if inp.dtype != dtypes_module.resource:
|
||||
continue
|
||||
|
||||
for inp in _get_resource_inputs(op):
|
||||
input_id = ops.tensor_id(inp)
|
||||
|
||||
# If the op receives the same resource tensor twice as an input, we skip
|
||||
@ -338,9 +336,11 @@ class AutomaticControlDependencies(object):
|
||||
self._process_switch(inp.op, ops_which_must_run,
|
||||
last_op_using_resource_tensor,
|
||||
merge_for_resource)
|
||||
is_building_function = op.graph.building_function
|
||||
# Ensure uses of resources are serialized
|
||||
if input_id in last_op_using_resource_tensor:
|
||||
if (last_op_using_resource_tensor[input_id]._control_flow_context # pylint: disable=protected-access
|
||||
if is_building_function or (
|
||||
last_op_using_resource_tensor[input_id]._control_flow_context # pylint: disable=protected-access
|
||||
is op._control_flow_context): # pylint: disable=protected-access
|
||||
control_inputs.add(last_op_using_resource_tensor[input_id])
|
||||
# Ensure merges happen after the closing of a cond block
|
||||
@ -353,8 +353,9 @@ class AutomaticControlDependencies(object):
|
||||
if None in last_op_using_resource_tensor:
|
||||
op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
|
||||
last_op_using_resource_tensor[None] = op
|
||||
control_inputs = [c for c in control_inputs
|
||||
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
|
||||
control_inputs = [
|
||||
c for c in control_inputs if is_building_function or
|
||||
(c._control_flow_context is op._control_flow_context)] # pylint: disable=protected-access
|
||||
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
|
||||
|
||||
# Ensure all ops which must run do run
|
||||
@ -369,6 +370,60 @@ class AutomaticControlDependencies(object):
|
||||
])
|
||||
|
||||
|
||||
_acd_resource_resolvers_registry = registry.Registry("acd_resouce_resolvers")
|
||||
|
||||
|
||||
def register_acd_resource_resolver(f):
|
||||
"""Register a function for resolving resources touched by an op.
|
||||
|
||||
Example:
|
||||
@register_acd_resource_resolver
|
||||
def ResolveIdentity(op, resource_inputs):
|
||||
# op: The `Operation` being processed by ACD currently.
|
||||
# resource_inputs: An `ObjectIdentitySet` that can be updated in-place.
|
||||
if not resource_inputs:
|
||||
return False
|
||||
to_add = []
|
||||
to_remove = []
|
||||
for t in resource_inputs:
|
||||
if t.op.type == "Identity":
|
||||
to_remove.append(t)
|
||||
to_add.append(t.op.inputs[0])
|
||||
if not to_add and not to_remove:
|
||||
return False
|
||||
for t in to_remove:
|
||||
resource_inputs.discard(t)
|
||||
resource_inputs.update(to_add)
|
||||
return True # `resource_inputs` was updated.
|
||||
|
||||
Args:
|
||||
f: Python function
|
||||
|
||||
Returns:
|
||||
The function `f` after adding it to the registry.
|
||||
"""
|
||||
_acd_resource_resolvers_registry.register(f)
|
||||
return f
|
||||
|
||||
|
||||
def _get_resource_inputs(op):
|
||||
"""Returns an iterable of resources touched by this `op`."""
|
||||
resource_inputs = object_identity.ObjectIdentitySet(
|
||||
t for t in op.inputs if t.dtype == dtypes_module.resource)
|
||||
saturated = False
|
||||
while not saturated:
|
||||
saturated = True
|
||||
for key in _acd_resource_resolvers_registry.list():
|
||||
# Resolvers should return true if they are updating the list of
|
||||
# resource_inputs.
|
||||
# TODO(srbs): An alternate would be to just compare the old and new set
|
||||
# but that may not be as fast.
|
||||
updated = _acd_resource_resolvers_registry.lookup(key)(op,
|
||||
resource_inputs)
|
||||
saturated = saturated and not updated
|
||||
return resource_inputs
|
||||
|
||||
|
||||
def automatic_control_dependencies(f):
|
||||
"""Wraps f to automatically insert control dependencies.
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import auto_control_deps
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -204,6 +205,35 @@ def _enclosing_tpu_device_assignment():
|
||||
return strategy.extended._device_assignment # pylint: disable=protected-access
|
||||
|
||||
|
||||
@auto_control_deps.register_acd_resource_resolver
|
||||
def tpu_replicated_input_resolver(op, resource_inputs):
|
||||
"""Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
|
||||
# Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
|
||||
# control deps on the replicated inputs.
|
||||
if op.type == "TPUReplicatedInput":
|
||||
if resource_inputs:
|
||||
resource_inputs.clear()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
# Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
|
||||
# with the actual replicated inputs. This allows ACD to correct add control
|
||||
# deps when there are multiple calls to `experimental_run_v2` in a
|
||||
# `tf.function`.
|
||||
to_remove = []
|
||||
to_add = []
|
||||
for resource in resource_inputs:
|
||||
if resource.op.type == "TPUReplicatedInput":
|
||||
to_remove.append(resource)
|
||||
to_add.extend(resource.op.inputs)
|
||||
if not to_add and not to_remove:
|
||||
return False
|
||||
for t in to_remove:
|
||||
resource_inputs.discard(t)
|
||||
resource_inputs.update(to_add)
|
||||
return True
|
||||
|
||||
|
||||
class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
|
||||
"""A `ControlFlowContext` for nodes inside a TPU computation.
|
||||
|
||||
|
@ -195,6 +195,9 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
||||
def update(self, items):
|
||||
self._storage.update([self._wrap_key(item) for item in items])
|
||||
|
||||
def clear(self):
|
||||
self._storage.clear()
|
||||
|
||||
def intersection(self, items):
|
||||
return self._storage.intersection([self._wrap_key(item) for item in items])
|
||||
|
||||
|
@ -85,6 +85,21 @@ class ObjectIdentitySetTest(test.TestCase):
|
||||
self.assertNotIn(b, diff_set)
|
||||
self.assertNotIn(c, diff_set)
|
||||
|
||||
def testDiscard(self):
|
||||
a = object()
|
||||
b = object()
|
||||
set1 = object_identity.ObjectIdentitySet([a, b])
|
||||
set1.discard(a)
|
||||
self.assertIn(b, set1)
|
||||
self.assertNotIn(a, set1)
|
||||
|
||||
def testClear(self):
|
||||
a = object()
|
||||
b = object()
|
||||
set1 = object_identity.ObjectIdentitySet([a, b])
|
||||
set1.clear()
|
||||
self.assertLen(set1, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user