diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 47a7b32d7e3..1b25ac6ef18 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1014,6 +1014,75 @@ static LogicalResult Verify(SliceOp op) { return success(); } +TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op, + RankedTensorType value_type, + Location loc, OpBuilder *builder) { + if (input_op == nullptr) return nullptr; + + mlir::DenseIntElementsAttr attr; + if (!matchPattern(input_op, m_Constant(&attr))) { + return nullptr; + } + + auto value_shape_type = mlir::RankedTensorType::get( + value_type.getShape(), builder->getIntegerType(32)); + + SmallVector value_i32; + value_i32.reserve(value_type.getRank()); + for (const auto &size : attr) { + value_i32.push_back(static_cast(size.getSExtValue())); + } + auto new_value_i32_attr = + mlir::DenseIntElementsAttr::get(value_shape_type, value_i32); + + return builder->create(loc, new_value_i32_attr); +} + +// This will cast donw int64 values for TFL slice op. +// This will require the begin & size are constants. +struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SliceOp slice_op, + PatternRewriter &rewriter) const override { + auto begin = slice_op.begin(); + auto size = slice_op.size(); + auto begin_type = begin.getType().dyn_cast_or_null(); + auto size_type = size.getType().dyn_cast_or_null(); + auto begin_op = begin.getDefiningOp(); + auto size_op = size.getDefiningOp(); + + if (begin_op == nullptr && size_op == nullptr) return failure(); + + if (begin_type == nullptr && size_type == nullptr) return failure(); + + // Handle begin. + if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) { + auto new_begin = NarrowDownInt64InputValuesForOp( + begin_op, begin_type, slice_op.getLoc(), &rewriter); + if (new_begin != nullptr) { + slice_op.setOperand(1, new_begin); + } + } + + // Handle size. + if (size_op && size_type && size_type.getElementType().isInteger(64)) { + auto new_size = NarrowDownInt64InputValuesForOp( + size_op, size_type, slice_op.getLoc(), &rewriter); + if (new_size != nullptr) { + slice_op.setOperand(2, new_size); + } + } + + return success(); + } +}; + +void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index bf7b78e805d..22982850cbd 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1839,6 +1839,8 @@ equivalent to setting: ); let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index c94eb1bf087..1f067aae685 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -98,3 +98,16 @@ func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5x // CHECK-NOT: pack // CHECK: return %arg0, %[[UNPACK]]#0 : tensor<2x5xf32>, tensor<5xf32> } + +// ----- + +func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> { + %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64> + %1 = "tfl.pseudo_const"() {value = dense<[1, 128, 32]> : tensor<3xi64>} : () -> tensor<3xi64> + %2 = "tfl.slice"(%arg0, %0, %1) : (tensor<4x128x32xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x128x32xf32> + return %2 : tensor<1x128x32xf32> + +// CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32> +// CHECK: [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32> +}