Add MaxPoolWithArgmax as a TFLite custom op
This kernel is different than MaxPool as following: - Returns both pooling and argmax results. - Parameters are retrieved from custom option instead of builtin op data. PiperOrigin-RevId: 347808061 Change-Id: I064b62c5313ba3860f6f52d965747fa13d3042b1
This commit is contained in:
parent
305984e9e7
commit
2091277ccc
@ -10,6 +10,7 @@ package(
|
||||
cc_library(
|
||||
name = "perception_ops",
|
||||
srcs = [
|
||||
"max_pool_with_argmax.cc",
|
||||
"max_unpooling_2d.cc",
|
||||
"perception_ops.cc",
|
||||
],
|
||||
@ -23,8 +24,11 @@ cc_library(
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:padding",
|
||||
"//tensorflow/lite/kernels/internal:common",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite/kernels/internal:tensor_utils",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -32,6 +36,7 @@ cc_test(
|
||||
name = "perception_ops_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"max_pool_with_argmax_test.cc",
|
||||
"max_unpooling_2d_test.cc",
|
||||
],
|
||||
deps = [
|
||||
@ -40,5 +45,6 @@ cc_test(
|
||||
"//tensorflow/lite/kernels:test_main",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
248
tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
Normal file
248
tensorflow/lite/kernels/perception/max_pool_with_argmax.cc
Normal file
@ -0,0 +1,248 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace max_pool_with_argmax {
|
||||
namespace {
|
||||
// TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
|
||||
// this op to a builtin op.
|
||||
template <typename T>
|
||||
inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
|
||||
const RuntimeShape& output_shape, const T* input_data,
|
||||
T* output_data, int32_t* indices_data) {
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
|
||||
const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
|
||||
const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
|
||||
const int32_t input_height = input_shape.Dims(1);
|
||||
const int32_t input_width = input_shape.Dims(2);
|
||||
const int32_t output_height = output_shape.Dims(1);
|
||||
const int32_t output_width = output_shape.Dims(2);
|
||||
const int32_t stride_height = params.stride_height;
|
||||
const int32_t stride_width = params.stride_width;
|
||||
for (int32_t batch = 0; batch < batches; ++batch) {
|
||||
for (int32_t out_y = 0; out_y < output_height; ++out_y) {
|
||||
for (int32_t out_x = 0; out_x < output_width; ++out_x) {
|
||||
for (int32_t channel = 0; channel < depth; ++channel) {
|
||||
const int32_t in_x_origin =
|
||||
(out_x * stride_width) - params.padding_values.width;
|
||||
const int32_t in_y_origin =
|
||||
(out_y * stride_height) - params.padding_values.height;
|
||||
// Compute the boundaries of the filter region clamped so as to
|
||||
// ensure that the filter window fits in the input array.
|
||||
const int32_t filter_x_start = std::max(0, -in_x_origin);
|
||||
const int32_t filter_x_end =
|
||||
std::min(params.filter_width, input_width - in_x_origin);
|
||||
const int32_t filter_y_start = std::max(0, -in_y_origin);
|
||||
const int32_t filter_y_end =
|
||||
std::min(params.filter_height, input_height - in_y_origin);
|
||||
float max = std::numeric_limits<float>::lowest();
|
||||
int32_t max_x = 0;
|
||||
int32_t max_y = 0;
|
||||
|
||||
for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
|
||||
++filter_y) {
|
||||
for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
|
||||
++filter_x) {
|
||||
const int32_t in_x = in_x_origin + filter_x;
|
||||
const int32_t in_y = in_y_origin + filter_y;
|
||||
float cur =
|
||||
input_data[Offset(input_shape, batch, in_y, in_x, channel)];
|
||||
if (cur > max) {
|
||||
max = cur;
|
||||
max_x = in_x;
|
||||
max_y = in_y;
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t output_idx =
|
||||
Offset(output_shape, batch, out_y, out_x, channel);
|
||||
output_data[output_idx] = ActivationFunctionWithMinMax(
|
||||
max, params.float_activation_min, params.float_activation_max);
|
||||
indices_data[output_idx] =
|
||||
(max_y * input_width + max_x) * depth + channel;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
constexpr int kDataInputTensor = 0;
|
||||
constexpr int kDataOutputTensor = 0;
|
||||
constexpr int kIndicesOutputTensor = 1;
|
||||
|
||||
constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
|
||||
constexpr const char kPoolSizeStr[] = "ksize";
|
||||
constexpr const char kStridesStr[] = "strides";
|
||||
constexpr const char kPaddingStr[] = "padding";
|
||||
constexpr const char kPaddingSameStr[] = "SAME";
|
||||
constexpr const char kPaddingValidStr[] = "VALID";
|
||||
|
||||
struct OpData {
|
||||
TfLitePoolParams params;
|
||||
bool include_batch_in_index;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
const flexbuffers::Map& m =
|
||||
flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
|
||||
.AsMap();
|
||||
|
||||
OpData* op_data = new OpData;
|
||||
op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
|
||||
op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
|
||||
|
||||
const std::string padding = m[kPaddingStr].AsString().str();
|
||||
if (padding == kPaddingValidStr) {
|
||||
op_data->params.padding = kTfLitePaddingValid;
|
||||
} else if (padding == kPaddingSameStr) {
|
||||
op_data->params.padding = kTfLitePaddingSame;
|
||||
} else {
|
||||
op_data->params.padding = kTfLitePaddingUnknown;
|
||||
}
|
||||
|
||||
// The first and last element of pool_size are always 1.
|
||||
const auto pool_size = m[kPoolSizeStr].AsTypedVector();
|
||||
TFLITE_CHECK_EQ(pool_size.size(), 4);
|
||||
TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
|
||||
TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
|
||||
op_data->params.filter_height = pool_size[1].AsInt32();
|
||||
op_data->params.filter_width = pool_size[2].AsInt32();
|
||||
|
||||
// The first and last element of strides are always 1.
|
||||
const auto strides = m[kStridesStr].AsTypedVector();
|
||||
TFLITE_CHECK_EQ(strides.size(), 4);
|
||||
TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
|
||||
TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
|
||||
op_data->params.stride_height = strides[1].AsInt32();
|
||||
op_data->params.stride_width = strides[2].AsInt32();
|
||||
|
||||
return op_data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
TfLiteTensor *output, *indices;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kDataOutputTensor, &output));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kDataInputTensor, &input));
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
|
||||
TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
|
||||
TF_LITE_ENSURE_MSG(
|
||||
context, !op_data->include_batch_in_index,
|
||||
"Include batch dimension in flattened index is not yet supported.");
|
||||
|
||||
int batches = input->dims->data[0];
|
||||
int height = input->dims->data[1];
|
||||
int width = input->dims->data[2];
|
||||
int channels_out = input->dims->data[3];
|
||||
|
||||
// Matching GetWindowedOutputSize in TensorFlow.
|
||||
int out_width, out_height;
|
||||
op_data->params.computed.padding = ComputePaddingHeightWidth(
|
||||
op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
|
||||
width, op_data->params.filter_height, op_data->params.filter_width,
|
||||
op_data->params.padding, &out_height, &out_width);
|
||||
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
|
||||
output_size->data[0] = batches;
|
||||
output_size->data[1] = out_height;
|
||||
output_size->data[2] = out_width;
|
||||
output_size->data[3] = channels_out;
|
||||
TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
float activation_min, activation_max;
|
||||
CalculateActivationRange(op_data->params.activation, &activation_min,
|
||||
&activation_max);
|
||||
|
||||
tflite::PoolParams op_params;
|
||||
op_params.stride_height = op_data->params.stride_height;
|
||||
op_params.stride_width = op_data->params.stride_width;
|
||||
op_params.filter_height = op_data->params.filter_height;
|
||||
op_params.filter_width = op_data->params.filter_width;
|
||||
op_params.padding_values.height = op_data->params.computed.padding.height;
|
||||
op_params.padding_values.width = op_data->params.computed.padding.width;
|
||||
op_params.float_activation_min = activation_min;
|
||||
op_params.float_activation_max = activation_max;
|
||||
|
||||
TfLiteTensor *output, *indices;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kDataOutputTensor, &output));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetInputSafe(context, node, kDataInputTensor, &input));
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
|
||||
GetTensorData<float>(input), GetTensorData<float>(output),
|
||||
GetTensorData<int32_t>(indices));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace max_pool_with_argmax
|
||||
|
||||
TfLiteRegistration* RegisterMaxPoolWithArgmax() {
|
||||
static TfLiteRegistration r = {
|
||||
max_pool_with_argmax::Init, max_pool_with_argmax::Free,
|
||||
max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
298
tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc
Normal file
298
tensorflow/lite/kernels/perception/max_pool_with_argmax_test.cc
Normal file
@ -0,0 +1,298 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/perception/perception_ops.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
namespace {
|
||||
|
||||
using testing::ElementsAreArray;
|
||||
|
||||
class MaxpoolingWithArgMaxOpModel : public SingleOpModel {
|
||||
public:
|
||||
MaxpoolingWithArgMaxOpModel(const TensorData& input, int stride_height,
|
||||
int stride_width, int filter_height,
|
||||
int filter_width, TfLitePadding padding,
|
||||
const TensorData& output,
|
||||
const TensorData& indices) {
|
||||
input_ = AddInput(input);
|
||||
output_ = AddOutput(output);
|
||||
indices_ = AddOutput(indices);
|
||||
|
||||
std::vector<uint8_t> custom_option = CreateCustomOptions(
|
||||
stride_height, stride_width, filter_height, filter_width, padding);
|
||||
SetCustomOp("MaxPoolWithArgmax", custom_option, RegisterMaxPoolWithArgmax);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
void SetInput(const std::vector<float>& data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
std::vector<int32_t> GetIndices() { return ExtractVector<int32_t>(indices_); }
|
||||
|
||||
std::vector<int> GetIndicesShape() { return GetTensorShape(indices_); }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int output_;
|
||||
int indices_;
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> CreateCustomOptions(int stride_height, int stride_width,
|
||||
int filter_height, int filter_width,
|
||||
TfLitePadding padding) {
|
||||
auto flex_builder = std::make_unique<flexbuffers::Builder>();
|
||||
size_t map_start = flex_builder->StartMap();
|
||||
flex_builder->Bool("include_batch_in_index", false);
|
||||
if (padding == kTfLitePaddingValid) {
|
||||
flex_builder->String("padding", "VALID");
|
||||
} else {
|
||||
flex_builder->String("padding", "SAME");
|
||||
}
|
||||
|
||||
auto start = flex_builder->StartVector("ksize");
|
||||
flex_builder->Add(1);
|
||||
flex_builder->Add(filter_height);
|
||||
flex_builder->Add(filter_width);
|
||||
flex_builder->Add(1);
|
||||
flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||
|
||||
auto strides_start = flex_builder->StartVector("strides");
|
||||
flex_builder->Add(1);
|
||||
flex_builder->Add(stride_height);
|
||||
flex_builder->Add(stride_width);
|
||||
flex_builder->Add(1);
|
||||
flex_builder->EndVector(strides_start, /*typed=*/true, /*fixed=*/false);
|
||||
|
||||
flex_builder->EndMap(map_start);
|
||||
flex_builder->Finish();
|
||||
return flex_builder->GetBuffer();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, UnsupportedInt64Test) {
|
||||
EXPECT_DEATH_IF_SUPPORTED(MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
|
||||
/*stride_height=*/2, /*stride_width=*/2,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingSame,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT64, {}});
|
||||
, "indices->type == kTfLiteInt32 was not true.");
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, SimpleTest) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
|
||||
/*stride_height=*/2, /*stride_width=*/2,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingSame,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
model.SetInput({0, 13, 2, 0, 0, 1, 4, 0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({13, 4}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 1, 2, 1}));
|
||||
EXPECT_THAT(model.GetIndices(), ElementsAreArray({1, 6}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, Strides2x1Test) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
|
||||
/*stride_height=*/2, /*stride_width=*/1,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingSame,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput({1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 2}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 4, 0, 4, 7, 8, 0, 8}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 2}));
|
||||
EXPECT_THAT(model.GetIndices(),
|
||||
ElementsAreArray({4, 7, 2, 7, 12, 15, 10, 15}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, Strides2x2Test) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 8, 1}},
|
||||
/*stride_height=*/2, /*stride_width=*/2,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingSame,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0,
|
||||
0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 8});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 4, 0, 0, 7, 6, 8}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 4, 1}));
|
||||
EXPECT_THAT(model.GetIndices(),
|
||||
ElementsAreArray({0, 10, 13, 6, 16, 27, 20, 31}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, Strides2x2UnfitTest) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 7, 1}},
|
||||
/*stride_height=*/2, /*stride_width=*/2,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingSame,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput({1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 4,
|
||||
0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 7});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 2, 4, 0, 0, 5, 7}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 4, 1}));
|
||||
EXPECT_THAT(model.GetIndices(),
|
||||
ElementsAreArray({0, 10, 5, 13, 14, 16, 19, 27}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, PaddingValidTest) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 5, 1}},
|
||||
/*stride_height=*/2, /*stride_width=*/2,
|
||||
/*filter_height=*/2, /*filter_width=*/3,
|
||||
/*padding=*/kTfLitePaddingValid,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput(
|
||||
{0, 0, 0, 0, 0, 0, 7, 0, 0, 10, 0, 0, 0, 0, 0, 0, 20, 0, 0, 19});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({7, 10, 20, 19}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 1}));
|
||||
EXPECT_THAT(model.GetIndices(), ElementsAreArray({6, 9, 16, 19}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, PaddingValidUnfitTest) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 6, 1}},
|
||||
/*stride_height=*/2, /*stride_width=*/2,
|
||||
/*filter_height=*/2, /*filter_width=*/3,
|
||||
/*padding=*/kTfLitePaddingValid,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput({0, 0, 0, 0, 0, 0, 7, 0, 0, 10, 0, 0,
|
||||
0, 0, 0, 0, 20, 0, 0, 19, 24, 1, 2, 44});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({7, 10, 24, 24}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({1, 2, 2, 1}));
|
||||
EXPECT_THAT(model.GetIndices(), ElementsAreArray({6, 9, 20, 20}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, InputWithBatchTest) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {2, 4, 12, 2}},
|
||||
/*stride_height=*/2, /*stride_width=*/3,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingSame,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput({0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6,
|
||||
0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0,
|
||||
0, 16, 0, 0, 0, 0, 0, 0, 11, 0, 0, 12, 0, 0, 0, 14,
|
||||
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
17, 18, 0, 0, 0, 30, 0, 20, 0, 0, 0, 0, 0, 0, 21, 0,
|
||||
0, 0, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0,
|
||||
0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0,
|
||||
0, 0, 27, 28, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 32,
|
||||
0, 0, 0, 0, 25, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4, 2}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({1, 0, 3, 4, 5, 6, 9, 8, 11, 12, 13,
|
||||
14, 15, 0, 0, 0, 17, 18, 19, 20, 21, 0,
|
||||
23, 24, 27, 28, 29, 0, 31, 32, 25, 26}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({2, 2, 4, 2}));
|
||||
EXPECT_THAT(model.GetIndices(),
|
||||
ElementsAreArray({2, 1, 8, 9, 12, 15, 44, 43, 72, 75, 80,
|
||||
79, 62, 61, 66, 67, 0, 1, 30, 7, 14, 13,
|
||||
42, 21, 50, 51, 56, 55, 86, 63, 68, 69}));
|
||||
}
|
||||
|
||||
TEST(MaxpoolWithArgMaxTest, InputWithBatchAndPaddingValidTest) {
|
||||
MaxpoolingWithArgMaxOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {2, 4, 11, 2}},
|
||||
/*stride_height=*/2, /*stride_width=*/3,
|
||||
/*filter_height=*/2, /*filter_width=*/2,
|
||||
/*padding=*/kTfLitePaddingValid,
|
||||
/*output=*/{TensorType_FLOAT32, {}},
|
||||
/*indices=*/{TensorType_INT32, {}});
|
||||
|
||||
model.SetInput({0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6,
|
||||
0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0,
|
||||
0, 16, 0, 0, 0, 0, 0, 0, 11, 0, 0, 12, 0, 0, 0, 14,
|
||||
13, 0, 0, 0, 0, 0, 0, 0, 17, 18, 0, 0, 0, 30, 0, 20,
|
||||
0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 24, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 22, 0, 0, 0, 0,
|
||||
0, 0, 23, 0, 0, 0, 0, 0, 0, 0, 27, 28, 0, 0, 0, 0,
|
||||
29, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 25, 26, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4, 2}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
|
||||
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
||||
23, 24, 25, 26, 27, 28, 29, 0, 31, 32}));
|
||||
EXPECT_THAT(model.GetIndicesShape(), ElementsAreArray({2, 2, 4, 2}));
|
||||
EXPECT_THAT(model.GetIndices(),
|
||||
ElementsAreArray({2, 23, 8, 9, 12, 15, 40, 43, 44, 47, 72,
|
||||
75, 80, 79, 62, 65, 0, 1, 30, 7, 14, 35,
|
||||
42, 21, 68, 69, 50, 51, 56, 57, 86, 63}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
@ -22,6 +22,8 @@ namespace custom {
|
||||
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
|
||||
resolver->AddCustom("MaxUnpooling2D",
|
||||
tflite::ops::custom::RegisterMaxUnpooling2D());
|
||||
resolver->AddCustom("MaxPoolWithArgmax",
|
||||
tflite::ops::custom::RegisterMaxPoolWithArgmax());
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
|
@ -23,6 +23,7 @@ namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* RegisterMaxUnpooling2D();
|
||||
TfLiteRegistration* RegisterMaxPoolWithArgmax();
|
||||
|
||||
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);
|
||||
|
||||
|
@ -239,6 +239,7 @@ cc_library(
|
||||
"//tensorflow/lite/kernels:reference_ops",
|
||||
"//tensorflow/lite/kernels:test_delegate_providers_lib",
|
||||
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
|
||||
"//tensorflow/lite/kernels/perception:perception_ops",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
] + select({
|
||||
"//tensorflow:ios": [],
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#endif
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/hashtable/hashtable_ops.h"
|
||||
#include "tensorflow/lite/kernels/perception/perception_ops.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/register_ref.h"
|
||||
#include "tensorflow/lite/kernels/test_delegate_providers.h"
|
||||
@ -370,6 +371,7 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
|
||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
|
||||
tflite::ops::custom::AddPerceptionOps(buildinop_resolver_);
|
||||
}
|
||||
|
||||
switch (delegate_type) {
|
||||
|
Loading…
Reference in New Issue
Block a user