Add segment_sum op to Tensorflow Lite
PiperOrigin-RevId: 289377531 Change-Id: Ie8aa95ca9d6b32eb2c5eb8a11c96d6ed3b3464d9
This commit is contained in:
parent
c25b583371
commit
e07863b456
@ -151,6 +151,7 @@ typedef enum {
|
||||
kTfLiteBuiltinScatterNd = 122,
|
||||
kTfLiteBuiltinSelectV2 = 123,
|
||||
kTfLiteBuiltinDensify = 124,
|
||||
kTfLiteBuiltinSegmentSum = 125,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -826,6 +826,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
|
||||
case BuiltinOperator_SCATTER_ND:
|
||||
case BuiltinOperator_DENSIFY:
|
||||
case BuiltinOperator_SEGMENT_SUM:
|
||||
break;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
|
@ -481,6 +481,7 @@ cc_library(
|
||||
"reverse_sequence.cc",
|
||||
"round.cc",
|
||||
"scatter_nd.cc",
|
||||
"segment_sum.cc",
|
||||
"select.cc",
|
||||
"shape.cc",
|
||||
"skip_gram.cc",
|
||||
@ -2059,4 +2060,16 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "segment_sum_test",
|
||||
srcs = ["segment_sum_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
|
||||
|
@ -118,6 +118,7 @@ TfLiteRegistration* Register_RNN();
|
||||
TfLiteRegistration* Register_ROUND();
|
||||
TfLiteRegistration* Register_RSQRT();
|
||||
TfLiteRegistration* Register_SCATTER_ND();
|
||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||
TfLiteRegistration* Register_SELECT();
|
||||
TfLiteRegistration* Register_SELECT_V2();
|
||||
TfLiteRegistration* Register_SHAPE();
|
||||
|
@ -3033,6 +3033,25 @@ inline void HardSwish(const HardSwishParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& segment_ids_shape,
|
||||
const int32_t* segment_ids_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const int segment_flat_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
|
||||
|
||||
memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
|
||||
|
||||
for (int i = 0; i < input_shape.Dims(0); i++) {
|
||||
int output_index = segment_ids_data[i];
|
||||
for (int j = 0; j < segment_flat_size; ++j) {
|
||||
output_data[output_index * segment_flat_size + j] +=
|
||||
input_data[i * segment_flat_size + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -281,6 +281,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
Register_NON_MAX_SUPPRESSION_V5());
|
||||
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
||||
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -133,6 +133,7 @@ TfLiteRegistration* Register_QUANTIZE();
|
||||
TfLiteRegistration* Register_HARD_SWISH_REF();
|
||||
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
|
||||
TfLiteRegistration* Register_SELECT_V2();
|
||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||
|
||||
namespace {
|
||||
|
||||
@ -286,6 +287,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
||||
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF());
|
||||
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
|
||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
112
tensorflow/lite/kernels/segment_sum.cc
Normal file
112
tensorflow/lite/kernels/segment_sum.cc
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace segment_sum {
|
||||
|
||||
static const int kInputDataTensor = 0;
|
||||
static const int kInputSegmentIdsTensor = 1;
|
||||
static const int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* data,
|
||||
const TfLiteTensor* segment_ids,
|
||||
TfLiteTensor* output) {
|
||||
int max_index = -1;
|
||||
const int segment_id_size = segment_ids->dims->data[0];
|
||||
if (segment_id_size > 0) {
|
||||
max_index = segment_ids->data.i32[segment_id_size - 1];
|
||||
}
|
||||
const int data_rank = NumDimensions(data);
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
|
||||
output_shape->data[0] = max_index + 1;
|
||||
for (int i = 1; i < data_rank; ++i) {
|
||||
output_shape->data[i] = data->dims->data[i];
|
||||
}
|
||||
return context->ResizeTensor(context, output, output_shape);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
|
||||
const TfLiteTensor* segment_ids =
|
||||
GetInput(context, node, kInputSegmentIdsTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE(context,
|
||||
data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
|
||||
|
||||
if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) {
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
return ResizeOutputTensor(context, data, segment_ids, output);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
|
||||
const TfLiteTensor* segment_ids =
|
||||
GetInput(context, node, kInputSegmentIdsTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeOutputTensor(context, data, segment_ids, output));
|
||||
}
|
||||
|
||||
#define TF_LITE_SEGMENT_SUM(dtype) \
|
||||
reference_ops::SegmentSum<dtype>( \
|
||||
GetTensorShape(data), GetTensorData<dtype>(data), \
|
||||
GetTensorShape(segment_ids), GetTensorData<int32_t>(segment_ids), \
|
||||
GetTensorShape(output), GetTensorData<dtype>(output));
|
||||
switch (data->type) {
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_SEGMENT_SUM(int32_t);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_SEGMENT_SUM(float);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Currently SegmentSum doesn't support type: %s",
|
||||
TfLiteTypeGetName(data->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
#undef TF_LITE_SEGMENT_SUM
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace segment_sum
|
||||
|
||||
TfLiteRegistration* Register_SEGMENT_SUM() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare,
|
||||
segment_sum::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
112
tensorflow/lite/kernels/segment_sum_test.cc
Normal file
112
tensorflow/lite/kernels/segment_sum_test.cc
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class SegmentSumOpModel : public SingleOpModel {
|
||||
public:
|
||||
SegmentSumOpModel(const TensorData& data, const TensorData& segment_ids) {
|
||||
data_id_ = AddInput(data);
|
||||
segment_ids_id_ = AddInput(segment_ids);
|
||||
output_id_ = AddOutput(data.type);
|
||||
SetBuiltinOp(BuiltinOperator_SEGMENT_SUM, BuiltinOptions_NONE, 0);
|
||||
BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_)});
|
||||
}
|
||||
|
||||
int data() const { return data_id_; }
|
||||
int segment_ids() const { return segment_ids_id_; }
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
|
||||
std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); }
|
||||
|
||||
protected:
|
||||
int data_id_;
|
||||
int segment_ids_id_;
|
||||
int output_id_;
|
||||
};
|
||||
|
||||
TEST(SegmentSumOpModelTest, Int32Test_Simple) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 4}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<int32_t>(model.data(),
|
||||
{1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 6, 7, 8}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, Int32Test_OneDimension) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 3}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, Int32Test_ThreeDimensions) {
|
||||
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2, 1}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 5, 6}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, Float32Test_Simple) {
|
||||
SegmentSumOpModel<float> model({TensorType_FLOAT32, {3, 4}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<float>(model.data(),
|
||||
{1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8});
|
||||
model.PopulateTensor<int>(model.segment_ids(), {0, 0, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5.0f, 5.0f, 5.0f, 5.0f, 5.0f,
|
||||
6.0f, 7.0f, 8.0f}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, Float32Test_OneDimension) {
|
||||
SegmentSumOpModel<float> model({TensorType_FLOAT32, {3}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<float>(model.data(), {1, 2, 3});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3.0f, 3.0f}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
|
||||
}
|
||||
|
||||
TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) {
|
||||
SegmentSumOpModel<float> model({TensorType_FLOAT32, {3, 2, 1}},
|
||||
{TensorType_INT32, {3}});
|
||||
model.PopulateTensor<float>(model.data(), {1, 2, 3, 4, 5, 6});
|
||||
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4.0f, 6.0f, 5.0f, 6.0f}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -317,7 +317,8 @@ enum BuiltinOperator : byte {
|
||||
NON_MAX_SUPPRESSION_V5 = 121,
|
||||
SCATTER_ND = 122,
|
||||
SELECT_V2 = 123,
|
||||
DENSIFY = 124
|
||||
DENSIFY = 124,
|
||||
SEGMENT_SUM = 125
|
||||
}
|
||||
|
||||
|
||||
@ -421,7 +422,8 @@ union BuiltinOptions {
|
||||
NonMaxSuppressionV5Options,
|
||||
ScatterNdOptions,
|
||||
SelectV2Options,
|
||||
DensifyOptions
|
||||
DensifyOptions,
|
||||
SegmentSumOptions
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -911,6 +913,9 @@ table SelectV2Options {
|
||||
table DensifyOptions {
|
||||
}
|
||||
|
||||
table SegmentSumOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -334,6 +334,9 @@ struct SelectV2OptionsT;
|
||||
struct DensifyOptions;
|
||||
struct DensifyOptionsT;
|
||||
|
||||
struct SegmentSumOptions;
|
||||
struct SegmentSumOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -645,11 +648,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_SCATTER_ND = 122,
|
||||
BuiltinOperator_SELECT_V2 = 123,
|
||||
BuiltinOperator_DENSIFY = 124,
|
||||
BuiltinOperator_SEGMENT_SUM = 125,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_DENSIFY
|
||||
BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[125] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -775,7 +779,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[125] {
|
||||
BuiltinOperator_NON_MAX_SUPPRESSION_V5,
|
||||
BuiltinOperator_SCATTER_ND,
|
||||
BuiltinOperator_SELECT_V2,
|
||||
BuiltinOperator_DENSIFY
|
||||
BuiltinOperator_DENSIFY,
|
||||
BuiltinOperator_SEGMENT_SUM
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -907,13 +912,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"SCATTER_ND",
|
||||
"SELECT_V2",
|
||||
"DENSIFY",
|
||||
"SEGMENT_SUM",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||
if (e < BuiltinOperator_ADD || e > BuiltinOperator_DENSIFY) return "";
|
||||
if (e < BuiltinOperator_ADD || e > BuiltinOperator_SEGMENT_SUM) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOperator()[index];
|
||||
}
|
||||
@ -1019,11 +1025,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_ScatterNdOptions = 97,
|
||||
BuiltinOptions_SelectV2Options = 98,
|
||||
BuiltinOptions_DensifyOptions = 99,
|
||||
BuiltinOptions_SegmentSumOptions = 100,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_DensifyOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[100] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -1124,7 +1131,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[100] {
|
||||
BuiltinOptions_NonMaxSuppressionV5Options,
|
||||
BuiltinOptions_ScatterNdOptions,
|
||||
BuiltinOptions_SelectV2Options,
|
||||
BuiltinOptions_DensifyOptions
|
||||
BuiltinOptions_DensifyOptions,
|
||||
BuiltinOptions_SegmentSumOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1231,13 +1239,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"ScatterNdOptions",
|
||||
"SelectV2Options",
|
||||
"DensifyOptions",
|
||||
"SegmentSumOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
}
|
||||
|
||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||
if (e < BuiltinOptions_NONE || e > BuiltinOptions_DensifyOptions) return "";
|
||||
if (e < BuiltinOptions_NONE || e > BuiltinOptions_SegmentSumOptions) return "";
|
||||
const size_t index = static_cast<size_t>(e);
|
||||
return EnumNamesBuiltinOptions()[index];
|
||||
}
|
||||
@ -1642,6 +1651,10 @@ template<> struct BuiltinOptionsTraits<DensifyOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<SegmentSumOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2466,6 +2479,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_DensifyOptions ?
|
||||
reinterpret_cast<const DensifyOptionsT *>(value) : nullptr;
|
||||
}
|
||||
SegmentSumOptionsT *AsSegmentSumOptions() {
|
||||
return type == BuiltinOptions_SegmentSumOptions ?
|
||||
reinterpret_cast<SegmentSumOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const SegmentSumOptionsT *AsSegmentSumOptions() const {
|
||||
return type == BuiltinOptions_SegmentSumOptions ?
|
||||
reinterpret_cast<const SegmentSumOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -8659,6 +8680,46 @@ inline flatbuffers::Offset<DensifyOptions> CreateDensifyOptions(
|
||||
|
||||
flatbuffers::Offset<DensifyOptions> CreateDensifyOptions(flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct SegmentSumOptionsT : public flatbuffers::NativeTable {
|
||||
typedef SegmentSumOptions TableType;
|
||||
SegmentSumOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct SegmentSumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef SegmentSumOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
SegmentSumOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(SegmentSumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<SegmentSumOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct SegmentSumOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit SegmentSumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
SegmentSumOptionsBuilder &operator=(const SegmentSumOptionsBuilder &);
|
||||
flatbuffers::Offset<SegmentSumOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<SegmentSumOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
SegmentSumOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -9092,6 +9153,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const DensifyOptions *builtin_options_as_DensifyOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_DensifyOptions ? static_cast<const DensifyOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const SegmentSumOptions *builtin_options_as_SegmentSumOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_SegmentSumOptions ? static_cast<const SegmentSumOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -9524,6 +9588,10 @@ template<> inline const DensifyOptions *Operator::builtin_options_as<DensifyOpti
|
||||
return builtin_options_as_DensifyOptions();
|
||||
}
|
||||
|
||||
template<> inline const SegmentSumOptions *Operator::builtin_options_as<SegmentSumOptions>() const {
|
||||
return builtin_options_as_SegmentSumOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -12818,6 +12886,29 @@ inline flatbuffers::Offset<DensifyOptions> CreateDensifyOptions(flatbuffers::Fla
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline SegmentSumOptionsT *SegmentSumOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new SegmentSumOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void SegmentSumOptions::UnPackTo(SegmentSumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<SegmentSumOptions> SegmentSumOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateSegmentSumOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SegmentSumOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateSegmentSumOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -13507,6 +13598,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const DensifyOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_SegmentSumOptions: {
|
||||
auto ptr = reinterpret_cast<const SegmentSumOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return true;
|
||||
}
|
||||
}
|
||||
@ -13921,6 +14016,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const DensifyOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_SegmentSumOptions: {
|
||||
auto ptr = reinterpret_cast<const SegmentSumOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -14323,6 +14422,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const DensifyOptionsT *>(value);
|
||||
return CreateDensifyOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_SegmentSumOptions: {
|
||||
auto ptr = reinterpret_cast<const SegmentSumOptionsT *>(value);
|
||||
return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -14725,6 +14828,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new DensifyOptionsT(*reinterpret_cast<DensifyOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_SegmentSumOptions: {
|
||||
value = new SegmentSumOptionsT(*reinterpret_cast<SegmentSumOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -15227,6 +15334,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_SegmentSumOptions: {
|
||||
auto ptr = reinterpret_cast<SegmentSumOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -146,6 +146,7 @@ enum class OperatorType : uint8 {
|
||||
// instead of being given as plain constant arrays. So we need to insert
|
||||
// special nodes in the graph to shuffle axes.
|
||||
kReorderAxes,
|
||||
kSegmentSum,
|
||||
kSelect,
|
||||
kSelectV2,
|
||||
kSparseToDense,
|
||||
|
@ -197,6 +197,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kLess, 2}, "1.14.0"},
|
||||
{{OperatorType::kLessEqual, 1}, "1.14.0"},
|
||||
{{OperatorType::kLessEqual, 2}, "1.14.0"},
|
||||
{{OperatorType::kSegmentSum, 1}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kSelect, 1}, "1.14.0"},
|
||||
{{OperatorType::kSelect, 2}, "1.14.0"},
|
||||
{{OperatorType::kSelectV2, 1}, kPendingReleaseOpVersion},
|
||||
|
Loading…
x
Reference in New Issue
Block a user