From 2d2b11251752900f70b5a8f79b0a2501b25f3a67 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 15 Oct 2020 15:39:04 -0700 Subject: [PATCH] Ensure that if an op is outside compiled, its gradient gets outside compiled as well in TF2 (similar to TF1) when colocate_gradients_with_ops=True. A couple changes were needed for this: 1. EnterGradientColocation needs to be called on the enclosing XlaContext in the parent graph hierarchy since FuncGraphs do not inherit control flow contexts. 2. Make behavior of gradients codepath in outside compilation consistent with tf.tpu.outside_compilation by directly executing the gradients ops under a OutsideCompilationV2Context. PiperOrigin-RevId: 337397123 Change-Id: Ia47298da90d7094bee920398e6a0b4342469c9d8 --- tensorflow/python/framework/ops.py | 26 +++++++++++++--- tensorflow/python/tpu/tpu.py | 27 ++++++++++++++++ .../tpu/tpu_outside_compilation_test.py | 31 +++++++++++++++++++ 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ccc1daf721c..47561b2c115 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4332,12 +4332,16 @@ class Graph(object): def _colocate_with_for_gradient(self, op, gradient_uid, ignore_existing=False): with self.colocate_with(op, ignore_existing): - if gradient_uid is not None and self._control_flow_context is not None: - self._control_flow_context.EnterGradientColocation(op, gradient_uid) - try: + if gradient_uid is not None: + ctx = _get_enclosing_context(self) + if ctx is not None: + ctx.EnterGradientColocation(op, gradient_uid) + try: + yield + finally: + ctx.ExitGradientColocation(op, gradient_uid) + else: yield - finally: - self._control_flow_context.ExitGradientColocation(op, gradient_uid) else: yield @@ -6955,3 +6959,15 @@ def set_int_list_attr(op, attr_name, ints): """TF internal method used to set a list(int) attribute in the node_def.""" ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access + + +def _get_enclosing_context(graph): + # pylint: disable=protected-access + if graph is None: + return None + + if graph._control_flow_context is not None: + return graph._control_flow_context + + if graph.building_function and hasattr(graph, "outer_graph"): + return _get_enclosing_context(graph.outer_graph) diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index c6b5a256b42..084ec1f3dba 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -288,6 +288,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = None self._oc_dev_fn_stack = None self._outside_compilation_cluster = None + self._outside_compilation_v2_context = None self._outside_compilation_counter = 0 self._in_gradient_colocation = None self._gradient_colocation_stack = [] @@ -379,6 +380,21 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def EnterGradientColocation(self, op, gradient_uid): if op is not None: + if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access + # If we are in TF 2 functions (control flow V2 functions, or + # tf.function()), we need to attach _xla_outside_compilation attribute + # directly because we are not in TPUReplicateContext. + try: + outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR).decode("ascii") + except ValueError: + # The attr was not present: do nothing. + return + parts = outside_attr.split(".") + cluster = parts[0] + "." + gradient_uid + self._outside_compilation_v2_context = OutsideCompilationV2Context( + cluster) + self._outside_compilation_v2_context.Enter() + return self._gradient_colocation_stack.append(op) if not self._outside_compilation_cluster: try: @@ -418,6 +434,17 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def ExitGradientColocation(self, op, gradient_uid): if op is not None: + if ops.get_default_graph()._control_flow_context is None: # pylint: disable=protected-access + # Inside a TF2 tf.function or control flow graph and `op` was not + # marked to be outside compiled. + assert self._outside_compilation_v2_context is None + return + if self._outside_compilation_v2_context is not None: + # Inside a TF2 tf.function or control flow graph and `op` was + # marked to be outside compiled. + self._outside_compilation_v2_context.Exit() + self._outside_compilation_v2_context = None + return if not self._gradient_colocation_stack: raise errors.InternalError( op.node_def, op, diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 4eb6429f3c8..30bfabdff7c 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import tf_record from tensorflow.python.ops import array_ops @@ -450,6 +451,36 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): strategy.experimental_local_results(train_step()), constant_op.constant(2916., shape=(strategy.num_replicas_in_sync))) + def testColocateGradientWithOutsideCompiledOp(self): + strategy = get_tpu_strategy() + + @def_function.function + def train_step(): + + @def_function.function + def tpu_fn(x): + x1 = tpu.outside_compilation(math_ops.sqrt, x) + grad = gradients_impl.gradients([x1], [x], + colocate_gradients_with_ops=True)[0] + sqrt = [ + op for op in ops.get_default_graph().get_operations() + if op.type == "Sqrt" + ][0] + sqrt_grad = [ + op for op in ops.get_default_graph().get_operations() + if op.type == "SqrtGrad" + ][0] + assert sqrt.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) == b"0" + assert (sqrt_grad.get_attr( + tpu._OUTSIDE_COMPILATION_ATTR) == b"0.gradients/uid") + return grad + + return strategy.run(tpu_fn, args=(25.0,)) + + self.assertAllEqual( + strategy.experimental_local_results(train_step()), + constant_op.constant(.1, shape=(strategy.num_replicas_in_sync))) + class OutsideCompilationOnUnsupportedOpTest(test.TestCase, parameterized.TestCase):