From da671e7e88ee60154a064e4be5ec7fbac7947929 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 19 Aug 2020 14:34:37 -0700 Subject: [PATCH] [MLIR] Extend canonicalization for ToBoolOp to handle most ranked tensors - Canonicalize ToBool with scalar tensors to element comparison with 0/empty string. - Canonicalize ToBool with non-scalar ranked tensors to numElements != 0. PiperOrigin-RevId: 327508290 Change-Id: I31ecb63decfa5995797c4ff867fd131c6654f55b --- .../mlir/lite/tests/end2end/if_op.pbtxt | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 51 ++++++++++++--- .../mlir/tensorflow/tests/canonicalize.mlir | 63 ++++++++++++++++++- 3 files changed, 103 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt index f482e3db6b9..a7f6040f211 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s +# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=: -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s node { name: "tf.Less" op: "Less" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 925a2af3f8b..45c32f631eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1796,26 +1796,57 @@ static LogicalResult Verify(TopKV2Op op) { //===----------------------------------------------------------------------===// namespace { -// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity -// function and can be removed. -class ToBoolOfZeroDBoolTensor : public OpRewritePattern { +// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded +// into an identity or an equality comparison. +class ToBoolOfRankedTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToBoolOp op, PatternRewriter &rewriter) const override { - if (auto type = op.getOperand().getType().dyn_cast()) { - if (type.getRank() == 0 && type.getElementType().isInteger(1)) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } + auto type = op.getOperand().getType().dyn_cast(); + // If the input is an unranked tensor, cannpt rewrite. + if (!type) return failure(); + + // Expected return type of the ToBool operation. + auto result_type = op.getResult().getType().cast(); + + // If input is already a tensor, it can be folded into an identity. + if (type == result_type) { + rewriter.replaceOp(op, op.getOperand()); + return success(); } - return failure(); + + if (type.getRank() == 0) { + // If the input is a scalar tensor, the ToBool can be expanded to + // element != 0 (for numerical values) or element == empty (for string). + Type element_type = type.getElementType(); + Attribute zero_attr; + if (element_type.isIntOrFloat()) + zero_attr = rewriter.getZeroAttr(type); + else if (element_type.isa()) + zero_attr = DenseStringElementsAttr::get(type, {""}); + + if (!zero_attr) return failure(); + + auto zero_const = rewriter.create(op.getLoc(), zero_attr); + rewriter.replaceOpWithNewOp( + op, result_type, op.getOperand(), zero_const, false); + } else { + // If the input is a non-scalar ranked tensor, ToBool can be expanded + // to numElements != 0. numElements will be 0 iff one of the dimensions is + // zero. + bool any_zero = + llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; }); + rewriter.replaceOpWithNewOp( + op, result_type, DenseElementsAttr::get(result_type, {!any_zero})); + } + return success(); } }; } // namespace void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 8c3e8dc41a6..0227b4fdf9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -742,13 +742,72 @@ func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @ToBool_0DScalar -func @ToBool_0DScalar(%arg0: tensor) -> tensor { +// CHECK-LABEL: func @ToBool_0DScalarI1 +func @ToBool_0DScalarI1(%arg0: tensor) -> tensor { // CHECK: return %arg0 %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @ToBool_0DScalarInt +func @ToBool_0DScalarInt(%arg0: tensor) -> tensor { + // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) + // CHECK: return [[NE]] + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_0DScalarFloat +func @ToBool_0DScalarFloat(%arg0: tensor) -> tensor { + // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) + // CHECK: return [[NE]] + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_0DScalarString +func @ToBool_0DScalarString(%arg0: tensor) -> tensor { + // CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor} : () -> tensor + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {incompatible_shape_error = false} : (tensor, tensor) -> tensor + // CHECK: return [[NE]] : tensor + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_1DTensor +func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_1DTensorZeroDim +func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_2DTensor +func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_2DTensorZeroDim +func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor + return %0 : tensor +} + // CHECK-LABEL: testReadVariableOpOfCast func @testReadVariableOpOfCast(%arg0: tensor>>) -> tensor<8x40xf32> { %0 = "tf.Cast"(%arg0) : (tensor>>) -> tensor<*x!tf.resource>