PR #41735: [MLIR:LITE] Verify unpack op

PiperOrigin-RevId: 327119507
Change-Id: I6b71381c8f1d6f31e2d4c854273a391908d8fb74
This commit is contained in:
Jacques Pienaar 2020-08-17 16:08:09 -07:00 committed by TensorFlower Gardener
parent 1e87951747
commit fb78e1b6c1
5 changed files with 8 additions and 114 deletions

View File

@ -29,7 +29,6 @@ filegroup(
"ir/tfl_ops.td", "ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
], ],
@ -228,7 +227,6 @@ cc_library(
"@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffects", "@llvm-project//mlir:SideEffects",
@ -502,7 +500,6 @@ gentbl(
tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen",
td_file = "ir/tfl_ops.td", td_file = "ir/tfl_ops.td",
td_srcs = [ td_srcs = [
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"ir/tfl_op_interfaces.td", "ir/tfl_op_interfaces.td",

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project
@ -1446,59 +1445,12 @@ void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// TODO(b/133486129): Implement shape inference for unpack // TODO(b/133486129): Implement shape inference for unpack
LogicalResult UnpackOp::inferReturnTypes( static LogicalResult Verify(UnpackOp op) {
MLIRContext *context, Optional<Location> loc, ValueRange operands, // TODO(antiagainst): Implement other checks as in
DictionaryAttr attributes, RegionRange regions, // tensorflow/lite/kernels/unpack.cc
SmallVectorImpl<Type> &inferredReturnTypes) {
UnpackOpAdaptor op(operands, attributes);
// TODO(jpienaar): Refactor inferReturnTypes.
if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context))))
return failure();
if (operands.size() != 1) { if (op.getOperation()->getNumResults() != op.num())
return emitOptionalError(loc, "input count should be equal to 1"); return op.emitOpError("output count should match 'num' attribute");
}
const int64_t num_value = op.num().getInt();
auto input_type = operands[0].getType().dyn_cast<ShapedType>();
if (!input_type || !input_type.hasRank()) {
// If input is unranked, then so is output.
inferredReturnTypes.assign(
num_value, UnrankedTensorType::get(input_type.getElementType()));
return success();
}
if (input_type.getNumElements() <= 0) {
return emitOptionalError(
loc, "number of elements in input shoule be larger than 0");
}
const int64_t rank = input_type.getRank();
if (rank <= 0) {
return emitOptionalError(loc, "input should be of rank larger than 0");
}
int64_t axis_value = op.axis().getInt();
if (axis_value < 0) {
axis_value += rank;
}
if (axis_value < 0 || axis_value >= rank) {
return emitOptionalError(
loc, "attribute 'axis' should be in range [-rank, rank), got axis = ",
op.axis().getInt(), ", and rank = ", rank);
}
if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) &&
input_type.getDimSize(axis_value) != num_value) {
return emitOptionalError(loc, "output count should match 'num' attribute");
}
auto output_shape = llvm::to_vector<4>(input_type.getShape());
output_shape.erase(output_shape.begin() + axis_value);
auto output_type =
RankedTensorType::get(output_shape, input_type.getElementType());
inferredReturnTypes.assign(num_value, output_type);
return success(); return success();
} }

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project

View File

@ -19,7 +19,6 @@ limitations under the License.
#define TFL_OPS #define TFL_OPS
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
@ -3029,8 +3028,7 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
def TFL_UnpackOp : TFL_Op<"unpack", [ def TFL_UnpackOp : TFL_Op<"unpack", [
NoSideEffect, NoSideEffect,
SameOperandsAndResultElementType, SameOperandsAndResultElementType,
SameOperandsAndResultsScale, SameOperandsAndResultsScale]> {
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors"; let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{ let description = [{
@ -3061,6 +3059,8 @@ def TFL_UnpackOp : TFL_Op<"unpack", [
TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs
); );
let verifier = [{ return Verify(*this); }];
let hasOptions = 1; let hasOptions = 1;
} }

View File

@ -1189,22 +1189,7 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32> return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32}
%0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>)
return %0#0 : tensor<3xi32>
} }
// ----- // -----
@ -1225,45 +1210,6 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// ----- // -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}}
%0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}}
%0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<i32>) -> tensor<2xi32> {
// expected-error @+1 {{input should be of rank larger than 0}}
%0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor<i32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{op inferred type incompatible with return type of operation}}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpack(%arg0: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
%0:2 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>)
return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32>
}
// -----
// CHECK-LABEL: testMean // CHECK-LABEL: testMean
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> { func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false} // CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}