[TF2XLA] Fix a copy'n paste bug in constant propagation in IF kernel
PiperOrigin-RevId: 355491661 Change-Id: If60b3a4713c11c825d2196f6afccc3be53bfe9b3
This commit is contained in:
parent
952fe2581d
commit
2ddb6e97f6
@ -216,7 +216,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, then_branch_,
|
||||
&then_branch_must_be_const_nodes,
|
||||
&then_body));
|
||||
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, then_branch_,
|
||||
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, else_branch_,
|
||||
&else_branch_must_be_const_nodes,
|
||||
&else_body));
|
||||
|
||||
|
@ -977,6 +977,19 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
'def_function_xla_jit_test.py'):
|
||||
update_var(arg)
|
||||
|
||||
def testMustBeConstantInsideCondition(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f(x, d):
|
||||
if math_ops.reduce_all(
|
||||
math_ops.greater(x, random_ops.random_normal([10, 10]))):
|
||||
return array_ops.reshape(x * 2, constant_op.constant([100]))
|
||||
else:
|
||||
return array_ops.reshape(x * 3, d)
|
||||
|
||||
f(random_ops.random_normal([10, 10]), constant_op.constant([100]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
x
Reference in New Issue
Block a user