[TF2XLA] Fix a copy'n paste bug in constant propagation in IF kernel

PiperOrigin-RevId: 355491661
Change-Id: If60b3a4713c11c825d2196f6afccc3be53bfe9b3
This commit is contained in:
George Karpenkov 2021-02-03 15:08:19 -08:00 committed by TensorFlower Gardener
parent 952fe2581d
commit 2ddb6e97f6
2 changed files with 14 additions and 1 deletions

View File

@ -216,7 +216,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, then_branch_,
&then_branch_must_be_const_nodes,
&then_body));
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, then_branch_,
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, else_branch_,
&else_branch_must_be_const_nodes,
&else_body));

View File

@ -977,6 +977,19 @@ class DefFunctionTest(xla_test.XLATestCase):
'def_function_xla_jit_test.py'):
update_var(arg)
def testMustBeConstantInsideCondition(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(jit_compile=True)
def f(x, d):
if math_ops.reduce_all(
math_ops.greater(x, random_ops.random_normal([10, 10]))):
return array_ops.reshape(x * 2, constant_op.constant([100]))
else:
return array_ops.reshape(x * 3, d)
f(random_ops.random_normal([10, 10]), constant_op.constant([100]))
if __name__ == '__main__':
ops.enable_eager_execution()