Don't automatically add control deps to collective ops.
These ops need to run asynchronously to avoid deadlock. PiperOrigin-RevId: 226397820
This commit is contained in:
parent
d5216948d1
commit
5ff27167b2
@ -29,13 +29,22 @@ from tensorflow.python.ops import tensor_array_ops
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
|
# Op types that should not run in program order, e.g. because they need to run
|
||||||
|
# asynchronously to avoid deadlock.
|
||||||
|
ASYNC_STATEFUL_OPS = [
|
||||||
|
"CollectiveReduce",
|
||||||
|
"CollectiveBcastSend",
|
||||||
|
"CollectiveBcastRecv",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AutomaticControlDependencies(object):
|
class AutomaticControlDependencies(object):
|
||||||
"""Context manager to automatically add control dependencies.
|
"""Context manager to automatically add control dependencies.
|
||||||
|
|
||||||
Code under this context manager will act as if a sensible set of control
|
Code under this context manager will act as if a sensible set of control
|
||||||
dependencies were present. More specifically:
|
dependencies were present. More specifically:
|
||||||
1. All stateful ops in the scope will execute
|
1. All stateful ops in the scope will execute (with the exception of ops in
|
||||||
|
ASYNC_STATEFUL_OPS)
|
||||||
2. Stateful ops which modify the same resource will execute in program order
|
2. Stateful ops which modify the same resource will execute in program order
|
||||||
|
|
||||||
Note: creating variables in an automatic control dependencies context is not
|
Note: creating variables in an automatic control dependencies context is not
|
||||||
@ -223,7 +232,8 @@ class AutomaticControlDependencies(object):
|
|||||||
control_inputs = set()
|
control_inputs = set()
|
||||||
# Ensure stateful ops run
|
# Ensure stateful ops run
|
||||||
if (op.type not in self._graph._registered_ops # pylint: disable=protected-access
|
if (op.type not in self._graph._registered_ops # pylint: disable=protected-access
|
||||||
or self._graph._registered_ops[op.type].is_stateful): # pylint: disable=protected-access
|
or (self._graph._registered_ops[op.type].is_stateful # pylint: disable=protected-access
|
||||||
|
and op.type not in ASYNC_STATEFUL_OPS)):
|
||||||
ops_which_must_run.add(op)
|
ops_which_must_run.add(op)
|
||||||
# Ignore switches (they're handled separately)
|
# Ignore switches (they're handled separately)
|
||||||
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
|
if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
|
||||||
@ -255,8 +265,8 @@ class AutomaticControlDependencies(object):
|
|||||||
if inp in merge_for_resource:
|
if inp in merge_for_resource:
|
||||||
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
|
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
|
||||||
last_op_using_resource_tensor[inp] = op
|
last_op_using_resource_tensor[inp] = op
|
||||||
if (op.op_def.is_stateful and not found_resource
|
if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS
|
||||||
and op._control_flow_context is None): # pylint: disable=protected-access
|
and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access
|
||||||
if None in last_op_using_resource_tensor:
|
if None in last_op_using_resource_tensor:
|
||||||
op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
|
op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
|
||||||
last_op_using_resource_tensor[None] = op
|
last_op_using_resource_tensor[None] = op
|
||||||
|
Loading…
Reference in New Issue
Block a user