Add Conv3D reference kernel to TFLite

This kernel currently only supports float type and has filter with the same data format as TensorFlow Conv3D op.
The conversion change will be added in a follow-up cl.

PiperOrigin-RevId: 351726918
Change-Id: Id2cc805cb89a1da1fda656a2260516eda8bd9119
This commit is contained in:
Thai Nguyen 2021-01-13 21:31:40 -08:00 committed by TensorFlower Gardener
parent 977921340c
commit ac8f6785d3
17 changed files with 1416 additions and 537 deletions

View File

@ -159,6 +159,7 @@ typedef enum {
kTfLiteBuiltinCallOnce = 129,
kTfLiteBuiltinBroadcastTo = 130,
kTfLiteBuiltinRfft2d = 131,
kTfLiteBuiltinConv3d = 132,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -67,8 +67,8 @@ typedef struct {
typedef enum {
kTfLiteActNone = 0,
kTfLiteActRelu,
kTfLiteActReluN1To1, // min(max(-1, x), 1)
kTfLiteActRelu6, // min(max(0, x), 6)
kTfLiteActReluN1To1, // min(max(-1, x), 1)
kTfLiteActRelu6, // min(max(0, x), 6)
kTfLiteActTanh,
kTfLiteActSignBit,
kTfLiteActSigmoid,
@ -87,6 +87,17 @@ typedef struct {
int dilation_height_factor;
} TfLiteConvParams;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
int stride_depth;
int dilation_width_factor;
int dilation_height_factor;
int dilation_depth_factor;
TfLiteFusedActivation activation;
} TfLiteConv3DParams;
typedef struct {
TfLitePadding padding;
int stride_width;

View File

@ -770,6 +770,23 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_CONV_3D: {
auto params = safe_allocator.Allocate<TfLiteConv3DParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* conv3d_params = op->builtin_options_as_Conv3DOptions()) {
params->padding = ConvertPadding(conv3d_params->padding());
params->activation =
ConvertActivation(conv3d_params->fused_activation_function());
params->stride_depth = conv3d_params->stride_d();
params->stride_height = conv3d_params->stride_h();
params->stride_width = conv3d_params->stride_w();
params->dilation_depth_factor = conv3d_params->dilation_d_factor();
params->dilation_height_factor = conv3d_params->dilation_h_factor();
params->dilation_width_factor = conv3d_params->dilation_w_factor();
}
*builtin_data = params.release();
return kTfLiteOk;
}
// Below are the ops with no builtin_data structure.
case BuiltinOperator_BATCH_TO_SPACE_ND:
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are

View File

@ -538,6 +538,7 @@ cc_library(
copts = tflite_copts(),
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:types",
],
)
@ -559,6 +560,7 @@ BUILTIN_KERNEL_SRCS = [
"comparisons.cc",
"concatenation.cc",
"conv.cc",
"conv3d.cc",
"cumsum.cc",
"densify.cc",
"depth_to_space.cc",
@ -1079,6 +1081,21 @@ cc_test(
],
)
cc_test(
name = "conv3d_test",
size = "small",
srcs = ["conv3d_test.cc"],
deps = [
":test_main",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite:string",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "densify_test",
size = "small",

View File

@ -45,6 +45,7 @@ TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_CEIL();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_CONV_2D();
TfLiteRegistration* Register_CONV_3D();
TfLiteRegistration* Register_COS();
TfLiteRegistration* Register_CUMSUM();
TfLiteRegistration* Register_DENSIFY();

View File

@ -0,0 +1,167 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv3d.h"
#include <cstddef>
#include <cstdint>
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace conv3d {
// Struct to carry data from Prepare to Eval.
struct OpData {
Padding3DValues padding;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
return data;
}
void Free(TfLiteContext* context, void* buffer) {
delete static_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = static_cast<TfLiteConv3DParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Check number of inputs/outputs.
bool has_bias = node->inputs->size == 3;
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
// Check dimensionality of input, filter.
TF_LITE_ENSURE_EQ(context, input->dims->size, 5);
TF_LITE_ENSURE_EQ(context, filter->dims->size, 5);
// Check input channels matching filter.
TF_LITE_ENSURE_EQ(context, input->dims->data[4], filter->dims->data[3]);
// Check types.
TfLiteType input_type = input->type;
TF_LITE_ENSURE_TYPES_EQ(context, input_type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
// Check bias.
const TfLiteTensor* bias = nullptr;
if (has_bias) {
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4));
}
// Filter has shape of [filter_depth, filter_height, filter_width,
// in_channels, out_channels].
int batches = input->dims->data[0];
int channels_out = filter->dims->data[4];
int depth = input->dims->data[1];
int height = input->dims->data[2];
int width = input->dims->data[3];
int filter_depth = filter->dims->data[0];
int filter_height = filter->dims->data[1];
int filter_width = filter->dims->data[2];
// Matching GetWindowedOutputSize in TensorFlow.
int out_width, out_height, out_depth;
data->padding = ComputePadding3DValues(
params->stride_height, params->stride_width, params->stride_depth,
params->dilation_height_factor, params->dilation_width_factor,
params->dilation_depth_factor, height, width, depth, filter_height,
filter_width, filter_depth, params->padding, &out_height, &out_width,
&out_depth);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(5);
output_size->data[0] = batches;
output_size->data[1] = out_depth;
output_size->data[2] = out_height;
output_size->data[3] = out_width;
output_size->data[4] = channels_out;
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConv3DParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
bool has_bias = node->inputs->size == 3;
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
Conv3DParams runtime_params;
runtime_params.padding_values = data->padding;
runtime_params.stride_depth = params->stride_depth;
runtime_params.stride_height = params->stride_height;
runtime_params.stride_width = params->stride_width;
runtime_params.dilation_depth = params->dilation_depth_factor;
runtime_params.dilation_height = params->dilation_height_factor;
runtime_params.dilation_width = params->dilation_width_factor;
runtime_params.float_activation_min = output_activation_min;
runtime_params.float_activation_max = output_activation_max;
switch (input->type) {
case kTfLiteFloat32:
reference_ops::Conv3D(runtime_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace conv3d
TfLiteRegistration* Register_CONV_3D() {
static TfLiteRegistration r = {conv3d::Init, conv3d::Free, conv3d::Prepare,
conv3d::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,256 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
class Conv3dOpModel : public SingleOpModel {
public:
Conv3dOpModel(const TensorData& input, const TensorData& filter,
const TensorData& bias, const TensorData& output,
Padding padding = Padding_VALID, int32_t stride_depth = 1,
int32_t stride_width = 1, int32_t stride_height = 1,
ActivationFunctionType activation = ActivationFunctionType_NONE,
int32_t dilation_depth = 1, int32_t dilation_width = 1,
int32_t dilation_height = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
bias_ = AddInput(bias);
output_ = AddOutput(output);
SetBuiltinOp(
BuiltinOperator_CONV_3D, BuiltinOptions_Conv3DOptions,
CreateConv3DOptions(builder_, padding, stride_depth, stride_width,
stride_height, activation, dilation_depth,
dilation_width, dilation_height)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
Conv3dOpModel(const TensorData& input, const TensorData& filter,
const TensorData& output, Padding padding = Padding_VALID,
int32_t stride_depth = 1, int32_t stride_width = 1,
int32_t stride_height = 1,
ActivationFunctionType activation = ActivationFunctionType_NONE,
int32_t dilation_depth = 1, int32_t dilation_width = 1,
int32_t dilation_height = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
output_ = AddOutput(output);
SetBuiltinOp(
BuiltinOperator_CONV_3D, BuiltinOptions_Conv3DOptions,
CreateConv3DOptions(builder_, padding, stride_depth, stride_width,
stride_height, activation, dilation_depth,
dilation_width, dilation_height)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_)});
}
void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
void SetInput(std::vector<float> data) { PopulateTensor(input_, data); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int filter_;
int bias_;
int output_;
};
template <typename T>
std::vector<T> CreateRangeVector(int N) {
std::vector<T> result;
for (int i = 0; i < N; ++i) result.push_back(i);
return result;
}
TEST(Conv3dOpModel, InvalidInputDimsTest) {
EXPECT_DEATH_IF_SUPPORTED(Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}}),
"input->dims->size != 5");
}
TEST(Conv3dOpModel, InvalidFilterDimsTest) {
EXPECT_DEATH_IF_SUPPORTED(
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}}),
"filter->dims->size != 5");
}
TEST(Conv3dOpModel, MismatchChannelSizeTest) {
EXPECT_DEATH_IF_SUPPORTED(
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 1}},
{TensorType_FLOAT32, {1, 3, 2, 2, 2}},
{TensorType_FLOAT32, {}}),
"input->dims->data.4. != filter->dims->data.3.");
}
TEST(Conv3dOpModel, MismatchBiasSizeTest) {
EXPECT_DEATH_IF_SUPPORTED(
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 2}},
{TensorType_FLOAT32, {1, 3, 2, 2, 1}},
{TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {}}),
"NumElements.bias. != SizeOfDimension.filter, 4.");
}
TEST(Conv3dOpModel, SimpleFloat32Test) {
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}});
m.SetInput(CreateRangeVector<float>(32));
m.SetFilter({-1, -1, -1, -1, -1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1,
1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, 1, -1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 3, 2));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({30, 6, 26, 10, 22, 14}));
}
TEST(Conv3dOpModel, PaddingValidTest) {
Conv3dOpModel m({TensorType_FLOAT32, {1, 3, 4, 5, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}});
m.SetInput(CreateRangeVector<float>(120));
m.SetFilter({-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1,
1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 3, 4, 2));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({-214, 266, -234, 270, -254, 274, -274, 278, -314, 286,
-334, 290, -354, 294, -374, 298, -414, 306, -434, 310,
-454, 314, -474, 318, -614, 346, -634, 350, -654, 354,
-674, 358, -714, 366, -734, 370, -754, 374, -774, 378,
-814, 386, -834, 390, -854, 394, -874, 398}));
}
TEST(Conv3dOpModel, PaddingSameTest) {
Conv3dOpModel m({TensorType_FLOAT32, {1, 3, 4, 5, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_SAME);
m.SetInput(CreateRangeVector<float>(120));
m.SetFilter({1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1,
-1, 1, -1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 3, 4, 5, 2));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray(
{-172, 290, -176, 298, -180, 306, -184, 314, 36, 198, -192,
330, -196, 338, -200, 346, -204, 354, 56, 218, -212, 370,
-216, 378, -220, 386, -224, 394, 76, 238, -226, 82, -230,
82, -234, 82, -238, 82, -80, 80, -252, 450, -256, 458,
-260, 466, -264, 474, 116, 278, -272, 490, -276, 498, -280,
506, -284, 514, 136, 298, -292, 530, -296, 538, -300, 546,
-304, 554, 156, 318, -306, 82, -310, 82, -314, 82, -318,
82, -80, 80, 158, -158, 162, -162, 166, -166, 170, -170,
176, -176, 178, -178, 182, -182, 186, -186, 190, -190, 196,
-196, 198, -198, 202, -202, 206, -206, 210, -210, 216, -216,
220, -220, 224, -224, 228, -228, 232, -232, 237, -237}));
}
TEST(Conv3dOpModel, StrideTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_VALID, /*stride_depth=*/2,
/*stride_width=*/2, /*stride_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 2, 2));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({52, 8, 68, 8, 244, 8, 260, 8}));
}
TEST(Conv3dOpModel, StrideAndPaddingSameTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_SAME, /*stride_depth=*/2,
/*stride_width=*/2, /*stride_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({-1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1,
1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 2, 2, 2));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({-70, -28, -86, -12, -82, -16, -90, -8, -262,
164, -278, 180, -178, 80, -186, 88}));
}
TEST(Conv3dOpModel, DilationTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_VALID, /*stride_depth=*/1,
/*stride_width=*/1, /*stride_height=*/1,
/*activation=*/ActivationFunctionType_NONE,
/*dilation_depth=*/1, /*dilation_width=*/1,
/*dilation_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 3, 2));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({52, 8, 60, 8, 68, 8, 244, 8, 252, 8, 260, 8}));
}
TEST(Conv3dOpModel, BiasTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {}},
Padding_VALID, /*stride_depth=*/2,
/*stride_width=*/2, /*stride_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
m.SetBias({1, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 2, 2));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({53, 10, 69, 10, 245, 10, 261, 10}));
}
} // namespace
} // namespace tflite

View File

@ -452,6 +452,7 @@ cc_library(
"reference/comparisons.h",
"reference/concatenation.h",
"reference/conv.h",
"reference/conv3d.h",
"reference/densify.h",
"reference/depth_to_space.h",
"reference/depthwiseconv_float.h",
@ -565,6 +566,7 @@ cc_library(
"reference/comparisons.h",
"reference/concatenation.h",
"reference/conv.h",
"reference/conv3d.h",
"reference/densify.h",
"reference/depth_to_space.h",
"reference/depthwiseconv_float.h",

View File

@ -0,0 +1,114 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV3D_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV3D_H_
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,
const float* bias_data, const RuntimeShape& output_shape,
float* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_num_channels = MatchingDim(input_shape, 4, filter_shape, 3);
const int output_num_channels = MatchingDim(filter_shape, 4, output_shape, 4);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_num_channels);
}
// Only NDHWC format is currently supported.
const int input_width = input_shape.Dims(3);
const int input_height = input_shape.Dims(2);
const int input_depth = input_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_depth = filter_shape.Dims(0);
const int output_width = output_shape.Dims(3);
const int output_height = output_shape.Dims(2);
const int output_depth = output_shape.Dims(1);
const int pad_width = params.padding_values.width;
const int pad_height = params.padding_values.height;
const int pad_depth = params.padding_values.depth;
for (int batch = 0; batch < batches; ++batch) {
for (int out_d = 0; out_d < output_depth; ++out_d) {
const int in_d_origin = (out_d * params.stride_depth) - pad_depth;
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * params.stride_height) - pad_height;
for (int out_x = 0; out_x < output_width; ++out_x) {
const int in_x_origin = (out_x * params.stride_width) - pad_width;
for (int out_channel = 0; out_channel < output_num_channels;
++out_channel) {
float total = 0.f;
for (int filter_d = 0; filter_d < filter_depth; ++filter_d) {
const int in_d = in_d_origin + params.dilation_depth * filter_d;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y =
in_y_origin + params.dilation_height * filter_y;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int in_x =
in_x_origin + params.dilation_width * filter_x;
// Zero padding by omitting the areas outside the image.
const bool is_point_inside_image =
(in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height) && (in_d >= 0) &&
(in_d < input_depth);
if (!is_point_inside_image) {
continue;
}
for (int in_channel = 0; in_channel < input_num_channels;
++in_channel) {
float input_value = input_data[Offset(
input_shape, batch, in_d, in_y, in_x, in_channel)];
float filter_value =
filter_data[Offset(filter_shape, filter_d, filter_y,
filter_x, in_channel, out_channel)];
total += (input_value * filter_value);
}
}
}
}
float bias_value = 0.0f;
if (bias_data) {
bias_value = bias_data[out_channel];
}
output_data[Offset(output_shape, batch, out_d, out_y, out_x,
out_channel)] =
ActivationFunctionWithMinMax(total + bias_value,
params.float_activation_min,
params.float_activation_max);
}
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV3D_H_

View File

@ -43,6 +43,20 @@ struct PaddingValues {
int16_t height_offset;
};
struct Padding3DValues {
int16_t width;
int16_t height;
int16_t depth;
// offset is used for calculating "remaining" padding, for example, `width`
// is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
// 1 + 1 = 2.
int16_t width_offset;
// Same as width_offset except it's over the height dimension.
int16_t height_offset;
// Same as width_offset except it's over the depth dimension.
int16_t depth_offset;
};
// This enumeration allows for non-default formats for the weights array
// of a fully-connected operator, allowing the use of special optimized
// runtime paths.
@ -854,6 +868,19 @@ struct ConvParams {
float float_activation_max;
};
struct Conv3DParams {
Padding3DValues padding_values;
int stride_width;
int stride_height;
int stride_depth;
int dilation_width;
int dilation_height;
int dilation_depth;
// float activation params.
float float_activation_min;
float float_activation_max;
};
struct DepthToSpaceParams {
int32_t block_size;
};

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_PADDING_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -75,6 +76,36 @@ inline TfLitePaddingValues ComputePaddingHeightWidth(
padding_values.width_offset = offset;
return padding_values;
}
inline Padding3DValues ComputePadding3DValues(
int stride_height, int stride_width, int stride_depth,
int dilation_rate_height, int dilation_rate_width, int dilation_rate_depth,
int in_height, int in_width, int in_depth, int filter_height,
int filter_width, int filter_depth, TfLitePadding padding, int* out_height,
int* out_width, int* out_depth) {
*out_width = ComputeOutSize(padding, in_width, filter_width, stride_width,
dilation_rate_width);
*out_height = ComputeOutSize(padding, in_height, filter_height, stride_height,
dilation_rate_height);
*out_depth = ComputeOutSize(padding, in_depth, filter_depth, stride_depth,
dilation_rate_depth);
Padding3DValues padding_values;
int offset = 0;
padding_values.depth =
ComputePaddingWithOffset(stride_depth, dilation_rate_depth, in_depth,
filter_depth, *out_depth, &offset);
padding_values.depth_offset = offset;
padding_values.height =
ComputePaddingWithOffset(stride_height, dilation_rate_height, in_height,
filter_height, *out_height, &offset);
padding_values.height_offset = offset;
padding_values.width =
ComputePaddingWithOffset(stride_width, dilation_rate_width, in_width,
filter_width, *out_width, &offset);
padding_values.width_offset = offset;
return padding_values;
}
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_PADDING_H_

View File

@ -312,6 +312,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_CALL_ONCE,
tflite::ops::builtin::Register_CALL_ONCE());
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -157,6 +157,7 @@ TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_BROADCAST_TO();
TfLiteRegistration* Register_CONV_3D();
namespace {
@ -461,6 +462,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL_REF(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
AddCustom("NumericVerify",
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that

View File

@ -357,6 +357,7 @@ enum BuiltinOperator : int32 {
CALL_ONCE = 129,
BROADCAST_TO = 130,
RFFT2D = 131,
CONV_3D = 132,
}
@ -467,6 +468,7 @@ union BuiltinOptions {
CallOnceOptions,
BroadcastToOptions,
Rfft2dOptions,
Conv3DOptions,
}
enum Padding : byte { SAME, VALID }
@ -489,6 +491,17 @@ table Conv2DOptions {
dilation_h_factor:int = 1;
}
table Conv3DOptions {
padding:Padding;
stride_d:int;
stride_w:int;
stride_h:int;
fused_activation_function:ActivationFunctionType;
dilation_d_factor:int = 1;
dilation_w_factor:int = 1;
dilation_h_factor:int = 1;
}
table Pool2DOptions {
padding:Padding;
stride_w:int;

File diff suppressed because it is too large Load Diff

View File

@ -82,6 +82,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteWhileParams",
"TfLiteCumsumParams",
"TfLiteCallOnceParams",
"TfLiteConv3DParams",
nullptr};
} // namespace
@ -331,10 +332,14 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
elem_name = "stride_width";
else if (elem_name == "stride_h")
elem_name = "stride_height";
else if (elem_name == "stride_d")
elem_name = "stride_depth";
else if (elem_name == "dilation_h_factor")
elem_name = "dilation_height_factor";
else if (elem_name == "dilation_w_factor")
elem_name = "dilation_width_factor";
else if (elem_name == "dilation_d_factor")
elem_name = "dilation_depth_factor";
else if (elem_name == "idx_out_type")
elem_name = "index_out_type";

View File

@ -340,6 +340,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
{{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion},
{{BuiltinOperator_RFFT2D, 1}, kPendingReleaseVersion},
{{BuiltinOperator_CONV_3D, 1}, kPendingReleaseVersion},
});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};