PR #41735: [MLIR:LITE] Verify unpack op
PiperOrigin-RevId: 327119507 Change-Id: I6b71381c8f1d6f31e2d4c854273a391908d8fb74
This commit is contained in:
parent
1e87951747
commit
fb78e1b6c1
@ -29,7 +29,6 @@ filegroup(
|
|||||||
"ir/tfl_ops.td",
|
"ir/tfl_ops.td",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||||
],
|
],
|
||||||
@ -228,7 +227,6 @@ cc_library(
|
|||||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||||
"@llvm-project//mlir:Dialect",
|
"@llvm-project//mlir:Dialect",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:InferTypeOpInterface",
|
|
||||||
"@llvm-project//mlir:LoopLikeInterface",
|
"@llvm-project//mlir:LoopLikeInterface",
|
||||||
"@llvm-project//mlir:QuantOps",
|
"@llvm-project//mlir:QuantOps",
|
||||||
"@llvm-project//mlir:SideEffects",
|
"@llvm-project//mlir:SideEffects",
|
||||||
@ -502,7 +500,6 @@ gentbl(
|
|||||||
tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen",
|
tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen",
|
||||||
td_file = "ir/tfl_ops.td",
|
td_file = "ir/tfl_ops.td",
|
||||||
td_srcs = [
|
td_srcs = [
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||||
"ir/tfl_op_interfaces.td",
|
"ir/tfl_op_interfaces.td",
|
||||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
#include "mlir/IR/Location.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
@ -1446,59 +1445,12 @@ void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
|
|
||||||
// TODO(b/133486129): Implement shape inference for unpack
|
// TODO(b/133486129): Implement shape inference for unpack
|
||||||
|
|
||||||
LogicalResult UnpackOp::inferReturnTypes(
|
static LogicalResult Verify(UnpackOp op) {
|
||||||
MLIRContext *context, Optional<Location> loc, ValueRange operands,
|
// TODO(antiagainst): Implement other checks as in
|
||||||
DictionaryAttr attributes, RegionRange regions,
|
// tensorflow/lite/kernels/unpack.cc
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
||||||
UnpackOpAdaptor op(operands, attributes);
|
|
||||||
// TODO(jpienaar): Refactor inferReturnTypes.
|
|
||||||
if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (operands.size() != 1) {
|
if (op.getOperation()->getNumResults() != op.num())
|
||||||
return emitOptionalError(loc, "input count should be equal to 1");
|
return op.emitOpError("output count should match 'num' attribute");
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t num_value = op.num().getInt();
|
|
||||||
auto input_type = operands[0].getType().dyn_cast<ShapedType>();
|
|
||||||
if (!input_type || !input_type.hasRank()) {
|
|
||||||
// If input is unranked, then so is output.
|
|
||||||
inferredReturnTypes.assign(
|
|
||||||
num_value, UnrankedTensorType::get(input_type.getElementType()));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (input_type.getNumElements() <= 0) {
|
|
||||||
return emitOptionalError(
|
|
||||||
loc, "number of elements in input shoule be larger than 0");
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t rank = input_type.getRank();
|
|
||||||
if (rank <= 0) {
|
|
||||||
return emitOptionalError(loc, "input should be of rank larger than 0");
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t axis_value = op.axis().getInt();
|
|
||||||
if (axis_value < 0) {
|
|
||||||
axis_value += rank;
|
|
||||||
}
|
|
||||||
if (axis_value < 0 || axis_value >= rank) {
|
|
||||||
return emitOptionalError(
|
|
||||||
loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
|
|
||||||
op.axis().getInt(), ", and rank = ", rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
|
|
||||||
input_type.getDimSize(axis_value) != num_value) {
|
|
||||||
return emitOptionalError(loc, "output count should match 'num' attribute");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output_shape = llvm::to_vector<4>(input_type.getShape());
|
|
||||||
output_shape.erase(output_shape.begin() + axis_value);
|
|
||||||
|
|
||||||
auto output_type =
|
|
||||||
RankedTensorType::get(output_shape, input_type.getElementType());
|
|
||||||
inferredReturnTypes.assign(num_value, output_type);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
|
||||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#define TFL_OPS
|
#define TFL_OPS
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
|
||||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||||
@ -3029,8 +3028,7 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
|
|||||||
def TFL_UnpackOp : TFL_Op<"unpack", [
|
def TFL_UnpackOp : TFL_Op<"unpack", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultElementType,
|
SameOperandsAndResultElementType,
|
||||||
SameOperandsAndResultsScale,
|
SameOperandsAndResultsScale]> {
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
|
||||||
let summary = "Unpacks a tensor along a dimension into multiple tensors";
|
let summary = "Unpacks a tensor along a dimension into multiple tensors";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -3061,6 +3059,8 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
|
|||||||
TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
|
TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1189,22 +1189,7 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
|||||||
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||||
return %0#0 : tensor<2xi32>
|
return %0#0 : tensor<2xi32>
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
|
||||||
// CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32}
|
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
|
||||||
return %0#0 : tensor<2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> {
|
|
||||||
// CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32}
|
|
||||||
%0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
|
|
||||||
return %0#0 : tensor<3xi32>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
@ -1225,45 +1210,6 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
|
||||||
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}}
|
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
|
||||||
return %0#0 : tensor<2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
|
||||||
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}}
|
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
|
||||||
return %0#0 : tensor<2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<i32>) -> tensor<2xi32> {
|
|
||||||
// expected-error @+1 {{input should be of rank larger than 0}}
|
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
|
||||||
return %0#0 : tensor<2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
|
||||||
// expected-error @+1 {{op inferred type incompatible with return type of operation}}
|
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>)
|
|
||||||
return %0#0 : tensor<2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
|
|
||||||
%0:2 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>)
|
|
||||||
return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: testMean
|
// CHECK-LABEL: testMean
|
||||||
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
|
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
|
||||||
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}
|
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user