From 5ff27167b274a7471b35ba80491004093a3f6133 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 20 Dec 2018 14:53:11 -0800 Subject: [PATCH] Don't automatically add control deps to collective ops. These ops need to run asynchronously to avoid deadlock. PiperOrigin-RevId: 226397820 --- .../python/framework/auto_control_deps.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index a72ded11314..a7d61417bf6 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -29,13 +29,22 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator +# Op types that should not run in program order, e.g. because they need to run +# asynchronously to avoid deadlock. +ASYNC_STATEFUL_OPS = [ + "CollectiveReduce", + "CollectiveBcastSend", + "CollectiveBcastRecv", +] + class AutomaticControlDependencies(object): """Context manager to automatically add control dependencies. Code under this context manager will act as if a sensible set of control dependencies were present. More specifically: - 1. All stateful ops in the scope will execute + 1. All stateful ops in the scope will execute (with the exception of ops in + ASYNC_STATEFUL_OPS) 2. Stateful ops which modify the same resource will execute in program order Note: creating variables in an automatic control dependencies context is not @@ -223,7 +232,8 @@ class AutomaticControlDependencies(object): control_inputs = set() # Ensure stateful ops run if (op.type not in self._graph._registered_ops # pylint: disable=protected-access - or self._graph._registered_ops[op.type].is_stateful): # pylint: disable=protected-access + or (self._graph._registered_ops[op.type].is_stateful # pylint: disable=protected-access + and op.type not in ASYNC_STATEFUL_OPS)): ops_which_must_run.add(op) # Ignore switches (they're handled separately) if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: @@ -255,8 +265,8 @@ class AutomaticControlDependencies(object): if inp in merge_for_resource: merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[inp] = op - if (op.op_def.is_stateful and not found_resource - and op._control_flow_context is None): # pylint: disable=protected-access + if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS + and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access if None in last_op_using_resource_tensor: op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access last_op_using_resource_tensor[None] = op