Add DenseImageWarp custom op to TFLite
PiperOrigin-RevId: 351517211 Change-Id: Ibeb1e639017b3d0a35b98ba761e98021a8bb8958
This commit is contained in:
parent
670cc3fa48
commit
0acc4e3260
@ -14,6 +14,7 @@ package(
|
||||
cc_library(
|
||||
name = "perception_ops",
|
||||
srcs = [
|
||||
"dense_image_warp.cc",
|
||||
"max_pool_with_argmax.cc",
|
||||
"max_unpooling_2d.cc",
|
||||
"perception_ops.cc",
|
||||
@ -40,6 +41,7 @@ cc_test(
|
||||
name = "perception_ops_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"dense_image_warp_test.cc",
|
||||
"max_pool_with_argmax_test.cc",
|
||||
"max_unpooling_2d_test.cc",
|
||||
],
|
||||
|
149
tensorflow/lite/kernels/perception/dense_image_warp.cc
Normal file
149
tensorflow/lite/kernels/perception/dense_image_warp.cc
Normal file
@ -0,0 +1,149 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace dense_image_warp {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kFlowTensor = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
inline void DenseImageWarp(const RuntimeShape& input_shape,
|
||||
const float* input_data,
|
||||
const RuntimeShape& flow_shape,
|
||||
const float* flow_data, float* output_data) {
|
||||
const int batches = MatchingDim(input_shape, 0, flow_shape, 0);
|
||||
const int height = MatchingDim(input_shape, 1, flow_shape, 1);
|
||||
const int width = MatchingDim(input_shape, 2, flow_shape, 2);
|
||||
const int channels = input_shape.Dims(3);
|
||||
TFLITE_CHECK_EQ(flow_shape.Dims(3), 2);
|
||||
|
||||
// Max values to make sure the indexes are not out of bound.
|
||||
const int max_floor_y = height - 2;
|
||||
const int max_floor_x = width - 2;
|
||||
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int in_y = 0; in_y < height; ++in_y) {
|
||||
for (int in_x = 0; in_x < width; ++in_x) {
|
||||
float querry_point_y =
|
||||
in_y - flow_data[Offset(flow_shape, batch, in_y, in_x, 0)];
|
||||
float querry_point_x =
|
||||
in_x - flow_data[Offset(flow_shape, batch, in_y, in_x, 1)];
|
||||
|
||||
int floor_y =
|
||||
std::min(std::max(0, static_cast<int>(std::floor(querry_point_y))),
|
||||
max_floor_y);
|
||||
int floor_x =
|
||||
std::min(std::max(0, static_cast<int>(std::floor(querry_point_x))),
|
||||
max_floor_x);
|
||||
float alpha_y =
|
||||
std::min(std::max(0.0f, querry_point_y - floor_y), 1.0f);
|
||||
float alpha_x =
|
||||
std::min(std::max(0.0f, querry_point_x - floor_x), 1.0f);
|
||||
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
float top_left =
|
||||
input_data[Offset(input_shape, batch, floor_y, floor_x, c)];
|
||||
float top_right =
|
||||
input_data[Offset(input_shape, batch, floor_y, floor_x + 1, c)];
|
||||
float bottom_left =
|
||||
input_data[Offset(input_shape, batch, floor_y + 1, floor_x, c)];
|
||||
float bottom_right = input_data[Offset(input_shape, batch,
|
||||
floor_y + 1, floor_x + 1, c)];
|
||||
|
||||
float interp_top = alpha_x * (top_right - top_left) + top_left;
|
||||
float interp_bottom =
|
||||
alpha_x * (bottom_right - bottom_left) + bottom_left;
|
||||
float interp = alpha_y * (interp_bottom - interp_top) + interp_top;
|
||||
output_data[Offset(input_shape, batch, in_y, in_x, c)] = interp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Check inputs and output.
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* flow = GetInput(context, node, kFlowTensor);
|
||||
TF_LITE_ENSURE(context, flow != nullptr);
|
||||
|
||||
// Check types.
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, flow->type, kTfLiteFloat32);
|
||||
|
||||
// Check shapes. If input has shape of [b, h, w, c], flow must have shape of
|
||||
// [b, h, w, 2].
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(flow), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
const RuntimeShape input_shape = GetTensorShape(input);
|
||||
const RuntimeShape flow_shape = GetTensorShape(flow);
|
||||
TF_LITE_ENSURE_EQ(context, input_shape.Dims(0), flow_shape.Dims(0));
|
||||
TF_LITE_ENSURE_EQ(context, input_shape.Dims(1), flow_shape.Dims(1));
|
||||
TF_LITE_ENSURE_EQ(context, input_shape.Dims(2), flow_shape.Dims(2));
|
||||
TF_LITE_ENSURE_MSG(context, input_shape.Dims(1) >= 2,
|
||||
"Image height must be at least 2.");
|
||||
TF_LITE_ENSURE_MSG(context, input_shape.Dims(2) >= 2,
|
||||
"Image width must be at least 2.");
|
||||
TF_LITE_ENSURE_MSG(context, flow_shape.Dims(3) == 2,
|
||||
"The last dimension of flow tensor must be 2.");
|
||||
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TF_LITE_ENSURE(context, input != nullptr);
|
||||
const TfLiteTensor* flow = GetInput(context, node, kFlowTensor);
|
||||
TF_LITE_ENSURE(context, flow != nullptr);
|
||||
|
||||
DenseImageWarp(GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(flow), GetTensorData<float>(flow),
|
||||
GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace dense_image_warp
|
||||
|
||||
TfLiteRegistration* RegisterDenseImageWarp() {
|
||||
static TfLiteRegistration reg = {/*init=*/nullptr,
|
||||
/*free=*/nullptr, dense_image_warp::Prepare,
|
||||
dense_image_warp::Eval};
|
||||
return ®
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
138
tensorflow/lite/kernels/perception/dense_image_warp_test.cc
Normal file
138
tensorflow/lite/kernels/perception/dense_image_warp_test.cc
Normal file
@ -0,0 +1,138 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/perception/perception_ops.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
namespace {
|
||||
|
||||
using testing::ElementsAreArray;
|
||||
|
||||
class DenseImageWarpOpModel : public SingleOpModel {
|
||||
public:
|
||||
DenseImageWarpOpModel(const TensorData& input, const TensorData& flow,
|
||||
const TensorData& output) {
|
||||
input_ = AddInput(input);
|
||||
flow_ = AddInput(flow);
|
||||
output_ = AddOutput(output);
|
||||
|
||||
std::vector<uint8_t> custom_option;
|
||||
SetCustomOp("DenseImageWarp", custom_option, RegisterDenseImageWarp);
|
||||
BuildInterpreter({GetShape(input_), GetShape(flow_)});
|
||||
}
|
||||
|
||||
void SetInput(const std::vector<float>& data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
void SetFlow(const std::vector<float>& data) { PopulateTensor(flow_, data); }
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int flow_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(DenseImageWarpOpTest, MismatchedSizeTest) {
|
||||
EXPECT_DEATH_IF_SUPPORTED(
|
||||
DenseImageWarpOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
/*flow=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
|
||||
/*output=*/{TensorType_FLOAT32, {}});
|
||||
, "input_shape.Dims.2. != flow_shape.Dims.2. .4 != 2.");
|
||||
}
|
||||
|
||||
TEST(DenseImageWarpOpTest, WrongFlowSizeTest) {
|
||||
EXPECT_DEATH_IF_SUPPORTED(DenseImageWarpOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
/*output=*/{TensorType_FLOAT32, {}});
|
||||
, "The last dimension of flow tensor must be 2.");
|
||||
}
|
||||
|
||||
TEST(DenseImageWarpOpTest, SimpleTest) {
|
||||
DenseImageWarpOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
/*output=*/{TensorType_FLOAT32, {}});
|
||||
model.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
||||
model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
|
||||
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 0, 3, 3, 0, 3, 2, 0,
|
||||
0, 3, 12, 15, 12, 0}));
|
||||
}
|
||||
|
||||
TEST(DenseImageWarpOpTest, RoundTest) {
|
||||
DenseImageWarpOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||
/*output=*/{TensorType_FLOAT32, {}});
|
||||
model.SetInput({0.2, 1.5, 2.4, 3.5, 4.6, 5.1, 6.3, 7.2, 8.5, 9.6, 10.9, 11.6,
|
||||
12.8, 13.2, 14.4, 15.5});
|
||||
model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
|
||||
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({0.2, 0.2, 0.2, 0.2, 3.5, 3.5, 0.2, 3.5, 2.4,
|
||||
0.2, 0.2, 3.5, 12.8, 15.5, 12.8, 0.2}));
|
||||
}
|
||||
|
||||
TEST(DenseImageWarpOpTest, WithBatchandChannelTest) {
|
||||
DenseImageWarpOpModel model(
|
||||
/*input=*/{TensorType_FLOAT32, {2, 4, 4, 3}},
|
||||
/*flow=*/{TensorType_FLOAT32, {2, 4, 4, 2}},
|
||||
/*output=*/{TensorType_FLOAT32, {}});
|
||||
|
||||
std::vector<float> input_data;
|
||||
for (int i = 0; i < 96; ++i) input_data.push_back(i);
|
||||
model.SetInput(input_data);
|
||||
model.SetFlow({2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
|
||||
4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
|
||||
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
|
||||
4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 4, 3}));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({6, 7, 8, 0, 1, 2, 0, 1, 2, 9, 10, 11, 36, 37,
|
||||
38, 45, 46, 47, 36, 37, 38, 0, 1, 2, 0, 1, 2, 0,
|
||||
1, 2, 0, 1, 2, 0, 1, 2, 9, 10, 11, 21, 22, 23,
|
||||
0, 1, 2, 9, 10, 11, 54, 55, 56, 48, 49, 50, 48, 49,
|
||||
50, 57, 58, 59, 84, 85, 86, 93, 94, 95, 84, 85, 86, 48,
|
||||
49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50,
|
||||
57, 58, 59, 69, 70, 71, 48, 49, 50, 57, 58, 59}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
@ -20,10 +20,12 @@ namespace ops {
|
||||
namespace custom {
|
||||
|
||||
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
|
||||
resolver->AddCustom("MaxUnpooling2D",
|
||||
tflite::ops::custom::RegisterMaxUnpooling2D());
|
||||
resolver->AddCustom("DenseImageWarp",
|
||||
tflite::ops::custom::RegisterDenseImageWarp());
|
||||
resolver->AddCustom("MaxPoolWithArgmax",
|
||||
tflite::ops::custom::RegisterMaxPoolWithArgmax());
|
||||
resolver->AddCustom("MaxUnpooling2D",
|
||||
tflite::ops::custom::RegisterMaxUnpooling2D());
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
|
@ -22,8 +22,9 @@ namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* RegisterMaxUnpooling2D();
|
||||
TfLiteRegistration* RegisterDenseImageWarp();
|
||||
TfLiteRegistration* RegisterMaxPoolWithArgmax();
|
||||
TfLiteRegistration* RegisterMaxUnpooling2D();
|
||||
|
||||
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user