Add DenseImageWarp custom op to TFLite
PiperOrigin-RevId: 351517211 Change-Id: Ibeb1e639017b3d0a35b98ba761e98021a8bb8958
This commit is contained in:
parent
670cc3fa48
commit
0acc4e3260
@ -14,6 +14,7 @@ package(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "perception_ops",
|
name = "perception_ops",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"dense_image_warp.cc",
|
||||||
"max_pool_with_argmax.cc",
|
"max_pool_with_argmax.cc",
|
||||||
"max_unpooling_2d.cc",
|
"max_unpooling_2d.cc",
|
||||||
"perception_ops.cc",
|
"perception_ops.cc",
|
||||||
@ -40,6 +41,7 @@ cc_test(
|
|||||||
name = "perception_ops_test",
|
name = "perception_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"dense_image_warp_test.cc",
|
||||||
"max_pool_with_argmax_test.cc",
|
"max_pool_with_argmax_test.cc",
|
||||||
"max_unpooling_2d_test.cc",
|
"max_unpooling_2d_test.cc",
|
||||||
],
|
],
|
||||||
|
149
tensorflow/lite/kernels/perception/dense_image_warp.cc
Normal file
149
tensorflow/lite/kernels/perception/dense_image_warp.cc
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/common.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/kernels/padding.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace custom {
|
||||||
|
namespace dense_image_warp {
|
||||||
|
|
||||||
|
constexpr int kInputTensor = 0;
|
||||||
|
constexpr int kFlowTensor = 1;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
inline void DenseImageWarp(const RuntimeShape& input_shape,
|
||||||
|
const float* input_data,
|
||||||
|
const RuntimeShape& flow_shape,
|
||||||
|
const float* flow_data, float* output_data) {
|
||||||
|
const int batches = MatchingDim(input_shape, 0, flow_shape, 0);
|
||||||
|
const int height = MatchingDim(input_shape, 1, flow_shape, 1);
|
||||||
|
const int width = MatchingDim(input_shape, 2, flow_shape, 2);
|
||||||
|
const int channels = input_shape.Dims(3);
|
||||||
|
TFLITE_CHECK_EQ(flow_shape.Dims(3), 2);
|
||||||
|
|
||||||
|
// Max values to make sure the indexes are not out of bound.
|
||||||
|
const int max_floor_y = height - 2;
|
||||||
|
const int max_floor_x = width - 2;
|
||||||
|
|
||||||
|
for (int batch = 0; batch < batches; ++batch) {
|
||||||
|
for (int in_y = 0; in_y < height; ++in_y) {
|
||||||
|
for (int in_x = 0; in_x < width; ++in_x) {
|
||||||
|
float querry_point_y =
|
||||||
|
in_y - flow_data[Offset(flow_shape, batch, in_y, in_x, 0)];
|
||||||
|
float querry_point_x =
|
||||||
|
in_x - flow_data[Offset(flow_shape, batch, in_y, in_x, 1)];
|
||||||
|
|
||||||
|
int floor_y =
|
||||||
|
std::min(std::max(0, static_cast<int>(std::floor(querry_point_y))),
|
||||||
|
max_floor_y);
|
||||||
|
int floor_x =
|
||||||
|
std::min(std::max(0, static_cast<int>(std::floor(querry_point_x))),
|
||||||
|
max_floor_x);
|
||||||
|
float alpha_y =
|
||||||
|
std::min(std::max(0.0f, querry_point_y - floor_y), 1.0f);
|
||||||
|
float alpha_x =
|
||||||
|
std::min(std::max(0.0f, querry_point_x - floor_x), 1.0f);
|
||||||
|
|
||||||
|
for (int c = 0; c < channels; ++c) {
|
||||||
|
float top_left =
|
||||||
|
input_data[Offset(input_shape, batch, floor_y, floor_x, c)];
|
||||||
|
float top_right =
|
||||||
|
input_data[Offset(input_shape, batch, floor_y, floor_x + 1, c)];
|
||||||
|
float bottom_left =
|
||||||
|
input_data[Offset(input_shape, batch, floor_y + 1, floor_x, c)];
|
||||||
|
float bottom_right = input_data[Offset(input_shape, batch,
|
||||||
|
floor_y + 1, floor_x + 1, c)];
|
||||||
|
|
||||||
|
float interp_top = alpha_x * (top_right - top_left) + top_left;
|
||||||
|
float interp_bottom =
|
||||||
|
alpha_x * (bottom_right - bottom_left) + bottom_left;
|
||||||
|
float interp = alpha_y * (interp_bottom - interp_top) + interp_top;
|
||||||
|
output_data[Offset(input_shape, batch, in_y, in_x, c)] = interp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
// Check inputs and output.
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* flow = GetInput(context, node, kFlowTensor);
|
||||||
|
TF_LITE_ENSURE(context, flow != nullptr);
|
||||||
|
|
||||||
|
// Check types.
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
|
||||||
|
TF_LITE_ENSURE_TYPES_EQ(context, flow->type, kTfLiteFloat32);
|
||||||
|
|
||||||
|
// Check shapes. If input has shape of [b, h, w, c], flow must have shape of
|
||||||
|
// [b, h, w, 2].
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(flow), 4);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||||
|
const RuntimeShape input_shape = GetTensorShape(input);
|
||||||
|
const RuntimeShape flow_shape = GetTensorShape(flow);
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_shape.Dims(0), flow_shape.Dims(0));
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_shape.Dims(1), flow_shape.Dims(1));
|
||||||
|
TF_LITE_ENSURE_EQ(context, input_shape.Dims(2), flow_shape.Dims(2));
|
||||||
|
TF_LITE_ENSURE_MSG(context, input_shape.Dims(1) >= 2,
|
||||||
|
"Image height must be at least 2.");
|
||||||
|
TF_LITE_ENSURE_MSG(context, input_shape.Dims(2) >= 2,
|
||||||
|
"Image width must be at least 2.");
|
||||||
|
TF_LITE_ENSURE_MSG(context, flow_shape.Dims(3) == 2,
|
||||||
|
"The last dimension of flow tensor must be 2.");
|
||||||
|
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
|
||||||
|
return context->ResizeTensor(context, output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* flow = GetInput(context, node, kFlowTensor);
|
||||||
|
TF_LITE_ENSURE(context, flow != nullptr);
|
||||||
|
|
||||||
|
DenseImageWarp(GetTensorShape(input), GetTensorData<float>(input),
|
||||||
|
GetTensorShape(flow), GetTensorData<float>(flow),
|
||||||
|
GetTensorData<float>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dense_image_warp
|
||||||
|
|
||||||
|
TfLiteRegistration* RegisterDenseImageWarp() {
|
||||||
|
static TfLiteRegistration reg = {/*init=*/nullptr,
|
||||||
|
/*free=*/nullptr, dense_image_warp::Prepare,
|
||||||
|
dense_image_warp::Eval};
|
||||||
|
return ®
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace custom
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
138
tensorflow/lite/kernels/perception/dense_image_warp_test.cc
Normal file
138
tensorflow/lite/kernels/perception/dense_image_warp_test.cc
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/perception/perception_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/testing/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace custom {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using testing::ElementsAreArray;
|
||||||
|
|
||||||
|
class DenseImageWarpOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
DenseImageWarpOpModel(const TensorData& input, const TensorData& flow,
|
||||||
|
const TensorData& output) {
|
||||||
|
input_ = AddInput(input);
|
||||||
|
flow_ = AddInput(flow);
|
||||||
|
output_ = AddOutput(output);
|
||||||
|
|
||||||
|
std::vector<uint8_t> custom_option;
|
||||||
|
SetCustomOp("DenseImageWarp", custom_option, RegisterDenseImageWarp);
|
||||||
|
BuildInterpreter({GetShape(input_), GetShape(flow_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInput(const std::vector<float>& data) {
|
||||||
|
PopulateTensor(input_, data);
|
||||||
|
}
|
||||||
|
void SetFlow(const std::vector<float>& data) { PopulateTensor(flow_, data); }
|
||||||
|
|
||||||
|
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||||
|
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int input_;
|
||||||
|
int flow_;
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(DenseImageWarpOpTest, MismatchedSizeTest) {
|
||||||
|
EXPECT_DEATH_IF_SUPPORTED(
|
||||||
|
DenseImageWarpOpModel model(
|
||||||
|
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||||
|
/*flow=*/{TensorType_FLOAT32, {1, 4, 2, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32, {}});
|
||||||
|
, "input_shape.Dims.2. != flow_shape.Dims.2. .4 != 2.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DenseImageWarpOpTest, WrongFlowSizeTest) {
|
||||||
|
EXPECT_DEATH_IF_SUPPORTED(DenseImageWarpOpModel model(
|
||||||
|
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||||
|
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||||
|
/*output=*/{TensorType_FLOAT32, {}});
|
||||||
|
, "The last dimension of flow tensor must be 2.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DenseImageWarpOpTest, SimpleTest) {
|
||||||
|
DenseImageWarpOpModel model(
|
||||||
|
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||||
|
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32, {}});
|
||||||
|
model.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
||||||
|
model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
|
||||||
|
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0, 0, 0, 3, 3, 0, 3, 2, 0,
|
||||||
|
0, 3, 12, 15, 12, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DenseImageWarpOpTest, RoundTest) {
|
||||||
|
DenseImageWarpOpModel model(
|
||||||
|
/*input=*/{TensorType_FLOAT32, {1, 4, 4, 1}},
|
||||||
|
/*flow=*/{TensorType_FLOAT32, {1, 4, 4, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32, {}});
|
||||||
|
model.SetInput({0.2, 1.5, 2.4, 3.5, 4.6, 5.1, 6.3, 7.2, 8.5, 9.6, 10.9, 11.6,
|
||||||
|
12.8, 13.2, 14.4, 15.5});
|
||||||
|
model.SetFlow({4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
|
||||||
|
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({0.2, 0.2, 0.2, 0.2, 3.5, 3.5, 0.2, 3.5, 2.4,
|
||||||
|
0.2, 0.2, 3.5, 12.8, 15.5, 12.8, 0.2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DenseImageWarpOpTest, WithBatchandChannelTest) {
|
||||||
|
DenseImageWarpOpModel model(
|
||||||
|
/*input=*/{TensorType_FLOAT32, {2, 4, 4, 3}},
|
||||||
|
/*flow=*/{TensorType_FLOAT32, {2, 4, 4, 2}},
|
||||||
|
/*output=*/{TensorType_FLOAT32, {}});
|
||||||
|
|
||||||
|
std::vector<float> input_data;
|
||||||
|
for (int i = 0; i < 96; ++i) input_data.push_back(i);
|
||||||
|
model.SetInput(input_data);
|
||||||
|
model.SetFlow({2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
|
||||||
|
4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0,
|
||||||
|
2, -2, 10, 6, 4, 4, 2, -4, -4, 10, -4, -4, -2, 6, 4, 6,
|
||||||
|
4, 10, 6, 10, 4, 2, 6, 6, 10, -4, 2, -2, 6, 8, 6, 0});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4, 4, 3}));
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetOutput(),
|
||||||
|
ElementsAreArray({6, 7, 8, 0, 1, 2, 0, 1, 2, 9, 10, 11, 36, 37,
|
||||||
|
38, 45, 46, 47, 36, 37, 38, 0, 1, 2, 0, 1, 2, 0,
|
||||||
|
1, 2, 0, 1, 2, 0, 1, 2, 9, 10, 11, 21, 22, 23,
|
||||||
|
0, 1, 2, 9, 10, 11, 54, 55, 56, 48, 49, 50, 48, 49,
|
||||||
|
50, 57, 58, 59, 84, 85, 86, 93, 94, 95, 84, 85, 86, 48,
|
||||||
|
49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50, 48, 49, 50,
|
||||||
|
57, 58, 59, 69, 70, 71, 48, 49, 50, 57, 58, 59}));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace custom
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
@ -20,10 +20,12 @@ namespace ops {
|
|||||||
namespace custom {
|
namespace custom {
|
||||||
|
|
||||||
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
|
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver) {
|
||||||
resolver->AddCustom("MaxUnpooling2D",
|
resolver->AddCustom("DenseImageWarp",
|
||||||
tflite::ops::custom::RegisterMaxUnpooling2D());
|
tflite::ops::custom::RegisterDenseImageWarp());
|
||||||
resolver->AddCustom("MaxPoolWithArgmax",
|
resolver->AddCustom("MaxPoolWithArgmax",
|
||||||
tflite::ops::custom::RegisterMaxPoolWithArgmax());
|
tflite::ops::custom::RegisterMaxPoolWithArgmax());
|
||||||
|
resolver->AddCustom("MaxUnpooling2D",
|
||||||
|
tflite::ops::custom::RegisterMaxUnpooling2D());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace custom
|
} // namespace custom
|
||||||
|
@ -22,8 +22,9 @@ namespace tflite {
|
|||||||
namespace ops {
|
namespace ops {
|
||||||
namespace custom {
|
namespace custom {
|
||||||
|
|
||||||
TfLiteRegistration* RegisterMaxUnpooling2D();
|
TfLiteRegistration* RegisterDenseImageWarp();
|
||||||
TfLiteRegistration* RegisterMaxPoolWithArgmax();
|
TfLiteRegistration* RegisterMaxPoolWithArgmax();
|
||||||
|
TfLiteRegistration* RegisterMaxUnpooling2D();
|
||||||
|
|
||||||
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);
|
extern "C" void AddPerceptionOps(::tflite::MutableOpResolver* resolver);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user