[MLIR] Extend canonicalization for ToBoolOp to handle most ranked tensors

- Canonicalize ToBool with scalar tensors to element comparison with
  0/empty string.
- Canonicalize ToBool with non-scalar ranked tensors to numElements != 0.

PiperOrigin-RevId: 327508290
Change-Id: I31ecb63decfa5995797c4ff867fd131c6654f55b
This commit is contained in:
Rahul Joshi 2020-08-19 14:34:37 -07:00 committed by TensorFlower Gardener
parent 699178a5d7
commit da671e7e88
3 changed files with 103 additions and 13 deletions

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=: -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
node {
name: "tf.Less"
op: "Less"

View File

@ -1796,26 +1796,57 @@ static LogicalResult Verify(TopKV2Op op) {
//===----------------------------------------------------------------------===//
namespace {
// If the input to ToBoolOp is a `tensor<i1>`, then the ToBoolOp is an identity
// function and can be removed.
class ToBoolOfZeroDBoolTensor : public OpRewritePattern<ToBoolOp> {
// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
// into an identity or an equality comparison.
class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
using OpRewritePattern<ToBoolOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToBoolOp op,
PatternRewriter &rewriter) const override {
if (auto type = op.getOperand().getType().dyn_cast<RankedTensorType>()) {
if (type.getRank() == 0 && type.getElementType().isInteger(1)) {
rewriter.replaceOp(op, op.getOperand());
return success();
}
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
// If the input is an unranked tensor, cannpt rewrite.
if (!type) return failure();
// Expected return type of the ToBool operation.
auto result_type = op.getResult().getType().cast<RankedTensorType>();
// If input is already a tensor<i1>, it can be folded into an identity.
if (type == result_type) {
rewriter.replaceOp(op, op.getOperand());
return success();
}
return failure();
if (type.getRank() == 0) {
// If the input is a scalar tensor, the ToBool can be expanded to
// element != 0 (for numerical values) or element == empty (for string).
Type element_type = type.getElementType();
Attribute zero_attr;
if (element_type.isIntOrFloat())
zero_attr = rewriter.getZeroAttr(type);
else if (element_type.isa<TF::StringType>())
zero_attr = DenseStringElementsAttr::get(type, {""});
if (!zero_attr) return failure();
auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
op, result_type, op.getOperand(), zero_const, false);
} else {
// If the input is a non-scalar ranked tensor, ToBool can be expanded
// to numElements != 0. numElements will be 0 iff one of the dimensions is
// zero.
bool any_zero =
llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
rewriter.replaceOpWithNewOp<TF::ConstOp>(
op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
}
return success();
}
};
} // namespace
void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ToBoolOfZeroDBoolTensor>(context);
results.insert<ToBoolOfRankedTensor>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -742,13 +742,72 @@ func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @ToBool_0DScalar
func @ToBool_0DScalar(%arg0: tensor<i1>) -> tensor<i1> {
// CHECK-LABEL: func @ToBool_0DScalarI1
func @ToBool_0DScalarI1(%arg0: tensor<i1>) -> tensor<i1> {
// CHECK: return %arg0
%0 = "tf.ToBool"(%arg0) : (tensor<i1>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_0DScalarInt
func @ToBool_0DScalarInt(%arg0: tensor<i32>) -> tensor<i1> {
// CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]])
// CHECK: return [[NE]]
%0 = "tf.ToBool"(%arg0) : (tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_0DScalarFloat
func @ToBool_0DScalarFloat(%arg0: tensor<f32>) -> tensor<i1> {
// CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]])
// CHECK: return [[NE]]
%0 = "tf.ToBool"(%arg0) : (tensor<f32>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_0DScalarString
func @ToBool_0DScalarString(%arg0: tensor<!tf.string>) -> tensor<i1> {
// CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor<!tf.string>} : () -> tensor<!tf.string>
// CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {incompatible_shape_error = false} : (tensor<!tf.string>, tensor<!tf.string>) -> tensor<i1>
// CHECK: return [[NE]] : tensor<i1>
%0 = "tf.ToBool"(%arg0) : (tensor<!tf.string>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_1DTensor
func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor<i1> {
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
// CHECK: return [[Const]]
%0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_1DTensorZeroDim
func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor<i1> {
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
// CHECK: return [[Const]]
%0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_2DTensor
func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor<i1> {
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
// CHECK: return [[Const]]
%0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @ToBool_2DTensorZeroDim
func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor<i1> {
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
// CHECK: return [[Const]]
%0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: testReadVariableOpOfCast
func @testReadVariableOpOfCast(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
%0 = "tf.Cast"(%arg0) : (tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<*x!tf.resource>