cast down i64 to i32 for begin/size in slice op as NNAPI does not support i64 begin/size, this should be safe for other targets.
PiperOrigin-RevId: 307998537 Change-Id: I82898af697c82ea19605f0ffe45e845c530fa081
This commit is contained in:
parent
5d615f2f05
commit
35e3d0f075
@ -1014,6 +1014,75 @@ static LogicalResult Verify(SliceOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op,
|
||||
RankedTensorType value_type,
|
||||
Location loc, OpBuilder *builder) {
|
||||
if (input_op == nullptr) return nullptr;
|
||||
|
||||
mlir::DenseIntElementsAttr attr;
|
||||
if (!matchPattern(input_op, m_Constant(&attr))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto value_shape_type = mlir::RankedTensorType::get(
|
||||
value_type.getShape(), builder->getIntegerType(32));
|
||||
|
||||
SmallVector<int32_t, 4> value_i32;
|
||||
value_i32.reserve(value_type.getRank());
|
||||
for (const auto &size : attr) {
|
||||
value_i32.push_back(static_cast<int32_t>(size.getSExtValue()));
|
||||
}
|
||||
auto new_value_i32_attr =
|
||||
mlir::DenseIntElementsAttr::get(value_shape_type, value_i32);
|
||||
|
||||
return builder->create<TFL::ConstOp>(loc, new_value_i32_attr);
|
||||
}
|
||||
|
||||
// This will cast donw int64 values for TFL slice op.
|
||||
// This will require the begin & size are constants.
|
||||
struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern<TFL::SliceOp> {
|
||||
using OpRewritePattern<TFL::SliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TFL::SliceOp slice_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto begin = slice_op.begin();
|
||||
auto size = slice_op.size();
|
||||
auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
auto begin_op = begin.getDefiningOp();
|
||||
auto size_op = size.getDefiningOp();
|
||||
|
||||
if (begin_op == nullptr && size_op == nullptr) return failure();
|
||||
|
||||
if (begin_type == nullptr && size_type == nullptr) return failure();
|
||||
|
||||
// Handle begin.
|
||||
if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) {
|
||||
auto new_begin = NarrowDownInt64InputValuesForOp(
|
||||
begin_op, begin_type, slice_op.getLoc(), &rewriter);
|
||||
if (new_begin != nullptr) {
|
||||
slice_op.setOperand(1, new_begin);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle size.
|
||||
if (size_op && size_type && size_type.getElementType().isInteger(64)) {
|
||||
auto new_size = NarrowDownInt64InputValuesForOp(
|
||||
size_op, size_type, slice_op.getLoc(), &rewriter);
|
||||
if (new_size != nullptr) {
|
||||
slice_op.setOperand(2, new_size);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<CastDonwInt64BeginEndToInt32>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1839,6 +1839,8 @@ equivalent to setting:
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
|
||||
|
@ -98,3 +98,16 @@ func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5x
|
||||
// CHECK-NOT: pack
|
||||
// CHECK: return %arg0, %[[UNPACK]]#0 : tensor<2x5xf32>, tensor<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> {
|
||||
%0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
%1 = "tfl.pseudo_const"() {value = dense<[1, 128, 32]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
%2 = "tfl.slice"(%arg0, %0, %1) : (tensor<4x128x32xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x128x32xf32>
|
||||
return %2 : tensor<1x128x32xf32>
|
||||
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<3xi32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32>
|
||||
// CHECK: [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user