Don't automatically add control deps to collective ops.

These ops need to run asynchronously to avoid deadlock.

PiperOrigin-RevId: 226397820
This commit is contained in:
Skye Wanderman-Milne 2018-12-20 14:53:11 -08:00 committed by TensorFlower Gardener
parent d5216948d1
commit 5ff27167b2

View File

@ -29,13 +29,22 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
# Op types that should not run in program order, e.g. because they need to run
# asynchronously to avoid deadlock.
ASYNC_STATEFUL_OPS = [
"CollectiveReduce",
"CollectiveBcastSend",
"CollectiveBcastRecv",
]
class AutomaticControlDependencies(object):
"""Context manager to automatically add control dependencies.
Code under this context manager will act as if a sensible set of control
dependencies were present. More specifically:
1. All stateful ops in the scope will execute
1. All stateful ops in the scope will execute (with the exception of ops in
ASYNC_STATEFUL_OPS)
2. Stateful ops which modify the same resource will execute in program order
Note: creating variables in an automatic control dependencies context is not
@ -223,7 +232,8 @@ class AutomaticControlDependencies(object):
control_inputs = set()
# Ensure stateful ops run
if (op.type not in self._graph._registered_ops # pylint: disable=protected-access
or self._graph._registered_ops[op.type].is_stateful): # pylint: disable=protected-access
or (self._graph._registered_ops[op.type].is_stateful # pylint: disable=protected-access
and op.type not in ASYNC_STATEFUL_OPS)):
ops_which_must_run.add(op)
# Ignore switches (they're handled separately)
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
@ -255,8 +265,8 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
if (op.op_def.is_stateful and not found_resource
and op._control_flow_context is None): # pylint: disable=protected-access
if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS
and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access
if None in last_op_using_resource_tensor:
op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
last_op_using_resource_tensor[None] = op