diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 2db64262094..a2cc33a8201 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64 // CHECK: return %0, %1 } +// CHECK-LABEL: testCompatibleCastType +func @testCompatibleCastType(%arg0: tensor) -> (tensor<10xf32>, tensor<10xf32>) { + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32> + %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32> + return %0, %1: tensor<10xf32>, tensor<10xf32> + +// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32> +// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32> +// CHECK: return %0, %1 +} + // CHECK-LABEL: testSameCastTypeAcrossBasicBlocks func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> { ^bb0(%arg0: tensor<8x16x32x64xf32>): diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index beb7583fc57..7c38b78f239 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def SingleResultAndOperandHaveSameElementType : Constraint< CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; +def SingleResultAndOperandHaveSameType : Constraint< + CPred<"$0->getType() == $1->getType()">>; + def IsRank2Tensor : Type, "Rank 2 tensor">; //===----------------------------------------------------------------------===// @@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)), def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate), (replaceWithValue $arg), - [(SingleResultAndOperandHaveSameElementType $res, - $arg)]>; + [(SingleResultAndOperandHaveSameType $res, $arg)]>; //===----------------------------------------------------------------------===// // Conj op patterns.