Add DenseImageWarp custom op to TFLite

PiperOrigin-RevId: 351517211
Change-Id: Ibeb1e639017b3d0a35b98ba761e98021a8bb8958
This commit is contained in:
Thai Nguyen 2021-01-12 22:36:31 -08:00 committed by TensorFlower Gardener
parent 670cc3fa48
commit 0acc4e3260
5 changed files with 295 additions and 3 deletions

View File

@ -14,6 +14,7 @@ package(
cc_library(
name = "perception_ops",
srcs = [
"dense_image_warp.cc",
"max_pool_with_argmax.cc",
"max_unpooling_2d.cc",
"perception_ops.cc",
@ -40,6 +41,7 @@ cc_test(
name = "perception_ops_test",
size = "small",
srcs = [
"dense_image_warp_test.cc",
"max_pool_with_argmax_test.cc",
"max_unpooling_2d_test.cc",
],

View File

@ -0,0 +1,149 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace custom {
namespace dense_image_warp {
constexpr int kInputTensor = 0;
constexpr int kFlowTensor = 1;
constexpr int kOutputTensor = 0;
inline void DenseImageWarp(const RuntimeShape& input_shape,
const float* input_data,
const RuntimeShape& flow_shape,
const float* flow_data, float* output_data) {
const int batches = MatchingDim(input_shape, 0, flow_shape, 0);
const int height = MatchingDim(input_shape, 1, flow_shape, 1);
const int width = MatchingDim(input_shape, 2, flow_shape, 2);
const int channels = input_shape.Dims(3);
TFLITE_CHECK_EQ(flow_shape.Dims(3), 2);
// Max values to make sure the indexes are not out of bound.
const int max_floor_y = height - 2;
const int max_floor_x = width - 2;
for (int batch = 0; batch < batches; ++batch) {
for (int in_y = 0; in_y < height; ++in_y) {
for (int in_x = 0; in_x < width; ++in_x) {
float querry_point_y =
in_y - flow_data[Offset(flow_shape, batch, in_y, in_x, 0)];
float querry_point_x =
in_x - flow_data[Offset(flow_shape, batch, in_y, in_x, 1)];
int floor_y =
std::min(std::max(0, static_cast<int>(std::floor(querry_point_y))),
max_floor_y);
int floor_x =
std::min(std::max(0, static_cast<int>(std::floor(querry_point_x))),
max_floor_x);
float alpha_y =
std::min(std::max(0.0f, querry_point_y - floor_y), 1.0f);
float alpha_x =
std::min(std::max(0.0f, querry_point_x - floor_x), 1.0f);
for (int c = 0; c < channels; ++c) {
float top_left =
input_data[Offset(input_shape, batch, floor_y, floor_x, c)];
float top_right =
input_data[Offset(input_shape, batch, floor_y, floor_x + 1, c)];
float bottom_left =
input_data[Offset(input_shape, batch, floor_y + 1, floor_x, c)];
float bottom_right = input_data[Offset(input_shape, batch,
floor_y + 1, floor_x + 1, c)];
float interp_top = alpha_x * (top_right - top_left) + top_left;
float interp_bottom =
alpha_x * (bottom_right - bottom_left) + bottom_left;
float interp = alpha_y * (interp_bottom - interp_top) + interp_top;
output_data[Offset(input_shape, batch, in_y, in_x, c)] = interp;
}
}
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check inputs and output.
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* flow = GetInput(context, node, kFlowTensor);
TF_LITE_ENSURE(context, flow != nullptr);
// Check types.
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, flow->type, kTfLiteFloat32);
// Check shapes. If input has shape of [b, h, w, c], flow must have shape of
// [b, h, w, 2].
TF_LITE_ENSURE_EQ(context, NumDimensions(flow), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
const RuntimeShape input_shape = GetTensorShape(input);
const RuntimeShape flow_shape = GetTensorShape(flow);
TF_LITE_ENSURE_EQ(context, input_shape.Dims(0), flow_shape.Dims(0));
TF_LITE_ENSURE_EQ(context, input_shape.Dims(1), flow_shape.Dims(1));
TF_LITE_ENSURE_EQ(context, input_shape.Dims(2), flow_shape.Dims(2));
TF_LITE_ENSURE_MSG(context, input_shape.Dims(1) >= 2,
"Image height must be at least 2.");
TF_LITE_ENSURE_MSG(context, input_shape.Dims(2) >= 2,
"Image width must be at least 2.");
TF_LITE_ENSURE_MSG(context, flow_shape.Dims(3) == 2,
"The last dimension of flow tensor must be 2.");
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* flow = GetInput(context, node, kFlowTensor);
TF_LITE_ENSURE(context, flow != nullptr);
DenseImageWarp(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(flow), GetTensorData<float>(flow),
GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace dense_image_warp
TfLiteRegistration* RegisterDenseImageWarp() {
static TfLiteRegistration reg = {/*init=*/nullptr,
/*free=*/nullptr, dense_image_warp::Prepare,
dense_image_warp::Eval};
return &reg;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,138 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <vector>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/perception/perception_ops.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace {
using testing::ElementsAreArray;
class DenseImageWarpOpModel : public SingleOpModel {
public:
DenseImageWarpOpModel(const TensorData& input, const TensorData& flow,
const TensorData& output) {
input_ = AddInput(input);
flow_ = AddInput(flow);
output_ = AddOutput(output);
std::vector<uint8_t> custom_option;
SetCustomOp("DenseImageWarp", custom_option, RegisterDenseImageWarp);
BuildInterpreter({GetShape(input_), GetShape(flow_)});
}
void SetInput(const std::vector<float>& data) {
PopulateTensor(input_, data);
}
void SetFlow(const std::vector<float>& data) { PopulateTensor(flow_, data); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int input_;
int flow_;
int output_;
};
TEST(DenseImageWarpOpTest, MismatchedSizeTest) {
EXPECT_DEATH_IF_SUPPORTED(
DenseImageWarpOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
/*flow=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
/*output=*/{TensorType_FLOAT32, {}});
, "input_shape.Dims.2. != flow_shape.Dims.2. .4 != 2.");
}
TEST(DenseImageWarpOpTest, WrongFlowSizeTest) {
EXPECT_DEATH_IF_SUPPORTED(DenseImageWarpOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
/*output=*/{TensorType_FLOAT32, {}});
, "The last dimension of flow tensor must be 2.");
}
TEST(DenseImageWarpOpTest, SimpleTest) {
DenseImageWarpOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
/*output=*/{TensorType_FLOAT32, {}});
model.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 0, 3, 3, 0, 3, 2, 0,
0, 3, 12, 15, 12, 0}));
}
TEST(DenseImageWarpOpTest, RoundTest) {
DenseImageWarpOpModel model(
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
/*output=*/{TensorType_FLOAT32, {}});
model.SetInput({0.2, 1.5, 2.4, 3.5, 4.6, 5.1, 6.3, 7.2, 8.5, 9.6, 10.9, 11.6,
12.8, 13.2, 14.4, 15.5});
model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({0.2, 0.2, 0.2, 0.2, 3.5, 3.5, 0.2, 3.5, 2.4,
0.2, 0.2, 3.5, 12.8, 15.5, 12.8, 0.2}));
}
TEST(DenseImageWarpOpTest, WithBatchandChannelTest) {
DenseImageWarpOpModel model(
/*input=*/{TensorType_FLOAT32, {2, 4, 4, 3}},
/*flow=*/{TensorType_FLOAT32, {2, 4, 4, 2}},
/*output=*/{TensorType_FLOAT32, {}});
std::vector<float> input_data;
for (int i = 0; i < 96; ++i) input_data.push_back(i);
model.SetInput(input_data);
model.SetFlow({2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 4, 3}));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({6, 7, 8, 0, 1, 2, 0, 1, 2, 9, 10, 11, 36, 37,
38, 45, 46, 47, 36, 37, 38, 0, 1, 2, 0, 1, 2, 0,
1, 2, 0, 1, 2, 0, 1, 2, 9, 10, 11, 21, 22, 23,
0, 1, 2, 9, 10, 11, 54, 55, 56, 48, 49, 50, 48, 49,
50, 57, 58, 59, 84, 85, 86, 93, 94, 95, 84, 85, 86, 48,
49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50,
57, 58, 59, 69, 70, 71, 48, 49, 50, 57, 58, 59}));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -20,10 +20,12 @@ namespace ops {
namespace custom {
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
resolver->AddCustom("MaxUnpooling2D",
tflite::ops::custom::RegisterMaxUnpooling2D());
resolver->AddCustom("DenseImageWarp",
tflite::ops::custom::RegisterDenseImageWarp());
resolver->AddCustom("MaxPoolWithArgmax",
tflite::ops::custom::RegisterMaxPoolWithArgmax());
resolver->AddCustom("MaxUnpooling2D",
tflite::ops::custom::RegisterMaxUnpooling2D());
}
} // namespace custom

View File

@ -22,8 +22,9 @@ namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* RegisterMaxUnpooling2D();
TfLiteRegistration* RegisterDenseImageWarp();
TfLiteRegistration* RegisterMaxPoolWithArgmax();
TfLiteRegistration* RegisterMaxUnpooling2D();
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);