Do not move constant ops to forward pass in XLAContext.
Do not forward constant ops in XLA context as pushing to and popping from a stack makes an op non-constant and breaks XLA compilation, which requires certain inputs to be constant for certain ops. PiperOrigin-RevId: 239268614
This commit is contained in:
parent
efb872186d
commit
b5692e1e80
@ -1384,6 +1384,37 @@ class ControlFlowTest(test.TestCase):
|
||||
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
|
||||
self.assertEqual(1, self.evaluate(r))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testXLAGradInLoop(self):
|
||||
# We have an optimization that moves certain reduction ops, this test makes
|
||||
# sure we don't do that for XLA ops.
|
||||
|
||||
# Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs"
|
||||
# and "Shape" op.
|
||||
input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
|
||||
input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
|
||||
def cond(i1, i2):
|
||||
return False
|
||||
|
||||
def body(i1, i2):
|
||||
return math_ops.add(i1, i2), math_ops.add(i1, i2)
|
||||
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
out1, _ = control_flow_ops.while_loop(
|
||||
cond, body, (input1, input2), maximum_iterations=2)
|
||||
g = gradients_impl.gradients(out1, [input1])
|
||||
|
||||
for op in out1.graph.get_operations():
|
||||
# Test that the "Shape" is directly passed to BroadcastGradientArgs
|
||||
# instead of being pushed to the stack.
|
||||
if op.type == "BroadcastGradientArgs":
|
||||
self.assertEqual(op.inputs[0].op.type, "Shape")
|
||||
self.assertEqual(op.inputs[1].op.type, "Shape")
|
||||
xla_context.Exit()
|
||||
|
||||
|
||||
@test_util.disable_control_flow_v2("b/115776323 (max_iters)")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
|
||||
|
@ -2455,7 +2455,12 @@ class WhileContext(ControlFlowContext):
|
||||
# store the tensor after the reduction as opposed to the tensor before
|
||||
# reduction, and therefore could significantly reduce memory consumption.
|
||||
# For now, we do this only for a few ops.
|
||||
if op.type in {"Shape", "Size", "Rank"}:
|
||||
#
|
||||
# If in XLA context, do not move constant ops to forward pass as pushing to
|
||||
# and popping from a stack removes the constant property of an op and breaks
|
||||
# XLA compilation, which requires certain inputs to be constant for certain
|
||||
# ops.
|
||||
if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
|
||||
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
|
||||
if grad_ctxt:
|
||||
grad_ctxt = grad_ctxt.GetWhileContext()
|
||||
|
Loading…
Reference in New Issue
Block a user