Don't automatically add control deps to collective ops.
These ops need to run asynchronously to avoid deadlock. PiperOrigin-RevId: 226397820
This commit is contained in:
parent
d5216948d1
commit
5ff27167b2
@ -29,13 +29,22 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
# Op types that should not run in program order, e.g. because they need to run
|
||||
# asynchronously to avoid deadlock.
|
||||
ASYNC_STATEFUL_OPS = [
|
||||
"CollectiveReduce",
|
||||
"CollectiveBcastSend",
|
||||
"CollectiveBcastRecv",
|
||||
]
|
||||
|
||||
|
||||
class AutomaticControlDependencies(object):
|
||||
"""Context manager to automatically add control dependencies.
|
||||
|
||||
Code under this context manager will act as if a sensible set of control
|
||||
dependencies were present. More specifically:
|
||||
1. All stateful ops in the scope will execute
|
||||
1. All stateful ops in the scope will execute (with the exception of ops in
|
||||
ASYNC_STATEFUL_OPS)
|
||||
2. Stateful ops which modify the same resource will execute in program order
|
||||
|
||||
Note: creating variables in an automatic control dependencies context is not
|
||||
@ -223,7 +232,8 @@ class AutomaticControlDependencies(object):
|
||||
control_inputs = set()
|
||||
# Ensure stateful ops run
|
||||
if (op.type not in self._graph._registered_ops # pylint: disable=protected-access
|
||||
or self._graph._registered_ops[op.type].is_stateful): # pylint: disable=protected-access
|
||||
or (self._graph._registered_ops[op.type].is_stateful # pylint: disable=protected-access
|
||||
and op.type not in ASYNC_STATEFUL_OPS)):
|
||||
ops_which_must_run.add(op)
|
||||
# Ignore switches (they're handled separately)
|
||||
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
|
||||
@ -255,8 +265,8 @@ class AutomaticControlDependencies(object):
|
||||
if inp in merge_for_resource:
|
||||
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
|
||||
last_op_using_resource_tensor[inp] = op
|
||||
if (op.op_def.is_stateful and not found_resource
|
||||
and op._control_flow_context is None): # pylint: disable=protected-access
|
||||
if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS
|
||||
and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access
|
||||
if None in last_op_using_resource_tensor:
|
||||
op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
|
||||
last_op_using_resource_tensor[None] = op
|
||||
|
Loading…
Reference in New Issue
Block a user