diff --git a/tensorflow/python/tf_program/tests/mlir_gen_test.py b/tensorflow/python/tf_program/tests/mlir_gen_test.py index 5e1ca5b36e0..49737352d73 100644 --- a/tensorflow/python/tf_program/tests/mlir_gen_test.py +++ b/tensorflow/python/tf_program/tests/mlir_gen_test.py @@ -83,7 +83,7 @@ class MLIRGenTest(MLIRGenTestBase): CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1 CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1> CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1> - return %[[r1]] : tensor<*xi1> + CHECK: return %[[r1]] : tensor<*xi1> """ self._check_code(mlir_code, exp_mlir_code) @@ -158,7 +158,7 @@ class MLIRGenTest(MLIRGenTestBase): mlir_code = mlir_gen(test_fn) exp_mlir_code = r""" CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32 - + CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor) -> tensor<*xi1> CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( { CHECK: return %{{[0-9]+}} : tensor @@ -222,7 +222,7 @@ class MLIRGenTest(MLIRGenTestBase): CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor) -> tensor<*xi1> CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor) -> tensor<*xi1> CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - + CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( { CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32> CHECK-NEXT: }, {