diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index b281df37d1f..fb9f1c53abc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -2122,7 +2122,7 @@ func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK-LABEL: func @cast_c2f func @cast_c2f(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { - //CHECK: "mhlo.convert"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: tf.Cast %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> return %0 : tensor<2xf32> } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 6ae8941e9b9..0ab98208c21 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -618,7 +618,7 @@ foreach Mapping = [ // TODO(bixia): Lower Cast with a Complex type source operand or with // Truncate=True for floating point value conversions. -def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), +def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse), (HLO_ConvertOp $arg)>; def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), @@ -637,8 +637,8 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>; def BothElementTypesSameWidthIntOrFloat : Constraint, "element types must be integers or floats of same width">; diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 3dd50f4cc75..fae750a09f1 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -887,8 +887,6 @@ class UnaryOpsTest(xla_test.XLATestCase): [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), expected=np.array([14., 22.], dtype=np.float32)) - @test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation" - ) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] types = { @@ -936,8 +934,6 @@ class UnaryOpsTest(xla_test.XLATestCase): src, expected=dst) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.Bitcast compilation") def testBitcast(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32),