From 1e7a91e26abd93086a376ef6212bdf463c747dca Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 3 Dec 2019 13:23:40 -0800 Subject: [PATCH] Don't canonicalize away casts between different types. The previous code would incorrectly only check the element type, rather than exact type equality. Failure to do so can trigger many different kinds of verifier errors. PiperOrigin-RevId: 283609199 Change-Id: I3bbd8b41a6a2c8edd2e9d97b32eda78e546975ac --- .../compiler/mlir/tensorflow/tests/canonicalize.mlir | 11 +++++++++++ .../mlir/tensorflow/transforms/canonicalize.td | 6 ++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 2db64262094..a2cc33a8201 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -101,6 +101,17 @@ func @testDifferentCastType(%arg0: tensor<8x16x32x64xf32>) -> (tensor<8x16x32x64 // CHECK: return %0, %1 } +// CHECK-LABEL: testCompatibleCastType +func @testCompatibleCastType(%arg0: tensor) -> (tensor<10xf32>, tensor<10xf32>) { + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32> + %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32> + return %0, %1: tensor<10xf32>, tensor<10xf32> + +// CHECK: %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor) -> tensor<10xf32> +// CHECK: %1 = "tf.Cast"(%arg0) {Truncate = true} : (tensor) -> tensor<10xf32> +// CHECK: return %0, %1 +} + // CHECK-LABEL: testSameCastTypeAcrossBasicBlocks func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x32x64xf32> { ^bb0(%arg0: tensor<8x16x32x64xf32>): diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index beb7583fc57..7c38b78f239 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -22,6 +22,9 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def SingleResultAndOperandHaveSameElementType : Constraint< CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; +def SingleResultAndOperandHaveSameType : Constraint< + CPred<"$0->getType() == $1->getType()">>; + def IsRank2Tensor : Type, "Rank 2 tensor">; //===----------------------------------------------------------------------===// @@ -75,8 +78,7 @@ def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)), def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate), (replaceWithValue $arg), - [(SingleResultAndOperandHaveSameElementType $res, - $arg)]>; + [(SingleResultAndOperandHaveSameType $res, $arg)]>; //===----------------------------------------------------------------------===// // Conj op patterns.