Do not move constant ops to forward pass in XLAContext.
Do not forward constant ops in XLA context as pushing to and popping from a stack makes an op non-constant and breaks XLA compilation, which requires certain inputs to be constant for certain ops. PiperOrigin-RevId: 239268614
This commit is contained in:
parent
efb872186d
commit
b5692e1e80
@ -1384,6 +1384,37 @@ class ControlFlowTest(test.TestCase):
|
|||||||
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
|
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
|
||||||
self.assertEqual(1, self.evaluate(r))
|
self.assertEqual(1, self.evaluate(r))
|
||||||
|
|
||||||
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
def testXLAGradInLoop(self):
|
||||||
|
# We have an optimization that moves certain reduction ops, this test makes
|
||||||
|
# sure we don't do that for XLA ops.
|
||||||
|
|
||||||
|
# Use dynamic inputs, which triggers the creation of "BroadcastGradientArgs"
|
||||||
|
# and "Shape" op.
|
||||||
|
input1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
|
||||||
|
input2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
|
||||||
|
def cond(i1, i2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def body(i1, i2):
|
||||||
|
return math_ops.add(i1, i2), math_ops.add(i1, i2)
|
||||||
|
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
out1, _ = control_flow_ops.while_loop(
|
||||||
|
cond, body, (input1, input2), maximum_iterations=2)
|
||||||
|
g = gradients_impl.gradients(out1, [input1])
|
||||||
|
|
||||||
|
for op in out1.graph.get_operations():
|
||||||
|
# Test that the "Shape" is directly passed to BroadcastGradientArgs
|
||||||
|
# instead of being pushed to the stack.
|
||||||
|
if op.type == "BroadcastGradientArgs":
|
||||||
|
self.assertEqual(op.inputs[0].op.type, "Shape")
|
||||||
|
self.assertEqual(op.inputs[1].op.type, "Shape")
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/115776323 (max_iters)")
|
@test_util.disable_control_flow_v2("b/115776323 (max_iters)")
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
|
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
|
||||||
|
@ -2455,7 +2455,12 @@ class WhileContext(ControlFlowContext):
|
|||||||
# store the tensor after the reduction as opposed to the tensor before
|
# store the tensor after the reduction as opposed to the tensor before
|
||||||
# reduction, and therefore could significantly reduce memory consumption.
|
# reduction, and therefore could significantly reduce memory consumption.
|
||||||
# For now, we do this only for a few ops.
|
# For now, we do this only for a few ops.
|
||||||
if op.type in {"Shape", "Size", "Rank"}:
|
#
|
||||||
|
# If in XLA context, do not move constant ops to forward pass as pushing to
|
||||||
|
# and popping from a stack removes the constant property of an op and breaks
|
||||||
|
# XLA compilation, which requires certain inputs to be constant for certain
|
||||||
|
# ops.
|
||||||
|
if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
|
||||||
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
|
grad_ctxt = ops.get_default_graph()._get_control_flow_context()
|
||||||
if grad_ctxt:
|
if grad_ctxt:
|
||||||
grad_ctxt = grad_ctxt.GetWhileContext()
|
grad_ctxt = grad_ctxt.GetWhileContext()
|
||||||
|
Loading…
Reference in New Issue
Block a user