[MLIR] Extend canonicalization for ToBoolOp to handle most ranked tensors
- Canonicalize ToBool with scalar tensors to element comparison with 0/empty string. - Canonicalize ToBool with non-scalar ranked tensors to numElements != 0. PiperOrigin-RevId: 327508290 Change-Id: I31ecb63decfa5995797c4ff867fd131c6654f55b
This commit is contained in:
parent
699178a5d7
commit
da671e7e88
@ -1,4 +1,4 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=: -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
|
||||
node {
|
||||
name: "tf.Less"
|
||||
op: "Less"
|
||||
|
@ -1796,26 +1796,57 @@ static LogicalResult Verify(TopKV2Op op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// If the input to ToBoolOp is a `tensor<i1>`, then the ToBoolOp is an identity
|
||||
// function and can be removed.
|
||||
class ToBoolOfZeroDBoolTensor : public OpRewritePattern<ToBoolOp> {
|
||||
// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded
|
||||
// into an identity or an equality comparison.
|
||||
class ToBoolOfRankedTensor : public OpRewritePattern<ToBoolOp> {
|
||||
using OpRewritePattern<ToBoolOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ToBoolOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (auto type = op.getOperand().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (type.getRank() == 0 && type.getElementType().isInteger(1)) {
|
||||
rewriter.replaceOp(op, op.getOperand());
|
||||
return success();
|
||||
}
|
||||
auto type = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
// If the input is an unranked tensor, cannpt rewrite.
|
||||
if (!type) return failure();
|
||||
|
||||
// Expected return type of the ToBool operation.
|
||||
auto result_type = op.getResult().getType().cast<RankedTensorType>();
|
||||
|
||||
// If input is already a tensor<i1>, it can be folded into an identity.
|
||||
if (type == result_type) {
|
||||
rewriter.replaceOp(op, op.getOperand());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
|
||||
if (type.getRank() == 0) {
|
||||
// If the input is a scalar tensor, the ToBool can be expanded to
|
||||
// element != 0 (for numerical values) or element == empty (for string).
|
||||
Type element_type = type.getElementType();
|
||||
Attribute zero_attr;
|
||||
if (element_type.isIntOrFloat())
|
||||
zero_attr = rewriter.getZeroAttr(type);
|
||||
else if (element_type.isa<TF::StringType>())
|
||||
zero_attr = DenseStringElementsAttr::get(type, {""});
|
||||
|
||||
if (!zero_attr) return failure();
|
||||
|
||||
auto zero_const = rewriter.create<TF::ConstOp>(op.getLoc(), zero_attr);
|
||||
rewriter.replaceOpWithNewOp<TF::NotEqualOp>(
|
||||
op, result_type, op.getOperand(), zero_const, false);
|
||||
} else {
|
||||
// If the input is a non-scalar ranked tensor, ToBool can be expanded
|
||||
// to numElements != 0. numElements will be 0 iff one of the dimensions is
|
||||
// zero.
|
||||
bool any_zero =
|
||||
llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; });
|
||||
rewriter.replaceOpWithNewOp<TF::ConstOp>(
|
||||
op, result_type, DenseElementsAttr::get(result_type, {!any_zero}));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ToBoolOfZeroDBoolTensor>(context);
|
||||
results.insert<ToBoolOfRankedTensor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -742,13 +742,72 @@ func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_0DScalar
|
||||
func @ToBool_0DScalar(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
// CHECK-LABEL: func @ToBool_0DScalarI1
|
||||
func @ToBool_0DScalarI1(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
// CHECK: return %arg0
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_0DScalarInt
|
||||
func @ToBool_0DScalarInt(%arg0: tensor<i32>) -> tensor<i1> {
|
||||
// CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
|
||||
// CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]])
|
||||
// CHECK: return [[NE]]
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_0DScalarFloat
|
||||
func @ToBool_0DScalarFloat(%arg0: tensor<f32>) -> tensor<i1> {
|
||||
// CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]])
|
||||
// CHECK: return [[NE]]
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<f32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_0DScalarString
|
||||
func @ToBool_0DScalarString(%arg0: tensor<!tf.string>) -> tensor<i1> {
|
||||
// CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
// CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {incompatible_shape_error = false} : (tensor<!tf.string>, tensor<!tf.string>) -> tensor<i1>
|
||||
// CHECK: return [[NE]] : tensor<i1>
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<!tf.string>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_1DTensor
|
||||
func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor<i1> {
|
||||
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
// CHECK: return [[Const]]
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_1DTensorZeroDim
|
||||
func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor<i1> {
|
||||
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
// CHECK: return [[Const]]
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_2DTensor
|
||||
func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor<i1> {
|
||||
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
// CHECK: return [[Const]]
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @ToBool_2DTensorZeroDim
|
||||
func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor<i1> {
|
||||
// CHECK: [[Const:%.*]] = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
// CHECK: return [[Const]]
|
||||
%0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testReadVariableOpOfCast
|
||||
func @testReadVariableOpOfCast(%arg0: tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<8x40xf32> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<!tf.resource<tensor<8x40xf32>>>) -> tensor<*x!tf.resource>
|
||||
|
Loading…
Reference in New Issue
Block a user