[TFLite/MLIR] Adds a verifier to tfl.slice op.
PiperOrigin-RevId: 264725917
This commit is contained in:
parent
fe77c9191c
commit
f671dfffa9
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
@ -459,6 +460,74 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<RemoveAdjacentReshape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(SliceOp op) {
|
||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
||||
auto begin_type = op.begin()->getType().cast<ShapedType>();
|
||||
auto size_type = op.size()->getType().cast<ShapedType>();
|
||||
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
|
||||
size_type.hasStaticShape()) {
|
||||
if (input_type.getRank() != begin_type.getNumElements()) {
|
||||
return op.emitError(
|
||||
"begin tensor elements size is not equal to input tensor rank");
|
||||
}
|
||||
|
||||
if (input_type.getRank() != size_type.getNumElements()) {
|
||||
return op.emitError(
|
||||
"size tensor elements size is not equal to input tensor rank");
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntElementsAttr begin;
|
||||
if (matchPattern(op.begin(), m_Constant(&begin))) {
|
||||
int axis = 0;
|
||||
for (auto begin_i : llvm::enumerate(begin)) {
|
||||
if (begin_i.value().getSExtValue() < 0) {
|
||||
return op.emitError(
|
||||
llvm::formatv("begin[{0}] cannot be negative", axis));
|
||||
}
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntElementsAttr size;
|
||||
if (matchPattern(op.size(), m_Constant(&size))) {
|
||||
int axis = 0;
|
||||
for (auto size_i : llvm::enumerate(size)) {
|
||||
if (size_i.value().getSExtValue() < -1) {
|
||||
return op.emitError(
|
||||
llvm::formatv("size[{0}] cannot be negative other than -1", axis));
|
||||
}
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
|
||||
if (begin && size && input_type.hasStaticShape()) {
|
||||
const int input_rank = begin.getNumElements();
|
||||
for (uint64_t i = 0; i < input_rank; i++) {
|
||||
int begin_i =
|
||||
begin.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
|
||||
int size_i =
|
||||
size.getValue({i}).cast<IntegerAttr>().getValue().getSExtValue();
|
||||
int dim_i = input_type.getShape()[i];
|
||||
if (begin_i >= dim_i) {
|
||||
return op.emitOpError(llvm::formatv(
|
||||
"begin[{0}] cannot exceed dimension length: {1}", i, dim_i));
|
||||
}
|
||||
if (size_i >= 0 && begin_i + size_i > dim_i) {
|
||||
return op.emitError(llvm::formatv(
|
||||
"begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i,
|
||||
dim_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1264,6 +1264,11 @@ The output tensor is a tensor with dimensions described by 'size'
|
||||
whose values are extracted from 'input' starting at the offsets in
|
||||
'begin'.
|
||||
|
||||
`begin` is zero-based; `size` is one-based. If size[i] is -1, all remaining
|
||||
elements in dimension i are included in the slice. In other words, this is
|
||||
equivalent to setting:
|
||||
size[i] = input.dim_size(i) - begin[i]
|
||||
|
||||
*Requirements*:
|
||||
0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)
|
||||
}];
|
||||
@ -1277,6 +1282,8 @@ whose values are extracted from 'input' starting at the offsets in
|
||||
let results = (outs
|
||||
AnyTensor:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
|
||||
|
@ -1198,7 +1198,6 @@ func @testSvdfUnsupportedType(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>, %a
|
||||
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testDepthToSpace
|
||||
@ -1218,3 +1217,62 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSliceBadBeginDimension(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
// expected-error @+1 {{begin tensor elements size is not equal to input tensor rank}}
|
||||
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<2xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSliceBadSizeDimension(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<2xi32>) -> tensor<?x3x5xf32> {
|
||||
// expected-error @+1 {{size tensor elements size is not equal to input tensor rank}}
|
||||
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<2xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSliceBadBegin(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%cst = constant dense<[2, -1, 5]> : tensor<3xi32>
|
||||
// expected-error @+1 {{begin[1] cannot be negative}}
|
||||
%0 = "tfl.slice"(%arg0, %cst, %arg1) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSliceNegativeSize(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%cst = constant dense<[-2, -1, 5]> : tensor<3xi32>
|
||||
// expected-error @+1 {{size[0] cannot be negative other than -1}}
|
||||
%0 = "tfl.slice"(%arg0, %arg1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSliceSizeOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%cst = constant dense<[2, 1, 5]> : tensor<3xi32>
|
||||
%cst_1 = constant dense<[0, 1, 1]> : tensor<3xi32>
|
||||
// expected-error @+1 {{begin[2] + size[2] cannot exceed dimension length: 5}}
|
||||
%0 = "tfl.slice"(%arg0, %cst_1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSliceBeginOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%cst = constant dense<[1, 1, 1]> : tensor<3xi32>
|
||||
%cst_1 = constant dense<[2, 1, 3]> : tensor<3xi32>
|
||||
// expected-error @+1 {{begin[0] cannot exceed dimension length: 2}}
|
||||
%0 = "tfl.slice"(%arg0, %cst_1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user