diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 3a88fcf4879..3b74ba0fdf5 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -216,7 +216,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, then_branch_, &then_branch_must_be_const_nodes, &then_body)); - OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, then_branch_, + OP_REQUIRES_OK(ctx, FindMustBeConstNodes(ctx, else_branch_, &else_branch_must_be_const_nodes, &else_body)); diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index c7f478c4dee..88f0db3755b 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -977,6 +977,19 @@ class DefFunctionTest(xla_test.XLATestCase): 'def_function_xla_jit_test.py'): update_var(arg) + def testMustBeConstantInsideCondition(self): + with ops.device('device:{}:0'.format(self.device)): + + @def_function.function(jit_compile=True) + def f(x, d): + if math_ops.reduce_all( + math_ops.greater(x, random_ops.random_normal([10, 10]))): + return array_ops.reshape(x * 2, constant_op.constant([100])) + else: + return array_ops.reshape(x * 3, d) + + f(random_ops.random_normal([10, 10]), constant_op.constant([100])) + if __name__ == '__main__': ops.enable_eager_execution()