Add a pass in mlir optimization for strided_slice.
currently tflite kernel does not support strided_slice with non-zero new_axis_mask case, we can add a reshape op right before the strided_slice. it should be always safe to do so. PiperOrigin-RevId: 258891977
This commit is contained in:
parent
0939d54145
commit
a355cfad30
@ -147,3 +147,32 @@ func @FuseFullyConnectedAdd(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>)
|
||||
// CHECK: %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32>
|
||||
// CHECK: %2 = "tfl.fully_connected"(%0, %1, %cst)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask
|
||||
func @NoPadStridedSliceNonNewAxisMask(%arg0: tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32> {
|
||||
%cst = constant dense<0> : tensor<4xi32>
|
||||
%cst_0 = constant dense<1> : tensor<4xi32>
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32>
|
||||
%1 = "tfl.strided_slice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
return %1 : tensor<1x2x3x1xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<0> : tensor<4xi32>
|
||||
// CHECK: %cst_0 = constant dense<1> : tensor<4xi32>
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x2x3x1xf32>) -> tensor<1x2x3x1xf32>
|
||||
// CHECK: %1 = "tfl.strided_slice"(%0, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @PadStridedSliceNewAxisMask
|
||||
func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> {
|
||||
%cst = constant dense<0> : tensor<4xi32>
|
||||
%cst_0 = constant dense<1> : tensor<4xi32>
|
||||
%0 = "tfl.pseudo_input"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.strided_slice"(%0, %cst, %cst, %cst_0) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 9 : i32, shrink_axis_mask = 0 : i32} : (tensor<2x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
return %1 : tensor<1x2x3x1xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<0> : tensor<4xi32>
|
||||
// CHECK: %cst_0 = constant dense<1> : tensor<4xi32>
|
||||
// CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%0) : (tensor<2x3xf32>) -> tensor<1x2x3x1xf32>
|
||||
// CHECK: %2 = "tfl.strided_slice"(%1, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32>
|
||||
}
|
||||
|
@ -17,8 +17,12 @@ limitations under the License.
|
||||
// optimizes them to resulting operations in TensorFlowLite dialect.
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
@ -117,6 +121,87 @@ struct FuseFullyConnectedAndAdd : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
// StridedSlice can have complicated atributes like begin_axis_mask,
|
||||
// end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
|
||||
// masks will complicate the strided_slice computation logic, we can simplify
|
||||
// the logic by inserting a reshape op to pad the inputs so strided_slice can
|
||||
// be easier to handle.
|
||||
//
|
||||
// So the graph may looks like below:
|
||||
// original_input -> strided_slice -> output
|
||||
// (transforms)
|
||||
// original_input -> reshape -> strided_slice -> output
|
||||
//
|
||||
// And the new shape is computed based on the masks.
|
||||
//
|
||||
// An example for new_axis_mask. say the new_axis_mask is 9 which represents
|
||||
// [1 0 0 1], and that means we're inserting two new axes at 0 & 3 dim, so
|
||||
// if original shape is [2, 3], now we reshape that into [1, 2, 3, 1].
|
||||
struct PadStridedSliceDims : public RewritePattern {
|
||||
explicit PadStridedSliceDims(MLIRContext *context)
|
||||
: RewritePattern(TFL::StridedSliceOp::getOperationName(),
|
||||
{"tfl.strided_slice", "tfl.strided_slice"}, 2, context) {
|
||||
}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *strided_slice_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(renjieliu): Consider expand the transformation for ellipsis & shrink
|
||||
// mask as well.
|
||||
TFL::StridedSliceOp strided_slice =
|
||||
llvm::cast<TFL::StridedSliceOp>(strided_slice_op);
|
||||
const uint64_t new_axis_mask = strided_slice.new_axis_mask().getZExtValue();
|
||||
if (new_axis_mask == 0) return matchFailure();
|
||||
|
||||
// Insert a new reshape op.
|
||||
Value *original_input = strided_slice.input();
|
||||
const RankedTensorType &original_input_type =
|
||||
original_input->getType().template cast<RankedTensorType>();
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
const RankedTensorType &begin_type =
|
||||
strided_slice.begin()->getType().template cast<RankedTensorType>();
|
||||
const int dim_size = begin_type.getShape()[0];
|
||||
SmallVector<int64_t, 4> new_shape;
|
||||
int mask = 1;
|
||||
int index = 0;
|
||||
for (int i = 0; i < dim_size; ++i) {
|
||||
if (mask & new_axis_mask) {
|
||||
new_shape.emplace_back(1);
|
||||
} else {
|
||||
new_shape.emplace_back(original_input_shape[index]);
|
||||
++index;
|
||||
}
|
||||
mask = mask << 1;
|
||||
}
|
||||
|
||||
auto new_output_type =
|
||||
rewriter.getTensorType(new_shape, original_input_type.getElementType());
|
||||
|
||||
TFL::ReshapeOp reshape = rewriter.create<TFL::ReshapeOp>(
|
||||
strided_slice.getLoc(), new_output_type, original_input);
|
||||
|
||||
// Replace the original strided_slice.
|
||||
llvm::APInt new_begin_mask = strided_slice.begin_mask();
|
||||
llvm::APInt new_end_mask = strided_slice.end_mask();
|
||||
// Since we expand the dims, we need to apply them to the begin_mask &
|
||||
// end_mask.
|
||||
new_begin_mask |= strided_slice.new_axis_mask();
|
||||
new_end_mask |= strided_slice.new_axis_mask();
|
||||
|
||||
auto attribute_type = rewriter.getIntegerType(32);
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
strided_slice_op, strided_slice.getType(), reshape,
|
||||
strided_slice.begin(), strided_slice.end(), strided_slice.strides(),
|
||||
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, strided_slice.ellipsis_mask()),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice.shrink_axis_mask()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
@ -124,6 +209,7 @@ void Optimize::runOnFunction() {
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.push_back(
|
||||
llvm::make_unique<FuseFullyConnectedAndAdd>(&getContext()));
|
||||
patterns.push_back(llvm::make_unique<PadStridedSliceDims>(&getContext()));
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user