Don't canonicalize away casts between different types.

The previous code would incorrectly only check the element type, rather than exact type equality. Failure to do so can trigger many different kinds of verifier errors.

PiperOrigin-RevId: 283609199
Change-Id: I3bbd8b41a6a2c8edd2e9d97b32eda78e546975ac
This commit is contained in:
Sean Silva 2019-12-03 13:23:40 -08:00 committed by TensorFlower Gardener
parent 0c3af9326f
commit 1e7a91e26a
2 changed files with 15 additions and 2 deletions

View File

@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64
// CHECK: return %0, %1
}
// CHECK-LABEL: testCompatibleCastType
func @testCompatibleCastType(%arg0: tensor<?xf32>) -> (tensor<10xf32>, tensor<10xf32>) {
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
%1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
return %0, %1: tensor<10xf32>, tensor<10xf32>
// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<?xf32>) -> tensor<10xf32>
// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor<?xf32>) -> tensor<10xf32>
// CHECK: return %0, %1
}
// CHECK-LABEL: testSameCastTypeAcrossBasicBlocks
func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> {
^bb0(%arg0: tensor<8x16x32x64xf32>):

View File

@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def SingleResultAndOperandHaveSameElementType : Constraint<
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
def SingleResultAndOperandHaveSameType : Constraint<
CPred<"$0->getType() == $1->getType()">>;
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
//===----------------------------------------------------------------------===//
@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate),
(replaceWithValue $arg),
[(SingleResultAndOperandHaveSameElementType $res,
$arg)]>;
[(SingleResultAndOperandHaveSameType $res, $arg)]>;
//===----------------------------------------------------------------------===//
// Conj op patterns.