diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ccc1daf721c..47561b2c115 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4332,12 +4332,16 @@ class Graph(object): def _colocate_with_for_gradient(self, op, gradient_uid, ignore_existing=False): with self.colocate_with(op, ignore_existing): - if gradient_uid is not None and self._control_flow_context is not None: - self._control_flow_context.EnterGradientColocation(op, gradient_uid) - try: + if gradient_uid is not None: + ctx = _get_enclosing_context(self) + if ctx is not None: + ctx.EnterGradientColocation(op, gradient_uid) + try: + yield + finally: + ctx.ExitGradientColocation(op, gradient_uid) + else: yield - finally: - self._control_flow_context.ExitGradientColocation(op, gradient_uid) else: yield @@ -6955,3 +6959,15 @@ def set_int_list_attr(op, attr_name, ints): """TF internal method used to set a list(int) attribute in the node_def.""" ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access + + +def _get_enclosing_context(graph): + # pylint: disable=protected-access + if graph is None: + return None + + if graph._control_flow_context is not None: + return graph._control_flow_context + + if graph.building_function and hasattr(graph, "outer_graph"): + return _get_enclosing_context(graph.outer_graph) diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index c6b5a256b42..084ec1f3dba 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -288,6 +288,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = None self._oc_dev_fn_stack = None self._outside_compilation_cluster = None + self._outside_compilation_v2_context = None self._outside_compilation_counter = 0 self._in_gradient_colocation = None self._gradient_colocation_stack = [] @@ -379,6 +380,21 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def EnterGradientColocation(self, op, gradient_uid): if op is not None: + if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access + # If we are in TF 2 functions (control flow V2 functions, or + # tf.function()), we need to attach _xla_outside_compilation attribute + # directly because we are not in TPUReplicateContext. + try: + outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") + except ValueError: + # The attr was not present: do nothing. + return + parts = outside_attr.split(".") + cluster = parts[0] + "." + gradient_uid + self._outside_compilation_v2_context = OutsideCompilationV2Context( + cluster) + self._outside_compilation_v2_context.Enter() + return self._gradient_colocation_stack.append(op) if not self._outside_compilation_cluster: try: @@ -418,6 +434,17 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def ExitGradientColocation(self, op, gradient_uid): if op is not None: + if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access + # Inside a TF2 tf.function or control flow graph and `op` was not + # marked to be outside compiled. + assert self._outside_compilation_v2_context is None + return + if self._outside_compilation_v2_context is not None: + # Inside a TF2 tf.function or control flow graph and `op` was + # marked to be outside compiled. + self._outside_compilation_v2_context.Exit() + self._outside_compilation_v2_context = None + return if not self._gradient_colocation_stack: raise errors.InternalError( op.node_def, op, diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 4eb6429f3c8..30bfabdff7c 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import array_ops @@ -450,6 +451,36 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): strategy.experimental_local_results(train_step()), constant_op.constant(2916., shape=(strategy.num_replicas_in_sync))) + def testColocateGradientWithOutsideCompiledOp(self): + strategy = get_tpu_strategy() + + @def_function.function + def train_step(): + + @def_function.function + def tpu_fn(x): + x1 = tpu.outside_compilation(math_ops.sqrt, x) + grad = gradients_impl.gradients([x1], [x], + colocate_gradients_with_ops=True)[0] + sqrt = [ + op for op in ops.get_default_graph().get_operations() + if op.type == "Sqrt" + ][0] + sqrt_grad = [ + op for op in ops.get_default_graph().get_operations() + if op.type == "SqrtGrad" + ][0] + assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0" + assert (sqrt_grad.get_attr( + tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid") + return grad + + return strategy.run(tpu_fn, args=(25.0,)) + + self.assertAllEqual( + strategy.experimental_local_results(train_step()), + constant_op.constant(.1, shape=(strategy.num_replicas_in_sync))) + class OutsideCompilationOnUnsupportedOpTest(test.TestCase, parameterized.TestCase):