Add MaxPoolWithArgmax as a TFLite custom op

This kernel is different than MaxPool as following:
- Returns both pooling and argmax results.
- Parameters are retrieved from custom option instead of builtin op data.

PiperOrigin-RevId: 347808061
Change-Id: I064b62c5313ba3860f6f52d965747fa13d3042b1
This commit is contained in:
Thai Nguyen 2020-12-16 06:04:04 -08:00 committed by TensorFlower Gardener
parent 305984e9e7
commit 2091277ccc
7 changed files with 558 additions and 0 deletions

View File

@ -10,6 +10,7 @@ package(
cc_library(
name = "perception_ops",
srcs = [
"max_pool_with_argmax.cc",
"max_unpooling_2d.cc",
"perception_ops.cc",
],
@ -23,8 +24,11 @@ cc_library(
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:common",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
"//tensorflow/lite/kernels/internal:types",
"@flatbuffers",
],
)
@ -32,6 +36,7 @@ cc_test(
name = "perception_ops_test",
size = "small",
srcs = [
"max_pool_with_argmax_test.cc",
"max_unpooling_2d_test.cc",
],
deps = [
@ -40,5 +45,6 @@ cc_test(
"//tensorflow/lite/kernels:test_main",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/testing:util",
"@flatbuffers",
],
)

View File

@ -0,0 +1,248 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace custom {
namespace max_pool_with_argmax {
namespace {
// TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
// this op to a builtin op.
template <typename T>
inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
const RuntimeShape& output_shape, const T* input_data,
T* output_data, int32_t* indices_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
const int32_t input_height = input_shape.Dims(1);
const int32_t input_width = input_shape.Dims(2);
const int32_t output_height = output_shape.Dims(1);
const int32_t output_width = output_shape.Dims(2);
const int32_t stride_height = params.stride_height;
const int32_t stride_width = params.stride_width;
for (int32_t batch = 0; batch < batches; ++batch) {
for (int32_t out_y = 0; out_y < output_height; ++out_y) {
for (int32_t out_x = 0; out_x < output_width; ++out_x) {
for (int32_t channel = 0; channel < depth; ++channel) {
const int32_t in_x_origin =
(out_x * stride_width) - params.padding_values.width;
const int32_t in_y_origin =
(out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int32_t filter_x_start = std::max(0, -in_x_origin);
const int32_t filter_x_end =
std::min(params.filter_width, input_width - in_x_origin);
const int32_t filter_y_start = std::max(0, -in_y_origin);
const int32_t filter_y_end =
std::min(params.filter_height, input_height - in_y_origin);
float max = std::numeric_limits<float>::lowest();
int32_t max_x = 0;
int32_t max_y = 0;
for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
++filter_x) {
const int32_t in_x = in_x_origin + filter_x;
const int32_t in_y = in_y_origin + filter_y;
float cur =
input_data[Offset(input_shape, batch, in_y, in_x, channel)];
if (cur > max) {
max = cur;
max_x = in_x;
max_y = in_y;
}
}
}
int32_t output_idx =
Offset(output_shape, batch, out_y, out_x, channel);
output_data[output_idx] = ActivationFunctionWithMinMax(
max, params.float_activation_min, params.float_activation_max);
indices_data[output_idx] =
(max_y * input_width + max_x) * depth + channel;
}
}
}
}
}
} // namespace
constexpr int kDataInputTensor = 0;
constexpr int kDataOutputTensor = 0;
constexpr int kIndicesOutputTensor = 1;
constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
constexpr const char kPoolSizeStr[] = "ksize";
constexpr const char kStridesStr[] = "strides";
constexpr const char kPaddingStr[] = "padding";
constexpr const char kPaddingSameStr[] = "SAME";
constexpr const char kPaddingValidStr[] = "VALID";
struct OpData {
TfLitePoolParams params;
bool include_batch_in_index;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const flexbuffers::Map& m =
flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
.AsMap();
OpData* op_data = new OpData;
op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
const std::string padding = m[kPaddingStr].AsString().str();
if (padding == kPaddingValidStr) {
op_data->params.padding = kTfLitePaddingValid;
} else if (padding == kPaddingSameStr) {
op_data->params.padding = kTfLitePaddingSame;
} else {
op_data->params.padding = kTfLitePaddingUnknown;
}
// The first and last element of pool_size are always 1.
const auto pool_size = m[kPoolSizeStr].AsTypedVector();
TFLITE_CHECK_EQ(pool_size.size(), 4);
TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
op_data->params.filter_height = pool_size[1].AsInt32();
op_data->params.filter_width = pool_size[2].AsInt32();
// The first and last element of strides are always 1.
const auto strides = m[kStridesStr].AsTypedVector();
TFLITE_CHECK_EQ(strides.size(), 4);
TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
op_data->params.stride_height = strides[1].AsInt32();
op_data->params.stride_width = strides[2].AsInt32();
return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
TfLiteTensor *output, *indices;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kDataOutputTensor, &output));
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDataInputTensor, &input));
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
TF_LITE_ENSURE_MSG(
context, !op_data->include_batch_in_index,
"Include batch dimension in flattened index is not yet supported.");
int batches = input->dims->data[0];
int height = input->dims->data[1];
int width = input->dims->data[2];
int channels_out = input->dims->data[3];
// Matching GetWindowedOutputSize in TensorFlow.
int out_width, out_height;
op_data->params.computed.padding = ComputePaddingHeightWidth(
op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
width, op_data->params.filter_height, op_data->params.filter_width,
op_data->params.padding, &out_height, &out_width);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = batches;
output_size->data[1] = out_height;
output_size->data[2] = out_width;
output_size->data[3] = channels_out;
TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
float activation_min, activation_max;
CalculateActivationRange(op_data->params.activation, &activation_min,
&activation_max);
tflite::PoolParams op_params;
op_params.stride_height = op_data->params.stride_height;
op_params.stride_width = op_data->params.stride_width;
op_params.filter_height = op_data->params.filter_height;
op_params.filter_width = op_data->params.filter_width;
op_params.padding_values.height = op_data->params.computed.padding.height;
op_params.padding_values.width = op_data->params.computed.padding.width;
op_params.float_activation_min = activation_min;
op_params.float_activation_max = activation_max;
TfLiteTensor *output, *indices;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kDataOutputTensor, &output));
TF_LITE_ENSURE_OK(
context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kDataInputTensor, &input));
switch (input->type) {
case kTfLiteFloat32:
MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
GetTensorData<float>(input), GetTensorData<float>(output),
GetTensorData<int32_t>(indices));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace max_pool_with_argmax
TfLiteRegistration* RegisterMaxPoolWithArgmax() {
static TfLiteRegistration r = {
max_pool_with_argmax::Init, max_pool_with_argmax::Free,
max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,298 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <vector>
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/perception/perception_ops.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace {
using testing::ElementsAreArray;
class MaxpoolingWithArgMaxOpModel : public SingleOpModel {
public:
MaxpoolingWithArgMaxOpModel(const TensorData& input, int stride_height,
int stride_width, int filter_height,
int filter_width, TfLitePadding padding,
const TensorData& output,
const TensorData& indices) {
input_ = AddInput(input);
output_ = AddOutput(output);
indices_ = AddOutput(indices);
std::vector<uint8_t> custom_option = CreateCustomOptions(
stride_height, stride_width, filter_height, filter_width, padding);
SetCustomOp("MaxPoolWithArgmax", custom_option, RegisterMaxPoolWithArgmax);
BuildInterpreter({GetShape(input_)});
}
void SetInput(const std::vector<float>& data) {
PopulateTensor(input_, data);
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
std::vector<int32_t> GetIndices() { return ExtractVector<int32_t>(indices_); }
std::vector<int> GetIndicesShape() { return GetTensorShape(indices_); }
protected:
int input_;
int output_;
int indices_;
private:
std::vector<uint8_t> CreateCustomOptions(int stride_height, int stride_width,
int filter_height, int filter_width,
TfLitePadding padding) {
auto flex_builder = std::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap();
flex_builder->Bool("include_batch_in_index", false);
if (padding == kTfLitePaddingValid) {
flex_builder->String("padding", "VALID");
} else {
flex_builder->String("padding", "SAME");
}
auto start = flex_builder->StartVector("ksize");
flex_builder->Add(1);
flex_builder->Add(filter_height);
flex_builder->Add(filter_width);
flex_builder->Add(1);
flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
auto strides_start = flex_builder->StartVector("strides");
flex_builder->Add(1);
flex_builder->Add(stride_height);
flex_builder->Add(stride_width);
flex_builder->Add(1);
flex_builder->EndVector(strides_start, /*typed=*/true, /*fixed=*/false);
flex_builder->EndMap(map_start);
flex_builder->Finish();
return flex_builder->GetBuffer();
}
};
TEST(MaxpoolWithArgMaxTest, UnsupportedInt64Test) {
EXPECT_DEATH_IF_SUPPORTED(MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
/*stride_height=*/2, /*stride_width=*/2,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingSame,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT64, {}});
, "indices->type == kTfLiteInt32 was not true.");
}
TEST(MaxpoolWithArgMaxTest, SimpleTest) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
/*stride_height=*/2, /*stride_width=*/2,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingSame,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({0, 13, 2, 0, 0, 1, 4, 0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({13, 4}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 1, 2, 1}));
EXPECT_THAT(model.GetIndices(), ElementsAreArray({1, 6}));
}
TEST(MaxpoolWithArgMaxTest, Strides2x1Test) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
/*stride_height=*/2, /*stride_width=*/1,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingSame,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 2}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 4, 0, 4, 7, 8, 0, 8}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 2}));
EXPECT_THAT(model.GetIndices(),
ElementsAreArray({4, 7, 2, 7, 12, 15, 10, 15}));
}
TEST(MaxpoolWithArgMaxTest, Strides2x2Test) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 8, 1}},
/*stride_height=*/2, /*stride_width=*/2,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingSame,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0,
0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 8});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 4, 0, 0, 7, 6, 8}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 4, 1}));
EXPECT_THAT(model.GetIndices(),
ElementsAreArray({0, 10, 13, 6, 16, 27, 20, 31}));
}
TEST(MaxpoolWithArgMaxTest, Strides2x2UnfitTest) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 7, 1}},
/*stride_height=*/2, /*stride_width=*/2,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingSame,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4,
0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 2, 4, 0, 0, 5, 7}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 4, 1}));
EXPECT_THAT(model.GetIndices(),
ElementsAreArray({0, 10, 5, 13, 14, 16, 19, 27}));
}
TEST(MaxpoolWithArgMaxTest, PaddingValidTest) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 5, 1}},
/*stride_height=*/2, /*stride_width=*/2,
/*filter_height=*/2, /*filter_width=*/3,
/*padding=*/kTfLitePaddingValid,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput(
{0, 0, 0, 0, 0, 0, 7, 0, 0, 10, 0, 0, 0, 0, 0, 0, 20, 0, 0, 19});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({7, 10, 20, 19}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 1}));
EXPECT_THAT(model.GetIndices(), ElementsAreArray({6, 9, 16, 19}));
}
TEST(MaxpoolWithArgMaxTest, PaddingValidUnfitTest) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 6, 1}},
/*stride_height=*/2, /*stride_width=*/2,
/*filter_height=*/2, /*filter_width=*/3,
/*padding=*/kTfLitePaddingValid,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({0, 0, 0, 0, 0, 0, 7, 0, 0, 10, 0, 0,
0, 0, 0, 0, 20, 0, 0, 19, 24, 1, 2, 44});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({7, 10, 24, 24}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 1}));
EXPECT_THAT(model.GetIndices(), ElementsAreArray({6, 9, 20, 20}));
}
TEST(MaxpoolWithArgMaxTest, InputWithBatchTest) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {2, 4, 12, 2}},
/*stride_height=*/2, /*stride_width=*/3,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingSame,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6,
0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0,
0, 16, 0, 0, 0, 0, 0, 0, 11, 0, 0, 12, 0, 0, 0, 14,
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
17, 18, 0, 0, 0, 30, 0, 20, 0, 0, 0, 0, 0, 0, 21, 0,
0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0,
0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0,
0, 0, 27, 28, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 32,
0, 0, 0, 0, 25, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4, 2}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({1, 0, 3, 4, 5, 6, 9, 8, 11, 12, 13,
14, 15, 0, 0, 0, 17, 18, 19, 20, 21, 0,
23, 24, 27, 28, 29, 0, 31, 32, 25, 26}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({2, 2, 4, 2}));
EXPECT_THAT(model.GetIndices(),
ElementsAreArray({2, 1, 8, 9, 12, 15, 44, 43, 72, 75, 80,
79, 62, 61, 66, 67, 0, 1, 30, 7, 14, 13,
42, 21, 50, 51, 56, 55, 86, 63, 68, 69}));
}
TEST(MaxpoolWithArgMaxTest, InputWithBatchAndPaddingValidTest) {
MaxpoolingWithArgMaxOpModel model(
/*input=*/{TensorType_FLOAT32, {2, 4, 11, 2}},
/*stride_height=*/2, /*stride_width=*/3,
/*filter_height=*/2, /*filter_width=*/2,
/*padding=*/kTfLitePaddingValid,
/*output=*/{TensorType_FLOAT32, {}},
/*indices=*/{TensorType_INT32, {}});
model.SetInput({0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6,
0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0,
0, 16, 0, 0, 0, 0, 0, 0, 11, 0, 0, 12, 0, 0, 0, 14,
13, 0, 0, 0, 0, 0, 0, 0, 17, 18, 0, 0, 0, 30, 0, 20,
0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 24, 0, 0,
0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 22, 0, 0, 0, 0,
0, 0, 23, 0, 0, 0, 0, 0, 0, 0, 27, 28, 0, 0, 0, 0,
29, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 25, 26, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4, 2}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 29, 0, 31, 32}));
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({2, 2, 4, 2}));
EXPECT_THAT(model.GetIndices(),
ElementsAreArray({2, 23, 8, 9, 12, 15, 40, 43, 44, 47, 72,
75, 80, 79, 62, 65, 0, 1, 30, 7, 14, 35,
42, 21, 68, 69, 50, 51, 56, 57, 86, 63}));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -22,6 +22,8 @@ namespace custom {
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
resolver->AddCustom("MaxUnpooling2D",
tflite::ops::custom::RegisterMaxUnpooling2D());
resolver->AddCustom("MaxPoolWithArgmax",
tflite::ops::custom::RegisterMaxPoolWithArgmax());
}
} // namespace custom

View File

@ -23,6 +23,7 @@ namespace ops {
namespace custom {
TfLiteRegistration* RegisterMaxUnpooling2D();
TfLiteRegistration* RegisterMaxPoolWithArgmax();
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);

View File

@ -239,6 +239,7 @@ cc_library(
"//tensorflow/lite/kernels:reference_ops",
"//tensorflow/lite/kernels:test_delegate_providers_lib",
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
"//tensorflow/lite/kernels/perception:perception_ops",
"//tensorflow/lite/tools/evaluation:utils",
] + select({
"//tensorflow:ios": [],

View File

@ -26,6 +26,7 @@ limitations under the License.
#endif
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
#include "tensorflow/lite/kernels/perception/perception_ops.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register_ref.h"
#include "tensorflow/lite/kernels/test_delegate_providers.h"
@ -370,6 +371,7 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
tflite::ops::custom::AddPerceptionOps(buildinop_resolver_);
}
switch (delegate_type) {