diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6cdd88630f4..265cc08cfdb 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1384,6 +1384,37 @@ class ControlFlowTest(test.TestCase): lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) self.assertEqual(1, self.evaluate(r)) + @test_util.run_v1_only("b/120545219") + def testXLAGradInLoop(self): + # We have an optimization that moves certain reduction ops, this test makes + # sure we don't do that for XLA ops. + + # Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs" + # and "Shape" op. + input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) + input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) + def cond(i1, i2): + return False + + def body(i1, i2): + return math_ops.add(i1, i2), math_ops.add(i1, i2) + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + out1, _ = control_flow_ops.while_loop( + cond, body, (input1, input2), maximum_iterations=2) + g = gradients_impl.gradients(out1, [input1]) + + for op in out1.graph.get_operations(): + # Test that the "Shape" is directly passed to BroadcastGradientArgs + # instead of being pushed to the stack. + if op.type == "BroadcastGradientArgs": + self.assertEqual(op.inputs[0].op.type, "Shape") + self.assertEqual(op.inputs[1].op.type, "Shape") + xla_context.Exit() + + @test_util.disable_control_flow_v2("b/115776323 (max_iters)") @test_util.run_v1_only("b/120545219") def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 32a5db2c1ae..fbb41fc5843 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2455,7 +2455,12 @@ class WhileContext(ControlFlowContext): # store the tensor after the reduction as opposed to the tensor before # reduction, and therefore could significantly reduce memory consumption. # For now, we do this only for a few ops. - if op.type in {"Shape", "Size", "Rank"}: + # + # If in XLA context, do not move constant ops to forward pass as pushing to + # and popping from a stack removes the constant property of an op and breaks + # XLA compilation, which requires certain inputs to be constant for certain + # ops. + if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}: grad_ctxt = ops.get_default_graph()._get_control_flow_context() if grad_ctxt: grad_ctxt = grad_ctxt.GetWhileContext()