Add a TFL_CustomOp definition and support in exporter and importer.

Replace the previously defined 3 custom ops TFL_MaxPoolingWithArgMax2DOp, TFL_MaxUnpooling2DOp and TFL_MaxUnpooling2DOp instances with TFL_CustomOp

PiperOrigin-RevId: 308755676
Change-Id: Id33e2987c475c6bdd10cbc88410ae4e561aabbdb
This commit is contained in:
Chuan He 2020-04-27 21:05:58 -07:00 committed by TensorFlower Gardener
parent bcdbfb8a9c
commit fb7ea8f0e6
12 changed files with 112 additions and 451 deletions

View File

@ -249,9 +249,9 @@ class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.max_unpooling_2d
// tfl.custom
template <>
class TFLiteCostEstimator<MaxUnpooling2DOp, hardware::GPU> {
class TFLiteCostEstimator<CustomOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "

View File

@ -403,17 +403,8 @@ class Translator {
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>>
BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::Operator>> BuildMaxUnpooling2DOperator(
Operation* inst, mlir::TFL::MaxUnpooling2DOp op,
BufferOffset<tflite::Operator> BuildCustomOperator(
Operation* inst, mlir::TFL::CustomOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
@ -767,48 +758,21 @@ BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results);
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildConvolution2DTransposeBiasOperator(
Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op,
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
Operation* inst, mlir::TFL::CustomOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
TfLiteTransposeConvParams conv_params;
conv_params.stride_height = op.stride_h().getSExtValue();
conv_params.stride_width = op.stride_w().getSExtValue();
const auto padding = GetTflitePadding(inst, op.padding());
if (padding) {
conv_params.padding = *padding;
return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxPoolingWithArgMax2DOperator(
Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op,
operands, results);
}
return llvm::None;
}
Optional<BufferOffset<tflite::Operator>>
Translator::BuildMaxUnpooling2DOperator(Operation* inst,
mlir::TFL::MaxUnpooling2DOp op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
const auto pool_params = GetTflitePoolParams(inst, op);
if (pool_params) {
return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands,
results);
}
return llvm::None;
const std::string attrs =
op.custom_option().cast<mlir::OpaqueElementsAttr>().getValue().str();
std::vector<uint8_t> custom_option_vector(attrs.size());
memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
auto opcode_index =
GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0,
builder_.CreateVector<uint8_t>(custom_option_vector),
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
@ -951,19 +915,8 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
return BuildNumericVerifyOperator(verify_op, operands, results);
}
if (auto conv_transpose_bias_op =
dyn_cast<mlir::TFL::Convolution2DTransposeBiasOp>(inst)) {
return BuildConvolution2DTransposeBiasOperator(
inst, conv_transpose_bias_op, operands, results);
}
if (auto max_pooling_with_arg_max_op =
dyn_cast<mlir::TFL::MaxPoolingWithArgMax2DOp>(inst)) {
return BuildMaxPoolingWithArgMax2DOperator(
inst, max_pooling_with_arg_max_op, operands, results);
}
if (auto max_unpooling_op = dyn_cast<mlir::TFL::MaxUnpooling2DOp>(inst)) {
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
return BuildCustomOperator(inst, custom_op, operands, results);
}
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
if (inst->getNumOperands() != inst->getNumResults()) {

View File

@ -246,23 +246,8 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
}
StatusOr<std::string> OpNameForOpCode(const tflite::OperatorCodeT opcode) {
// TODO(b/143872630): Support custom ops
if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) {
// Adding some custom op supported on GPU.
const absl::string_view custom_name = opcode.custom_code;
if (custom_name == "MaxPoolingWithArgmax2D") {
return std::string("tfl.max_pooling_with_argmax_2d");
}
if (custom_name == "Convolution2DTransposeBias") {
return std::string("tfl.convolution_2d_transpose_bias");
}
if (custom_name == "MaxUnpooling2D") {
return std::string("tfl.max_unpooling_2d");
}
// Use an unsupported op name instead of throwing an error here in case the
// op is pruned during the import.
return std::string(
llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str());
return std::string("tfl.custom");
}
if (opcode.builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
@ -523,18 +508,13 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) {
}
}
// Returns true if this is a custom op.
bool IsCustomOp(const std::string& op_name) {
return op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d" ||
op_name == "tfl.convolution_2d_transpose_bias";
}
// TODO(krzysd) Handle function calls
StatusOr<Operation*> ConvertOp(
const tflite::OperatorT& op, const std::vector<Value>& vals_map,
const std::vector<mlir::TensorType>& intermediate_types,
Value optional_arg_marker, const std::vector<std::string>& op_names,
Value optional_arg_marker,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::TensorT>>& tensors, Location loc,
OpBuilder builder) {
@ -547,6 +527,7 @@ StatusOr<Operation*> ConvertOp(
}
const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options);
const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index);
const std::string& op_name =
is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index);
OperationState op_state(loc, op_name);
@ -638,9 +619,9 @@ StatusOr<Operation*> ConvertOp(
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
if (IsCustomOp(op_name)) {
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
builder, loc, &attrs);
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
auto status = mlir::CustomOptionsToAttributes(
op_code.custom_code, op.custom_options, builder, loc, &attrs);
if (!status.ok()) {
return emitError(loc, status.ToString()), status;
}
@ -752,6 +733,7 @@ StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
// return nodes in ordered_output_arrays in the same order.
StatusOr<FuncOp> ConvertSubgraph(
const tflite::SubGraphT& subgraph, llvm::StringRef name,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>>& op_codes,
const std::vector<std::string>& op_names,
const std::vector<std::string>& func_names,
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
@ -942,7 +924,8 @@ StatusOr<FuncOp> ConvertSubgraph(
TF_ASSIGN_OR_RETURN(
auto* mlir_op,
ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker,
op_names, func_names, subgraph.tensors, op_loc, op_builder));
op_codes, op_names, func_names, subgraph.tensors, op_loc,
op_builder));
// Add the results to the value maps. There are two cases: 1. the result
// tensor does not have min/max values, the original op result is used
@ -1049,8 +1032,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
auto& subgraph = e.value();
std::string name = SubgraphName(e.index(), *subgraph);
auto func_or_error = ConvertSubgraph(
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
builder,
*subgraph, name, model->operator_codes, operator_names, func_names,
model->buffers, base_loc, builder,
// TODO(b/131175224,b/132239787) Support multiple entry points
/*is_entry_point=*/e.index() == 0,
/*use_external_constant=*/use_external_constant, ordered_input_arrays,

View File

@ -243,42 +243,22 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
}
Status mlir::CustomOptionsToAttributes(
const std::string& op_name, const std::vector<uint8_t>& custom_options,
const std::string& custom_code, const std::vector<uint8_t>& custom_options,
mlir::Builder builder, mlir::Location loc,
llvm::SmallVectorImpl<mlir::NamedAttribute>* attributes) {
if (op_name == "tfl.max_pooling_with_argmax_2d" ||
op_name == "tfl.max_unpooling_2d") {
auto* pool_params =
reinterpret_cast<const TfLitePoolParams*>(custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(pool_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(pool_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(pool_params->stride_width)));
attributes->emplace_back(builder.getNamedAttr(
"filter_h", builder.getI32IntegerAttr(pool_params->filter_height)));
attributes->emplace_back(builder.getNamedAttr(
"filter_w", builder.getI32IntegerAttr(pool_params->filter_width)));
return Status::OK();
attributes->emplace_back(
builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code)));
std::string content;
content.assign(reinterpret_cast<const char*>(custom_options.data()),
custom_options.size());
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(custom_options.size())}, builder.getIntegerType(8));
attributes->emplace_back(builder.getNamedAttr(
"custom_option",
OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"),
type, content)));
} else if (op_name == "tfl.convolution_2d_transpose_bias") {
auto* conv_params = reinterpret_cast<const TfLiteTransposeConvParams*>(
custom_options.data());
TF_ASSIGN_OR_RETURN(auto padding_attribute,
GetPaddingAttr(conv_params->padding, builder, loc));
attributes->emplace_back(
builder.getNamedAttr("padding", padding_attribute));
attributes->emplace_back(builder.getNamedAttr(
"stride_h", builder.getI32IntegerAttr(conv_params->stride_height)));
attributes->emplace_back(builder.getNamedAttr(
"stride_w", builder.getI32IntegerAttr(conv_params->stride_width)));
return Status::OK();
}
return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name));
return Status::OK();
}
// Pull in FlatBuffer writers for TFLite generated using TableGen

View File

@ -61,11 +61,12 @@ void BuiltinOptionsToAttributes(
// operands from tflite op name.
llvm::MinMax OperandNumbersMinMax(llvm::StringRef op_name);
// Populates the array of mlir::NamedAttributes corresponding to the given
// custom_options.
// We use an out parameter per LLVM convention
// Populates the `custom_code` and `custom_options` to attributes.
// `custom_code` is used to identify CustomOp.
// `custom_options` are opaque attribute used to store infomations for this
// custom op.
tensorflow::Status CustomOptionsToAttributes(
const std::string &op_name, const std::vector<uint8_t> &custom_options,
const std::string &custom_code, const std::vector<uint8_t> &custom_options,
mlir::Builder builder,
// NOLINTNEXTLINE
Location loc, llvm::SmallVectorImpl<mlir::NamedAttribute> *attributes);

View File

@ -2022,6 +2022,18 @@ LogicalResult Verify(WhileOp op) {
return success();
}
static LogicalResult Verify(CustomOp op) {
OpaqueElementsAttr opaque_attr =
op.custom_option().cast<OpaqueElementsAttr>();
if (!opaque_attr.getType().hasStaticShape())
return op.emitOpError("custom_option should have a static shape.");
if (opaque_attr.getValue().size() !=
opaque_attr.getType().cast<ShapedType>().getDimSize(0))
return op.emitOpError(
"custom_option should have the same length of content with shape.");
return success();
}
namespace {
// Canonicalize While op so that results and operands match and external values
// are via implicit capture rather than via block args.

View File

@ -99,6 +99,16 @@ def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [
// A type attribute containing the TensorType.
def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// A type attribute containing OpaqueElementsAttr and bytes.
def OpaqueBytesAttr : ElementsAttrBase<
And<[
CPred<"$_self.isa<OpaqueElementsAttr>() ">,
CPred<"$_self.cast<OpaqueElementsAttr>().getType()"
".getElementType().isInteger(8)">,
]>,
"opaque bytes attribute"
>;
//===----------------------------------------------------------------------===//
// Derived shape attribute class.
//===----------------------------------------------------------------------===//
@ -507,38 +517,6 @@ def TFL_TransposeConvOp:
let verifier = [{ return Verify(*this); }];
}
def TFL_Convolution2DTransposeBiasOp :
Op<TFL_Dialect, "convolution_2d_transpose_bias", [
NoSideEffect,
TFL_OperandHasRank<0, 4>,
TFL_OperandHasRank<1, 4>,
TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<2, 1>
]> {
let summary = " Transpose convolution with bias operator";
let description = [{
Performs transpose convolution operation on inputs,
with the option of adding a bias.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (
ins TFL_FpTensor:$input,
TFL_FpTensor:$filter,
TFL_TensorOfOrNone<[F32]>:$bias,
TFL_PaddingAttr:$padding,
Confined<I32Attr, [IntPositive]>:$stride_h,
Confined<I32Attr, [IntPositive]>:$stride_w
);
let results = (outs TFL_FpTensor:$output);
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d",
[NoSideEffect,
@ -1727,63 +1705,6 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
let customOption = "Pool2DOptions";
}
def TFL_MaxPoolingWithArgMax2DOp :
Op<TFL_Dialect, "max_pooling_with_argmax_2d", [NoSideEffect]> {
let summary = "Max Pool 2D with argmax op";
let description = [{
Performs max pooling on the input and outputs both max values and indices.
Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
}];
let arguments = (
ins AnyTensor:$input,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs
AnyTensor:$value,
AnyTensor:$indices
);
}
def TFL_MaxUnpooling2DOp :
Op<TFL_Dialect, "max_unpooling_2d", [NoSideEffect, TFL_GpuTargetOp]> {
let summary = "Max Unpool 2D";
let description = [{
Performs max unpool operation.
To some extent this is the reverse operation of max pooling:
the elements in the input activation tensor is stored into the position
specified by the input indices.
Note this is a custom op that is not supported in the standard runtime.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the input indices
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$indices,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_w,
I32Attr:$filter_h
);
let results = (outs AnyTensor:$outputs);
}
def TFL_MaximumOp : TFL_Op<"maximum", [
ResultsBroadcastableShape,
NoSideEffect,
@ -3991,4 +3912,27 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
let hasCanonicalizer = 1;
}
def TFL_CustomOp : Op<TFL_Dialect, "custom", [NoSideEffect]> {
let summary = "Custom op";
let description = [{
A generic op for any TFLite custom operation.
input: A list of inputs in the original op.
custom_code: A string used to identify which exactly this op is, which
corresponds to operator_codes.custom_code in the flatbuffer.
custom_option: a holder to save the op attributes in bytes fashion.
output: A list of outputs in the original op.
}];
let arguments = (ins
Variadic<TFL_TensorOfOrNone<[AnyType]>>:$input,
StrAttr:$custom_code,
OpaqueBytesAttr:$custom_option
);
let results = (outs Variadic<AnyTensor>:$output);
let verifier = [{ return Verify(*this); }];
}
#endif // TFL_OPS

View File

@ -0,0 +1,8 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}
// CHECK-LABEL: main
// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>

View File

@ -1,82 +0,0 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Convolution2DTransposeBias"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 32, 4, 4, 128 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 42, 128 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "arg2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 64, 84, 32 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>)
// MLIR-SAME: -> tensor<1x64x84x32xf32>
// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2)
// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32}
// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32>
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}

View File

@ -1,71 +0,0 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 64, 64, 32 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 32, 32, 32 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1, 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1, 2 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>)
// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0)
// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32}
// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}

View File

@ -1,71 +0,0 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s
func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MaxUnpooling2D"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 8, 8, 128 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.max_unpooling_2d",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}
// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>)
// MLIR-SAME: -> tensor<1x8x8x128xf32>
// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1)
// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32}
// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32>
// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32>
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
return %0 : tensor<1x8x8x128xf32>
}

View File

@ -517,14 +517,16 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
// -----
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
%0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
// custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
%0, %1 = "tfl.custom"(%arg0) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}
// -----
func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> {
%0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
// custom op for "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
%0 = "tfl.custom"(%arg0, %arg1) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxUnpooling2D"} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>)
return %0 : tensor<1x8x8x128xf32>
}
@ -2040,7 +2042,8 @@ func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %ar
// -----
func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
// custom op for "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
%0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>, custom_code = "Convolution2DTransposeBias"} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}
@ -2048,7 +2051,8 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso
func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> {
%cst = constant unit
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
// custom op for "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
%0 = "tfl.custom"(%arg0, %arg1, %cst) {custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>, custom_code = "Convolution2DTransposeBias"} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}