Add Conv3D reference kernel to TFLite

This kernel currently only supports float type and has filter with the same data format as TensorFlow Conv3D op.
The conversion change will be added in a follow-up cl.

PiperOrigin-RevId: 351726918
Change-Id: Id2cc805cb89a1da1fda656a2260516eda8bd9119
This commit is contained in:
Thai Nguyen 2021-01-13 21:31:40 -08:00 committed by TensorFlower Gardener
parent 977921340c
commit ac8f6785d3
17 changed files with 1416 additions and 537 deletions

View File

@ -159,6 +159,7 @@ typedef enum {
kTfLiteBuiltinCallOnce = 129,
kTfLiteBuiltinBroadcastTo = 130,
kTfLiteBuiltinRfft2d = 131,
kTfLiteBuiltinConv3d = 132,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -87,6 +87,17 @@ typedef struct {
int dilation_height_factor;
} TfLiteConvParams;
typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
int stride_depth;
int dilation_width_factor;
int dilation_height_factor;
int dilation_depth_factor;
TfLiteFusedActivation activation;
} TfLiteConv3DParams;
typedef struct {
TfLitePadding padding;
int stride_width;

View File

@ -770,6 +770,23 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_CONV_3D: {
auto params = safe_allocator.Allocate<TfLiteConv3DParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* conv3d_params = op->builtin_options_as_Conv3DOptions()) {
params->padding = ConvertPadding(conv3d_params->padding());
params->activation =
ConvertActivation(conv3d_params->fused_activation_function());
params->stride_depth = conv3d_params->stride_d();
params->stride_height = conv3d_params->stride_h();
params->stride_width = conv3d_params->stride_w();
params->dilation_depth_factor = conv3d_params->dilation_d_factor();
params->dilation_height_factor = conv3d_params->dilation_h_factor();
params->dilation_width_factor = conv3d_params->dilation_w_factor();
}
*builtin_data = params.release();
return kTfLiteOk;
}
// Below are the ops with no builtin_data structure.
case BuiltinOperator_BATCH_TO_SPACE_ND:
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are

View File

@ -538,6 +538,7 @@ cc_library(
copts = tflite_copts(),
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:types",
],
)
@ -559,6 +560,7 @@ BUILTIN_KERNEL_SRCS = [
"comparisons.cc",
"concatenation.cc",
"conv.cc",
"conv3d.cc",
"cumsum.cc",
"densify.cc",
"depth_to_space.cc",
@ -1079,6 +1081,21 @@ cc_test(
],
)
cc_test(
name = "conv3d_test",
size = "small",
srcs = ["conv3d_test.cc"],
deps = [
":test_main",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite:string",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "densify_test",
size = "small",

View File

@ -45,6 +45,7 @@ TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_CEIL();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_CONV_2D();
TfLiteRegistration* Register_CONV_3D();
TfLiteRegistration* Register_COS();
TfLiteRegistration* Register_CUMSUM();
TfLiteRegistration* Register_DENSIFY();

View File

@ -0,0 +1,167 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/conv3d.h"
#include <cstddef>
#include <cstdint>
#include <vector>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace conv3d {
// Struct to carry data from Prepare to Eval.
struct OpData {
Padding3DValues padding;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
return data;
}
void Free(TfLiteContext* context, void* buffer) {
delete static_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = static_cast<TfLiteConv3DParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
// Check number of inputs/outputs.
bool has_bias = node->inputs->size == 3;
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
// Check dimensionality of input, filter.
TF_LITE_ENSURE_EQ(context, input->dims->size, 5);
TF_LITE_ENSURE_EQ(context, filter->dims->size, 5);
// Check input channels matching filter.
TF_LITE_ENSURE_EQ(context, input->dims->data[4], filter->dims->data[3]);
// Check types.
TfLiteType input_type = input->type;
TF_LITE_ENSURE_TYPES_EQ(context, input_type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
// Check bias.
const TfLiteTensor* bias = nullptr;
if (has_bias) {
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &bias));
TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4));
}
// Filter has shape of [filter_depth, filter_height, filter_width,
// in_channels, out_channels].
int batches = input->dims->data[0];
int channels_out = filter->dims->data[4];
int depth = input->dims->data[1];
int height = input->dims->data[2];
int width = input->dims->data[3];
int filter_depth = filter->dims->data[0];
int filter_height = filter->dims->data[1];
int filter_width = filter->dims->data[2];
// Matching GetWindowedOutputSize in TensorFlow.
int out_width, out_height, out_depth;
data->padding = ComputePadding3DValues(
params->stride_height, params->stride_width, params->stride_depth,
params->dilation_height_factor, params->dilation_width_factor,
params->dilation_depth_factor, height, width, depth, filter_height,
filter_width, filter_depth, params->padding, &out_height, &out_width,
&out_depth);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(5);
output_size->data[0] = batches;
output_size->data[1] = out_depth;
output_size->data[2] = out_height;
output_size->data[3] = out_width;
output_size->data[4] = channels_out;
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConv3DParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* filter;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
bool has_bias = node->inputs->size == 3;
const TfLiteTensor* bias = has_bias ? GetInput(context, node, 2) : nullptr;
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
Conv3DParams runtime_params;
runtime_params.padding_values = data->padding;
runtime_params.stride_depth = params->stride_depth;
runtime_params.stride_height = params->stride_height;
runtime_params.stride_width = params->stride_width;
runtime_params.dilation_depth = params->dilation_depth_factor;
runtime_params.dilation_height = params->dilation_height_factor;
runtime_params.dilation_width = params->dilation_width_factor;
runtime_params.float_activation_min = output_activation_min;
runtime_params.float_activation_max = output_activation_max;
switch (input->type) {
case kTfLiteFloat32:
reference_ops::Conv3D(runtime_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace conv3d
TfLiteRegistration* Register_CONV_3D() {
static TfLiteRegistration r = {conv3d::Init, conv3d::Free, conv3d::Prepare,
conv3d::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,256 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
class Conv3dOpModel : public SingleOpModel {
public:
Conv3dOpModel(const TensorData& input, const TensorData& filter,
const TensorData& bias, const TensorData& output,
Padding padding = Padding_VALID, int32_t stride_depth = 1,
int32_t stride_width = 1, int32_t stride_height = 1,
ActivationFunctionType activation = ActivationFunctionType_NONE,
int32_t dilation_depth = 1, int32_t dilation_width = 1,
int32_t dilation_height = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
bias_ = AddInput(bias);
output_ = AddOutput(output);
SetBuiltinOp(
BuiltinOperator_CONV_3D, BuiltinOptions_Conv3DOptions,
CreateConv3DOptions(builder_, padding, stride_depth, stride_width,
stride_height, activation, dilation_depth,
dilation_width, dilation_height)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
Conv3dOpModel(const TensorData& input, const TensorData& filter,
const TensorData& output, Padding padding = Padding_VALID,
int32_t stride_depth = 1, int32_t stride_width = 1,
int32_t stride_height = 1,
ActivationFunctionType activation = ActivationFunctionType_NONE,
int32_t dilation_depth = 1, int32_t dilation_width = 1,
int32_t dilation_height = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
output_ = AddOutput(output);
SetBuiltinOp(
BuiltinOperator_CONV_3D, BuiltinOptions_Conv3DOptions,
CreateConv3DOptions(builder_, padding, stride_depth, stride_width,
stride_height, activation, dilation_depth,
dilation_width, dilation_height)
.Union());
BuildInterpreter({GetShape(input_), GetShape(filter_)});
}
void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
void SetInput(std::vector<float> data) { PopulateTensor(input_, data); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_;
int filter_;
int bias_;
int output_;
};
template <typename T>
std::vector<T> CreateRangeVector(int N) {
std::vector<T> result;
for (int i = 0; i < N; ++i) result.push_back(i);
return result;
}
TEST(Conv3dOpModel, InvalidInputDimsTest) {
EXPECT_DEATH_IF_SUPPORTED(Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}}),
"input->dims->size != 5");
}
TEST(Conv3dOpModel, InvalidFilterDimsTest) {
EXPECT_DEATH_IF_SUPPORTED(
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}}),
"filter->dims->size != 5");
}
TEST(Conv3dOpModel, MismatchChannelSizeTest) {
EXPECT_DEATH_IF_SUPPORTED(
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 1}},
{TensorType_FLOAT32, {1, 3, 2, 2, 2}},
{TensorType_FLOAT32, {}}),
"input->dims->data.4. != filter->dims->data.3.");
}
TEST(Conv3dOpModel, MismatchBiasSizeTest) {
EXPECT_DEATH_IF_SUPPORTED(
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 2}},
{TensorType_FLOAT32, {1, 3, 2, 2, 1}},
{TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {}}),
"NumElements.bias. != SizeOfDimension.filter, 4.");
}
TEST(Conv3dOpModel, SimpleFloat32Test) {
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}});
m.SetInput(CreateRangeVector<float>(32));
m.SetFilter({-1, -1, -1, -1, -1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1,
1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, 1, -1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 3, 2));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({30, 6, 26, 10, 22, 14}));
}
TEST(Conv3dOpModel, PaddingValidTest) {
Conv3dOpModel m({TensorType_FLOAT32, {1, 3, 4, 5, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}});
m.SetInput(CreateRangeVector<float>(120));
m.SetFilter({-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1,
1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 3, 4, 2));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({-214, 266, -234, 270, -254, 274, -274, 278, -314, 286,
-334, 290, -354, 294, -374, 298, -414, 306, -434, 310,
-454, 314, -474, 318, -614, 346, -634, 350, -654, 354,
-674, 358, -714, 366, -734, 370, -754, 374, -774, 378,
-814, 386, -834, 390, -854, 394, -874, 398}));
}
TEST(Conv3dOpModel, PaddingSameTest) {
Conv3dOpModel m({TensorType_FLOAT32, {1, 3, 4, 5, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_SAME);
m.SetInput(CreateRangeVector<float>(120));
m.SetFilter({1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1,
-1, 1, -1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 3, 4, 5, 2));
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray(
{-172, 290, -176, 298, -180, 306, -184, 314, 36, 198, -192,
330, -196, 338, -200, 346, -204, 354, 56, 218, -212, 370,
-216, 378, -220, 386, -224, 394, 76, 238, -226, 82, -230,
82, -234, 82, -238, 82, -80, 80, -252, 450, -256, 458,
-260, 466, -264, 474, 116, 278, -272, 490, -276, 498, -280,
506, -284, 514, 136, 298, -292, 530, -296, 538, -300, 546,
-304, 554, 156, 318, -306, 82, -310, 82, -314, 82, -318,
82, -80, 80, 158, -158, 162, -162, 166, -166, 170, -170,
176, -176, 178, -178, 182, -182, 186, -186, 190, -190, 196,
-196, 198, -198, 202, -202, 206, -206, 210, -210, 216, -216,
220, -220, 224, -224, 228, -228, 232, -232, 237, -237}));
}
TEST(Conv3dOpModel, StrideTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_VALID, /*stride_depth=*/2,
/*stride_width=*/2, /*stride_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 2, 2));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({52, 8, 68, 8, 244, 8, 260, 8}));
}
TEST(Conv3dOpModel, StrideAndPaddingSameTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_SAME, /*stride_depth=*/2,
/*stride_width=*/2, /*stride_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({-1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1,
1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 2, 2, 2));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({-70, -28, -86, -12, -82, -16, -90, -8, -262,
164, -278, 180, -178, 80, -186, 88}));
}
TEST(Conv3dOpModel, DilationTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {}}, Padding_VALID, /*stride_depth=*/1,
/*stride_width=*/1, /*stride_height=*/1,
/*activation=*/ActivationFunctionType_NONE,
/*dilation_depth=*/1, /*dilation_width=*/1,
/*dilation_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 3, 2));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({52, 8, 60, 8, 68, 8, 244, 8, 252, 8, 260, 8}));
}
TEST(Conv3dOpModel, BiasTest) {
Conv3dOpModel m({TensorType_FLOAT32, {2, 2, 3, 4, 2}},
{TensorType_FLOAT32, {2, 2, 2, 2, 2}},
{TensorType_FLOAT32, {2}}, {TensorType_FLOAT32, {}},
Padding_VALID, /*stride_depth=*/2,
/*stride_width=*/2, /*stride_height=*/2);
m.SetInput(CreateRangeVector<float>(96));
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
m.SetBias({1, 2});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 2, 2));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({53, 10, 69, 10, 245, 10, 261, 10}));
}
} // namespace
} // namespace tflite

View File

@ -452,6 +452,7 @@ cc_library(
"reference/comparisons.h",
"reference/concatenation.h",
"reference/conv.h",
"reference/conv3d.h",
"reference/densify.h",
"reference/depth_to_space.h",
"reference/depthwiseconv_float.h",
@ -565,6 +566,7 @@ cc_library(
"reference/comparisons.h",
"reference/concatenation.h",
"reference/conv.h",
"reference/conv3d.h",
"reference/densify.h",
"reference/depth_to_space.h",
"reference/depthwiseconv_float.h",

View File

@ -0,0 +1,114 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV3D_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV3D_H_
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,
const float* bias_data, const RuntimeShape& output_shape,
float* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int input_num_channels = MatchingDim(input_shape, 4, filter_shape, 3);
const int output_num_channels = MatchingDim(filter_shape, 4, output_shape, 4);
if (bias_data) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_num_channels);
}
// Only NDHWC format is currently supported.
const int input_width = input_shape.Dims(3);
const int input_height = input_shape.Dims(2);
const int input_depth = input_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int filter_height = filter_shape.Dims(1);
const int filter_depth = filter_shape.Dims(0);
const int output_width = output_shape.Dims(3);
const int output_height = output_shape.Dims(2);
const int output_depth = output_shape.Dims(1);
const int pad_width = params.padding_values.width;
const int pad_height = params.padding_values.height;
const int pad_depth = params.padding_values.depth;
for (int batch = 0; batch < batches; ++batch) {
for (int out_d = 0; out_d < output_depth; ++out_d) {
const int in_d_origin = (out_d * params.stride_depth) - pad_depth;
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * params.stride_height) - pad_height;
for (int out_x = 0; out_x < output_width; ++out_x) {
const int in_x_origin = (out_x * params.stride_width) - pad_width;
for (int out_channel = 0; out_channel < output_num_channels;
++out_channel) {
float total = 0.f;
for (int filter_d = 0; filter_d < filter_depth; ++filter_d) {
const int in_d = in_d_origin + params.dilation_depth * filter_d;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y =
in_y_origin + params.dilation_height * filter_y;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int in_x =
in_x_origin + params.dilation_width * filter_x;
// Zero padding by omitting the areas outside the image.
const bool is_point_inside_image =
(in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height) && (in_d >= 0) &&
(in_d < input_depth);
if (!is_point_inside_image) {
continue;
}
for (int in_channel = 0; in_channel < input_num_channels;
++in_channel) {
float input_value = input_data[Offset(
input_shape, batch, in_d, in_y, in_x, in_channel)];
float filter_value =
filter_data[Offset(filter_shape, filter_d, filter_y,
filter_x, in_channel, out_channel)];
total += (input_value * filter_value);
}
}
}
}
float bias_value = 0.0f;
if (bias_data) {
bias_value = bias_data[out_channel];
}
output_data[Offset(output_shape, batch, out_d, out_y, out_x,
out_channel)] =
ActivationFunctionWithMinMax(total + bias_value,
params.float_activation_min,
params.float_activation_max);
}
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONV3D_H_

View File

@ -43,6 +43,20 @@ struct PaddingValues {
int16_t height_offset;
};
struct Padding3DValues {
int16_t width;
int16_t height;
int16_t depth;
// offset is used for calculating "remaining" padding, for example, `width`
// is 1 and `width_offset` is 1, so padding_left is 1 while padding_right is
// 1 + 1 = 2.
int16_t width_offset;
// Same as width_offset except it's over the height dimension.
int16_t height_offset;
// Same as width_offset except it's over the depth dimension.
int16_t depth_offset;
};
// This enumeration allows for non-default formats for the weights array
// of a fully-connected operator, allowing the use of special optimized
// runtime paths.
@ -854,6 +868,19 @@ struct ConvParams {
float float_activation_max;
};
struct Conv3DParams {
Padding3DValues padding_values;
int stride_width;
int stride_height;
int stride_depth;
int dilation_width;
int dilation_height;
int dilation_depth;
// float activation params.
float float_activation_min;
float float_activation_max;
};
struct DepthToSpaceParams {
int32_t block_size;
};

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_PADDING_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -75,6 +76,36 @@ inline TfLitePaddingValues ComputePaddingHeightWidth(
padding_values.width_offset = offset;
return padding_values;
}
inline Padding3DValues ComputePadding3DValues(
int stride_height, int stride_width, int stride_depth,
int dilation_rate_height, int dilation_rate_width, int dilation_rate_depth,
int in_height, int in_width, int in_depth, int filter_height,
int filter_width, int filter_depth, TfLitePadding padding, int* out_height,
int* out_width, int* out_depth) {
*out_width = ComputeOutSize(padding, in_width, filter_width, stride_width,
dilation_rate_width);
*out_height = ComputeOutSize(padding, in_height, filter_height, stride_height,
dilation_rate_height);
*out_depth = ComputeOutSize(padding, in_depth, filter_depth, stride_depth,
dilation_rate_depth);
Padding3DValues padding_values;
int offset = 0;
padding_values.depth =
ComputePaddingWithOffset(stride_depth, dilation_rate_depth, in_depth,
filter_depth, *out_depth, &offset);
padding_values.depth_offset = offset;
padding_values.height =
ComputePaddingWithOffset(stride_height, dilation_rate_height, in_height,
filter_height, *out_height, &offset);
padding_values.height_offset = offset;
padding_values.width =
ComputePaddingWithOffset(stride_width, dilation_rate_width, in_width,
filter_width, *out_width, &offset);
padding_values.width_offset = offset;
return padding_values;
}
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_PADDING_H_

View File

@ -312,6 +312,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_CALL_ONCE,
tflite::ops::builtin::Register_CALL_ONCE());
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -157,6 +157,7 @@ TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_BROADCAST_TO();
TfLiteRegistration* Register_CONV_3D();
namespace {
@ -461,6 +462,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL_REF(),
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
AddCustom("NumericVerify",
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that

View File

@ -357,6 +357,7 @@ enum BuiltinOperator : int32 {
CALL_ONCE = 129,
BROADCAST_TO = 130,
RFFT2D = 131,
CONV_3D = 132,
}
@ -467,6 +468,7 @@ union BuiltinOptions {
CallOnceOptions,
BroadcastToOptions,
Rfft2dOptions,
Conv3DOptions,
}
enum Padding : byte { SAME, VALID }
@ -489,6 +491,17 @@ table Conv2DOptions {
dilation_h_factor:int = 1;
}
table Conv3DOptions {
padding:Padding;
stride_d:int;
stride_w:int;
stride_h:int;
fused_activation_function:ActivationFunctionType;
dilation_d_factor:int = 1;
dilation_w_factor:int = 1;
dilation_h_factor:int = 1;
}
table Pool2DOptions {
padding:Padding;
stride_w:int;

View File

@ -49,6 +49,9 @@ struct TensorT;
struct Conv2DOptions;
struct Conv2DOptionsT;
struct Conv3DOptions;
struct Conv3DOptionsT;
struct Pool2DOptions;
struct Pool2DOptionsT;
@ -807,11 +810,12 @@ enum BuiltinOperator {
BuiltinOperator_CALL_ONCE = 129,
BuiltinOperator_BROADCAST_TO = 130,
BuiltinOperator_RFFT2D = 131,
BuiltinOperator_CONV_3D = 132,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_RFFT2D
BuiltinOperator_MAX = BuiltinOperator_CONV_3D
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[132] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[133] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -944,12 +948,15 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[132] {
BuiltinOperator_CUMSUM,
BuiltinOperator_CALL_ONCE,
BuiltinOperator_BROADCAST_TO,
BuiltinOperator_RFFT2D};
BuiltinOperator_RFFT2D,
BuiltinOperator_CONV_3D
};
return values;
}
inline const char * const *EnumNamesBuiltinOperator() {
static const char *const names[133] = {"ADD",
static const char * const names[134] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
"CONV_2D",
@ -1081,13 +1088,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"CALL_ONCE",
"BROADCAST_TO",
"RFFT2D",
nullptr};
"CONV_3D",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_RFFT2D))
return "";
if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_CONV_3D)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -1199,11 +1207,12 @@ enum BuiltinOptions {
BuiltinOptions_CallOnceOptions = 103,
BuiltinOptions_BroadcastToOptions = 104,
BuiltinOptions_Rfft2dOptions = 105,
BuiltinOptions_Conv3DOptions = 106,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_Rfft2dOptions
BuiltinOptions_MAX = BuiltinOptions_Conv3DOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[106] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[107] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1310,12 +1319,15 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[106] {
BuiltinOptions_CumsumOptions,
BuiltinOptions_CallOnceOptions,
BuiltinOptions_BroadcastToOptions,
BuiltinOptions_Rfft2dOptions};
BuiltinOptions_Rfft2dOptions,
BuiltinOptions_Conv3DOptions
};
return values;
}
inline const char * const *EnumNamesBuiltinOptions() {
static const char *const names[107] = {"NONE",
static const char * const names[108] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
"ConcatEmbeddingsOptions",
@ -1421,14 +1433,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"CallOnceOptions",
"BroadcastToOptions",
"Rfft2dOptions",
nullptr};
"Conv3DOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE,
BuiltinOptions_Rfft2dOptions))
return "";
if (flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_Conv3DOptions)) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1853,11 +1865,14 @@ template<> struct BuiltinOptionsTraits<tflite::BroadcastToOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions;
};
template <>
struct BuiltinOptionsTraits<tflite::Rfft2dOptions> {
template<> struct BuiltinOptionsTraits<tflite::Rfft2dOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_Rfft2dOptions;
};
template<> struct BuiltinOptionsTraits<tflite::Conv3DOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_Conv3DOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2723,14 +2738,20 @@ struct BuiltinOptionsUnion {
reinterpret_cast<const tflite::BroadcastToOptionsT *>(value) : nullptr;
}
tflite::Rfft2dOptionsT *AsRfft2dOptions() {
return type == BuiltinOptions_Rfft2dOptions
? reinterpret_cast<tflite::Rfft2dOptionsT *>(value)
: nullptr;
return type == BuiltinOptions_Rfft2dOptions ?
reinterpret_cast<tflite::Rfft2dOptionsT *>(value) : nullptr;
}
const tflite::Rfft2dOptionsT *AsRfft2dOptions() const {
return type == BuiltinOptions_Rfft2dOptions
? reinterpret_cast<const tflite::Rfft2dOptionsT *>(value)
: nullptr;
return type == BuiltinOptions_Rfft2dOptions ?
reinterpret_cast<const tflite::Rfft2dOptionsT *>(value) : nullptr;
}
tflite::Conv3DOptionsT *AsConv3DOptions() {
return type == BuiltinOptions_Conv3DOptions ?
reinterpret_cast<tflite::Conv3DOptionsT *>(value) : nullptr;
}
const tflite::Conv3DOptionsT *AsConv3DOptions() const {
return type == BuiltinOptions_Conv3DOptions ?
reinterpret_cast<const tflite::Conv3DOptionsT *>(value) : nullptr;
}
};
@ -3928,6 +3949,144 @@ inline flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(
flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct Conv3DOptionsT : public flatbuffers::NativeTable {
typedef Conv3DOptions TableType;
tflite::Padding padding;
int32_t stride_d;
int32_t stride_w;
int32_t stride_h;
tflite::ActivationFunctionType fused_activation_function;
int32_t dilation_d_factor;
int32_t dilation_w_factor;
int32_t dilation_h_factor;
Conv3DOptionsT()
: padding(tflite::Padding_SAME),
stride_d(0),
stride_w(0),
stride_h(0),
fused_activation_function(tflite::ActivationFunctionType_NONE),
dilation_d_factor(1),
dilation_w_factor(1),
dilation_h_factor(1) {
}
};
struct Conv3DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef Conv3DOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_PADDING = 4,
VT_STRIDE_D = 6,
VT_STRIDE_W = 8,
VT_STRIDE_H = 10,
VT_FUSED_ACTIVATION_FUNCTION = 12,
VT_DILATION_D_FACTOR = 14,
VT_DILATION_W_FACTOR = 16,
VT_DILATION_H_FACTOR = 18
};
tflite::Padding padding() const {
return static_cast<tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
}
int32_t stride_d() const {
return GetField<int32_t>(VT_STRIDE_D, 0);
}
int32_t stride_w() const {
return GetField<int32_t>(VT_STRIDE_W, 0);
}
int32_t stride_h() const {
return GetField<int32_t>(VT_STRIDE_H, 0);
}
tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
int32_t dilation_d_factor() const {
return GetField<int32_t>(VT_DILATION_D_FACTOR, 1);
}
int32_t dilation_w_factor() const {
return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
}
int32_t dilation_h_factor() const {
return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING) &&
VerifyField<int32_t>(verifier, VT_STRIDE_D) &&
VerifyField<int32_t>(verifier, VT_STRIDE_W) &&
VerifyField<int32_t>(verifier, VT_STRIDE_H) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<int32_t>(verifier, VT_DILATION_D_FACTOR) &&
VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) &&
VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) &&
verifier.EndTable();
}
Conv3DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(Conv3DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<Conv3DOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct Conv3DOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_padding(tflite::Padding padding) {
fbb_.AddElement<int8_t>(Conv3DOptions::VT_PADDING, static_cast<int8_t>(padding), 0);
}
void add_stride_d(int32_t stride_d) {
fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_D, stride_d, 0);
}
void add_stride_w(int32_t stride_w) {
fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_W, stride_w, 0);
}
void add_stride_h(int32_t stride_h) {
fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_H, stride_h, 0);
}
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(Conv3DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
void add_dilation_d_factor(int32_t dilation_d_factor) {
fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_D_FACTOR, dilation_d_factor, 1);
}
void add_dilation_w_factor(int32_t dilation_w_factor) {
fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
}
void add_dilation_h_factor(int32_t dilation_h_factor) {
fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
}
explicit Conv3DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
Conv3DOptionsBuilder &operator=(const Conv3DOptionsBuilder &);
flatbuffers::Offset<Conv3DOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<Conv3DOptions>(end);
return o;
}
};
inline flatbuffers::Offset<Conv3DOptions> CreateConv3DOptions(
flatbuffers::FlatBufferBuilder &_fbb,
tflite::Padding padding = tflite::Padding_SAME,
int32_t stride_d = 0,
int32_t stride_w = 0,
int32_t stride_h = 0,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
int32_t dilation_d_factor = 1,
int32_t dilation_w_factor = 1,
int32_t dilation_h_factor = 1) {
Conv3DOptionsBuilder builder_(_fbb);
builder_.add_dilation_h_factor(dilation_h_factor);
builder_.add_dilation_w_factor(dilation_w_factor);
builder_.add_dilation_d_factor(dilation_d_factor);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
builder_.add_stride_d(stride_d);
builder_.add_fused_activation_function(fused_activation_function);
builder_.add_padding(padding);
return builder_.Finish();
}
flatbuffers::Offset<Conv3DOptions> CreateConv3DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct Pool2DOptionsT : public flatbuffers::NativeTable {
typedef Pool2DOptions TableType;
tflite::Padding padding;
@ -9604,22 +9763,19 @@ flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuffers::Fl
struct Rfft2dOptionsT : public flatbuffers::NativeTable {
typedef Rfft2dOptions TableType;
Rfft2dOptionsT() {}
Rfft2dOptionsT() {
}
};
struct Rfft2dOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef Rfft2dOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && verifier.EndTable();
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
Rfft2dOptionsT *UnPack(
const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(
Rfft2dOptionsT *_o,
const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<Rfft2dOptions> Pack(
flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
Rfft2dOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(Rfft2dOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<Rfft2dOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct Rfft2dOptionsBuilder {
@ -9643,9 +9799,7 @@ inline flatbuffers::Offset<Rfft2dOptions> CreateRfft2dOptions(
return builder_.Finish();
}
flatbuffers::Offset<Rfft2dOptions> CreateRfft2dOptions(
flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
flatbuffers::Offset<Rfft2dOptions> CreateRfft2dOptions(flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
@ -10110,9 +10264,10 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast<const tflite::BroadcastToOptions *>(builtin_options()) : nullptr;
}
const tflite::Rfft2dOptions *builtin_options_as_Rfft2dOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_Rfft2dOptions
? static_cast<const tflite::Rfft2dOptions *>(builtin_options())
: nullptr;
return builtin_options_type() == tflite::BuiltinOptions_Rfft2dOptions ? static_cast<const tflite::Rfft2dOptions *>(builtin_options()) : nullptr;
}
const tflite::Conv3DOptions *builtin_options_as_Conv3DOptions() const {
return builtin_options_type() == tflite::BuiltinOptions_Conv3DOptions ? static_cast<const tflite::Conv3DOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
@ -10566,12 +10721,14 @@ template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as
return builtin_options_as_BroadcastToOptions();
}
template <>
inline const tflite::Rfft2dOptions *
Operator::builtin_options_as<tflite::Rfft2dOptions>() const {
template<> inline const tflite::Rfft2dOptions *Operator::builtin_options_as<tflite::Rfft2dOptions>() const {
return builtin_options_as_Rfft2dOptions();
}
template<> inline const tflite::Conv3DOptions *Operator::builtin_options_as<tflite::Conv3DOptions>() const {
return builtin_options_as_Conv3DOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -11606,6 +11763,53 @@ inline flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(flatbuffers::FlatB
_dilation_h_factor);
}
inline Conv3DOptionsT *Conv3DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new Conv3DOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void Conv3DOptions::UnPackTo(Conv3DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = padding(); _o->padding = _e; }
{ auto _e = stride_d(); _o->stride_d = _e; }
{ auto _e = stride_w(); _o->stride_w = _e; }
{ auto _e = stride_h(); _o->stride_h = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = dilation_d_factor(); _o->dilation_d_factor = _e; }
{ auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; }
{ auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; }
}
inline flatbuffers::Offset<Conv3DOptions> Conv3DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateConv3DOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<Conv3DOptions> CreateConv3DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Conv3DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _padding = _o->padding;
auto _stride_d = _o->stride_d;
auto _stride_w = _o->stride_w;
auto _stride_h = _o->stride_h;
auto _fused_activation_function = _o->fused_activation_function;
auto _dilation_d_factor = _o->dilation_d_factor;
auto _dilation_w_factor = _o->dilation_w_factor;
auto _dilation_h_factor = _o->dilation_h_factor;
return tflite::CreateConv3DOptions(
_fbb,
_padding,
_stride_d,
_stride_w,
_stride_h,
_fused_activation_function,
_dilation_d_factor,
_dilation_w_factor,
_dilation_h_factor);
}
inline Pool2DOptionsT *Pool2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new Pool2DOptionsT();
UnPackTo(_o, _resolver);
@ -14329,38 +14533,27 @@ inline flatbuffers::Offset<BroadcastToOptions> CreateBroadcastToOptions(flatbuff
_fbb);
}
inline Rfft2dOptionsT *Rfft2dOptions::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
inline Rfft2dOptionsT *Rfft2dOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new Rfft2dOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void Rfft2dOptions::UnPackTo(
Rfft2dOptionsT *_o,
const flatbuffers::resolver_function_t *_resolver) const {
inline void Rfft2dOptions::UnPackTo(Rfft2dOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<Rfft2dOptions> Rfft2dOptions::Pack(
flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher) {
inline flatbuffers::Offset<Rfft2dOptions> Rfft2dOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateRfft2dOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<Rfft2dOptions> CreateRfft2dOptions(
flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher) {
inline flatbuffers::Offset<Rfft2dOptions> CreateRfft2dOptions(flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs {
flatbuffers::FlatBufferBuilder *__fbb;
const Rfft2dOptionsT *__o;
const flatbuffers::rehasher_function_t *__rehasher;
} _va = {&_fbb, _o, _rehasher};
(void)_va;
return tflite::CreateRfft2dOptions(_fbb);
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Rfft2dOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateRfft2dOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@ -15258,6 +15451,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const tflite::Rfft2dOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_Conv3DOptions: {
auto ptr = reinterpret_cast<const tflite::Conv3DOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
@ -15696,6 +15893,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const tflite::Rfft2dOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_Conv3DOptions: {
auto ptr = reinterpret_cast<const tflite::Conv3DOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -16122,6 +16323,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const tflite::Rfft2dOptionsT *>(value);
return CreateRfft2dOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_Conv3DOptions: {
auto ptr = reinterpret_cast<const tflite::Conv3DOptionsT *>(value);
return CreateConv3DOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -16545,8 +16750,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
break;
}
case BuiltinOptions_Rfft2dOptions: {
value = new tflite::Rfft2dOptionsT(
*reinterpret_cast<tflite::Rfft2dOptionsT *>(u.value));
value = new tflite::Rfft2dOptionsT(*reinterpret_cast<tflite::Rfft2dOptionsT *>(u.value));
break;
}
case BuiltinOptions_Conv3DOptions: {
value = new tflite::Conv3DOptionsT(*reinterpret_cast<tflite::Conv3DOptionsT *>(u.value));
break;
}
default:
@ -17081,6 +17289,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_Conv3DOptions: {
auto ptr = reinterpret_cast<tflite::Conv3DOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -82,6 +82,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteWhileParams",
"TfLiteCumsumParams",
"TfLiteCallOnceParams",
"TfLiteConv3DParams",
nullptr};
} // namespace
@ -331,10 +332,14 @@ void GenerateImportForOp(FILE* fp, const std::string& op_name,
elem_name = "stride_width";
else if (elem_name == "stride_h")
elem_name = "stride_height";
else if (elem_name == "stride_d")
elem_name = "stride_depth";
else if (elem_name == "dilation_h_factor")
elem_name = "dilation_height_factor";
else if (elem_name == "dilation_w_factor")
elem_name = "dilation_width_factor";
else if (elem_name == "dilation_d_factor")
elem_name = "dilation_depth_factor";
else if (elem_name == "idx_out_type")
elem_name = "index_out_type";

View File

@ -340,6 +340,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
{{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion},
{{BuiltinOperator_RFFT2D, 1}, kPendingReleaseVersion},
{{BuiltinOperator_CONV_3D, 1}, kPendingReleaseVersion},
});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};