Add segment_sum op to Tensorflow Lite

PiperOrigin-RevId: 289377531
Change-Id: Ie8aa95ca9d6b32eb2c5eb8a11c96d6ed3b3464d9
This commit is contained in:
Jaesung Chung 2020-01-12 22:00:06 -08:00 committed by TensorFlower Gardener
parent c25b583371
commit e07863b456
13 changed files with 391 additions and 10 deletions

View File

@ -151,6 +151,7 @@ typedef enum {
kTfLiteBuiltinScatterNd = 122,
kTfLiteBuiltinSelectV2 = 123,
kTfLiteBuiltinDensify = 124,
kTfLiteBuiltinSegmentSum = 125,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -826,6 +826,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_NON_MAX_SUPPRESSION_V5:
case BuiltinOperator_SCATTER_ND:
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
break;
}
return kTfLiteOk;

View File

@ -481,6 +481,7 @@ cc_library(
"reverse_sequence.cc",
"round.cc",
"scatter_nd.cc",
"segment_sum.cc",
"select.cc",
"shape.cc",
"skip_gram.cc",
@ -2059,4 +2060,16 @@ cc_test(
],
)
cc_test(
name = "segment_sum_test",
srcs = ["segment_sum_test.cc"],
deps = [
":builtin_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"@com_google_googletest//:gtest",
],
)
tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})

View File

@ -118,6 +118,7 @@ TfLiteRegistration* Register_RNN();
TfLiteRegistration* Register_ROUND();
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SCATTER_ND();
TfLiteRegistration* Register_SEGMENT_SUM();
TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SHAPE();

View File

@ -3033,6 +3033,25 @@ inline void HardSwish(const HardSwishParams& params,
}
}
template <typename T>
inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& segment_ids_shape,
const int32_t* segment_ids_data,
const RuntimeShape& output_shape, T* output_data) {
const int segment_flat_size =
MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
for (int i = 0; i < input_shape.Dims(0); i++) {
int output_index = segment_ids_data[i];
for (int j = 0; j < segment_flat_size; ++j) {
output_data[output_index * segment_flat_size + j] +=
input_data[i * segment_flat_size + j];
}
}
}
} // namespace reference_ops
} // namespace tflite

View File

@ -281,6 +281,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
Register_NON_MAX_SUPPRESSION_V5());
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -133,6 +133,7 @@ TfLiteRegistration* Register_QUANTIZE();
TfLiteRegistration* Register_HARD_SWISH_REF();
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
TfLiteRegistration* Register_SELECT_V2();
TfLiteRegistration* Register_SEGMENT_SUM();
namespace {
@ -286,6 +287,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF());
AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2());
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -0,0 +1,112 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace segment_sum {
static const int kInputDataTensor = 0;
static const int kInputSegmentIdsTensor = 1;
static const int kOutputTensor = 0;
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const TfLiteTensor* data,
const TfLiteTensor* segment_ids,
TfLiteTensor* output) {
int max_index = -1;
const int segment_id_size = segment_ids->dims->data[0];
if (segment_id_size > 0) {
max_index = segment_ids->data.i32[segment_id_size - 1];
}
const int data_rank = NumDimensions(data);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
output_shape->data[0] = max_index + 1;
for (int i = 1; i < data_rank; ++i) {
output_shape->data[i] = data->dims->data[i];
}
return context->ResizeTensor(context, output, output_shape);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
const TfLiteTensor* segment_ids =
GetInput(context, node, kInputSegmentIdsTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context,
data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) {
SetTensorToDynamic(output);
return kTfLiteOk;
}
return ResizeOutputTensor(context, data, segment_ids, output);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* data = GetInput(context, node, kInputDataTensor);
const TfLiteTensor* segment_ids =
GetInput(context, node, kInputSegmentIdsTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
ResizeOutputTensor(context, data, segment_ids, output));
}
#define TF_LITE_SEGMENT_SUM(dtype) \
reference_ops::SegmentSum<dtype>( \
GetTensorShape(data), GetTensorData<dtype>(data), \
GetTensorShape(segment_ids), GetTensorData<int32_t>(segment_ids), \
GetTensorShape(output), GetTensorData<dtype>(output));
switch (data->type) {
case kTfLiteInt32:
TF_LITE_SEGMENT_SUM(int32_t);
break;
case kTfLiteFloat32:
TF_LITE_SEGMENT_SUM(float);
break;
default:
context->ReportError(context,
"Currently SegmentSum doesn't support type: %s",
TfLiteTypeGetName(data->type));
return kTfLiteError;
}
#undef TF_LITE_SEGMENT_SUM
return kTfLiteOk;
}
} // namespace segment_sum
TfLiteRegistration* Register_SEGMENT_SUM() {
static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare,
segment_sum::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,112 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
template <typename T>
class SegmentSumOpModel : public SingleOpModel {
public:
SegmentSumOpModel(const TensorData& data, const TensorData& segment_ids) {
data_id_ = AddInput(data);
segment_ids_id_ = AddInput(segment_ids);
output_id_ = AddOutput(data.type);
SetBuiltinOp(BuiltinOperator_SEGMENT_SUM, BuiltinOptions_NONE, 0);
BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_)});
}
int data() const { return data_id_; }
int segment_ids() const { return segment_ids_id_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
std::vector<int32_t> GetOutputShape() { return GetTensorShape(output_id_); }
protected:
int data_id_;
int segment_ids_id_;
int output_id_;
};
TEST(SegmentSumOpModelTest, Int32Test_Simple) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 4}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(),
{1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 6, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
}
TEST(SegmentSumOpModelTest, Int32Test_OneDimension) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 3}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
}
TEST(SegmentSumOpModelTest, Int32Test_ThreeDimensions) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2, 1}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 5, 6}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
}
TEST(SegmentSumOpModelTest, Float32Test_Simple) {
SegmentSumOpModel<float> model({TensorType_FLOAT32, {3, 4}},
{TensorType_INT32, {3}});
model.PopulateTensor<float>(model.data(),
{1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8});
model.PopulateTensor<int>(model.segment_ids(), {0, 0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5.0f, 5.0f, 5.0f, 5.0f, 5.0f,
6.0f, 7.0f, 8.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
}
TEST(SegmentSumOpModelTest, Float32Test_OneDimension) {
SegmentSumOpModel<float> model({TensorType_FLOAT32, {3}},
{TensorType_INT32, {3}});
model.PopulateTensor<float>(model.data(), {1, 2, 3});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({3.0f, 3.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
}
TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) {
SegmentSumOpModel<float> model({TensorType_FLOAT32, {3, 2, 1}},
{TensorType_INT32, {3}});
model.PopulateTensor<float>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAreArray({4.0f, 6.0f, 5.0f, 6.0f}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
}
} // namespace
} // namespace tflite

View File

@ -317,7 +317,8 @@ enum BuiltinOperator : byte {
NON_MAX_SUPPRESSION_V5 = 121,
SCATTER_ND = 122,
SELECT_V2 = 123,
DENSIFY = 124
DENSIFY = 124,
SEGMENT_SUM = 125
}
@ -421,7 +422,8 @@ union BuiltinOptions {
NonMaxSuppressionV5Options,
ScatterNdOptions,
SelectV2Options,
DensifyOptions
DensifyOptions,
SegmentSumOptions
}
enum Padding : byte { SAME, VALID }
@ -911,6 +913,9 @@ table SelectV2Options {
table DensifyOptions {
}
table SegmentSumOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -334,6 +334,9 @@ struct SelectV2OptionsT;
struct DensifyOptions;
struct DensifyOptionsT;
struct SegmentSumOptions;
struct SegmentSumOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -645,11 +648,12 @@ enum BuiltinOperator {
BuiltinOperator_SCATTER_ND = 122,
BuiltinOperator_SELECT_V2 = 123,
BuiltinOperator_DENSIFY = 124,
BuiltinOperator_SEGMENT_SUM = 125,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_DENSIFY
BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[125] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -775,7 +779,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[125] {
BuiltinOperator_NON_MAX_SUPPRESSION_V5,
BuiltinOperator_SCATTER_ND,
BuiltinOperator_SELECT_V2,
BuiltinOperator_DENSIFY
BuiltinOperator_DENSIFY,
BuiltinOperator_SEGMENT_SUM
};
return values;
}
@ -907,13 +912,14 @@ inline const char * const *EnumNamesBuiltinOperator() {
"SCATTER_ND",
"SELECT_V2",
"DENSIFY",
"SEGMENT_SUM",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (e < BuiltinOperator_ADD || e > BuiltinOperator_DENSIFY) return "";
if (e < BuiltinOperator_ADD || e > BuiltinOperator_SEGMENT_SUM) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -1019,11 +1025,12 @@ enum BuiltinOptions {
BuiltinOptions_ScatterNdOptions = 97,
BuiltinOptions_SelectV2Options = 98,
BuiltinOptions_DensifyOptions = 99,
BuiltinOptions_SegmentSumOptions = 100,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_DensifyOptions
BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[100] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1124,7 +1131,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[100] {
BuiltinOptions_NonMaxSuppressionV5Options,
BuiltinOptions_ScatterNdOptions,
BuiltinOptions_SelectV2Options,
BuiltinOptions_DensifyOptions
BuiltinOptions_DensifyOptions,
BuiltinOptions_SegmentSumOptions
};
return values;
}
@ -1231,13 +1239,14 @@ inline const char * const *EnumNamesBuiltinOptions() {
"ScatterNdOptions",
"SelectV2Options",
"DensifyOptions",
"SegmentSumOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (e < BuiltinOptions_NONE || e > BuiltinOptions_DensifyOptions) return "";
if (e < BuiltinOptions_NONE || e > BuiltinOptions_SegmentSumOptions) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1642,6 +1651,10 @@ template<> struct BuiltinOptionsTraits<DensifyOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions;
};
template<> struct BuiltinOptionsTraits<SegmentSumOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2466,6 +2479,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_DensifyOptions ?
reinterpret_cast<const DensifyOptionsT *>(value) : nullptr;
}
SegmentSumOptionsT *AsSegmentSumOptions() {
return type == BuiltinOptions_SegmentSumOptions ?
reinterpret_cast<SegmentSumOptionsT *>(value) : nullptr;
}
const SegmentSumOptionsT *AsSegmentSumOptions() const {
return type == BuiltinOptions_SegmentSumOptions ?
reinterpret_cast<const SegmentSumOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -8659,6 +8680,46 @@ inline flatbuffers::Offset<DensifyOptions> CreateDensifyOptions(
flatbuffers::Offset<DensifyOptions> CreateDensifyOptions(flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct SegmentSumOptionsT : public flatbuffers::NativeTable {
typedef SegmentSumOptions TableType;
SegmentSumOptionsT() {
}
};
struct SegmentSumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef SegmentSumOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
SegmentSumOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(SegmentSumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<SegmentSumOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct SegmentSumOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit SegmentSumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
SegmentSumOptionsBuilder &operator=(const SegmentSumOptionsBuilder &);
flatbuffers::Offset<SegmentSumOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<SegmentSumOptions>(end);
return o;
}
};
inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
SegmentSumOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -9092,6 +9153,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const DensifyOptions *builtin_options_as_DensifyOptions() const {
return builtin_options_type() == BuiltinOptions_DensifyOptions ? static_cast<const DensifyOptions *>(builtin_options()) : nullptr;
}
const SegmentSumOptions *builtin_options_as_SegmentSumOptions() const {
return builtin_options_type() == BuiltinOptions_SegmentSumOptions ? static_cast<const SegmentSumOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -9524,6 +9588,10 @@ template<> inline const DensifyOptions *Operator::builtin_options_as<DensifyOpti
return builtin_options_as_DensifyOptions();
}
template<> inline const SegmentSumOptions *Operator::builtin_options_as<SegmentSumOptions>() const {
return builtin_options_as_SegmentSumOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -12818,6 +12886,29 @@ inline flatbuffers::Offset<DensifyOptions> CreateDensifyOptions(flatbuffers::Fla
_fbb);
}
inline SegmentSumOptionsT *SegmentSumOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new SegmentSumOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void SegmentSumOptions::UnPackTo(SegmentSumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<SegmentSumOptions> SegmentSumOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateSegmentSumOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<SegmentSumOptions> CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SegmentSumOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateSegmentSumOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -13507,6 +13598,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const DensifyOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_SegmentSumOptions: {
auto ptr = reinterpret_cast<const SegmentSumOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return true;
}
}
@ -13921,6 +14016,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const DensifyOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_SegmentSumOptions: {
auto ptr = reinterpret_cast<const SegmentSumOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -14323,6 +14422,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const DensifyOptionsT *>(value);
return CreateDensifyOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_SegmentSumOptions: {
auto ptr = reinterpret_cast<const SegmentSumOptionsT *>(value);
return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -14725,6 +14828,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new DensifyOptionsT(*reinterpret_cast<DensifyOptionsT *>(u.value));
break;
}
case BuiltinOptions_SegmentSumOptions: {
value = new SegmentSumOptionsT(*reinterpret_cast<SegmentSumOptionsT *>(u.value));
break;
}
default:
break;
}
@ -15227,6 +15334,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_SegmentSumOptions: {
auto ptr = reinterpret_cast<SegmentSumOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -146,6 +146,7 @@ enum class OperatorType : uint8 {
// instead of being given as plain constant arrays. So we need to insert
// special nodes in the graph to shuffle axes.
kReorderAxes,
kSegmentSum,
kSelect,
kSelectV2,
kSparseToDense,

View File

@ -197,6 +197,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kLess, 2}, "1.14.0"},
{{OperatorType::kLessEqual, 1}, "1.14.0"},
{{OperatorType::kLessEqual, 2}, "1.14.0"},
{{OperatorType::kSegmentSum, 1}, kPendingReleaseOpVersion},
{{OperatorType::kSelect, 1}, "1.14.0"},
{{OperatorType::kSelect, 2}, "1.14.0"},
{{OperatorType::kSelectV2, 1}, kPendingReleaseOpVersion},