Do not move constant ops to forward pass in XLAContext.

Do not forward constant ops in XLA context as pushing to and popping from
a stack makes an op non-constant and breaks XLA compilation, which
requires certain inputs to be constant for certain ops.

PiperOrigin-RevId: 239268614
This commit is contained in:
Yunxing Dai 2019-03-19 14:22:39 -07:00 committed by TensorFlower Gardener
parent efb872186d
commit b5692e1e80
2 changed files with 37 additions and 1 deletions

View File

@ -1384,6 +1384,37 @@ class ControlFlowTest(test.TestCase):
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, self.evaluate(r)) self.assertEqual(1, self.evaluate(r))
@test_util.run_v1_only("b/120545219")
def testXLAGradInLoop(self):
# We have an optimization that moves certain reduction ops, this test makes
# sure we don't do that for XLA ops.
# Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs"
# and "Shape" op.
input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
def cond(i1, i2):
return False
def body(i1, i2):
return math_ops.add(i1, i2), math_ops.add(i1, i2)
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
out1, _ = control_flow_ops.while_loop(
cond, body, (input1, input2), maximum_iterations=2)
g = gradients_impl.gradients(out1, [input1])
for op in out1.graph.get_operations():
# Test that the "Shape" is directly passed to BroadcastGradientArgs
# instead of being pushed to the stack.
if op.type == "BroadcastGradientArgs":
self.assertEqual(op.inputs[0].op.type, "Shape")
self.assertEqual(op.inputs[1].op.type, "Shape")
xla_context.Exit()
@test_util.disable_control_flow_v2("b/115776323 (max_iters)") @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):

View File

@ -2455,7 +2455,12 @@ class WhileContext(ControlFlowContext):
# store the tensor after the reduction as opposed to the tensor before # store the tensor after the reduction as opposed to the tensor before
# reduction, and therefore could significantly reduce memory consumption. # reduction, and therefore could significantly reduce memory consumption.
# For now, we do this only for a few ops. # For now, we do this only for a few ops.
if op.type in {"Shape", "Size", "Rank"}: #
# If in XLA context, do not move constant ops to forward pass as pushing to
# and popping from a stack removes the constant property of an op and breaks
# XLA compilation, which requires certain inputs to be constant for certain
# ops.
if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
grad_ctxt = ops.get_default_graph()._get_control_flow_context() grad_ctxt = ops.get_default_graph()._get_control_flow_context()
if grad_ctxt: if grad_ctxt:
grad_ctxt = grad_ctxt.GetWhileContext() grad_ctxt = grad_ctxt.GetWhileContext()