From e07863b456abff13b9499335b7cd8f0379c06c32 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Sun, 12 Jan 2020 22:00:06 -0800 Subject: [PATCH] Add segment_sum op to Tensorflow Lite PiperOrigin-RevId: 289377531 Change-Id: Ie8aa95ca9d6b32eb2c5eb8a11c96d6ed3b3464d9 --- tensorflow/lite/builtin_ops.h | 1 + .../lite/core/api/flatbuffer_conversions.cc | 1 + tensorflow/lite/kernels/BUILD | 13 ++ tensorflow/lite/kernels/builtin_op_kernels.h | 1 + .../internal/reference/reference_ops.h | 19 +++ tensorflow/lite/kernels/register.cc | 1 + tensorflow/lite/kernels/register_ref.cc | 2 + tensorflow/lite/kernels/segment_sum.cc | 112 +++++++++++++++ tensorflow/lite/kernels/segment_sum_test.cc | 112 +++++++++++++++ tensorflow/lite/schema/schema.fbs | 9 +- tensorflow/lite/schema/schema_generated.h | 128 ++++++++++++++++-- tensorflow/lite/toco/model.h | 1 + tensorflow/lite/toco/tflite/op_version.cc | 1 + 13 files changed, 391 insertions(+), 10 deletions(-) create mode 100644 tensorflow/lite/kernels/segment_sum.cc create mode 100644 tensorflow/lite/kernels/segment_sum_test.cc diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index ad5f6112baa..c4e2907ffa9 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -151,6 +151,7 @@ typedef enum { kTfLiteBuiltinScatterNd = 122, kTfLiteBuiltinSelectV2 = 123, kTfLiteBuiltinDensify = 124, + kTfLiteBuiltinSegmentSum = 125, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 7f30665cffe..90f06781d92 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -826,6 +826,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_NON_MAX_SUPPRESSION_V5: case BuiltinOperator_SCATTER_ND: case BuiltinOperator_DENSIFY: + case BuiltinOperator_SEGMENT_SUM: break; } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 2327534c159..fd7b5362790 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -481,6 +481,7 @@ cc_library( "reverse_sequence.cc", "round.cc", "scatter_nd.cc", + "segment_sum.cc", "select.cc", "shape.cc", "skip_gram.cc", @@ -2059,4 +2060,16 @@ cc_test( ], ) +cc_test( + name = "segment_sum_test", + srcs = ["segment_sum_test.cc"], + deps = [ + ":builtin_ops", + ":test_main", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]}) diff --git a/tensorflow/lite/kernels/builtin_op_kernels.h b/tensorflow/lite/kernels/builtin_op_kernels.h index 67669f85d0e..e5f00ddd229 100644 --- a/tensorflow/lite/kernels/builtin_op_kernels.h +++ b/tensorflow/lite/kernels/builtin_op_kernels.h @@ -118,6 +118,7 @@ TfLiteRegistration* Register_RNN(); TfLiteRegistration* Register_ROUND(); TfLiteRegistration* Register_RSQRT(); TfLiteRegistration* Register_SCATTER_ND(); +TfLiteRegistration* Register_SEGMENT_SUM(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SELECT_V2(); TfLiteRegistration* Register_SHAPE(); diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index b3969d24381..3b581fab519 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -3033,6 +3033,25 @@ inline void HardSwish(const HardSwishParams& params, } } +template +inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& segment_ids_shape, + const int32_t* segment_ids_data, + const RuntimeShape& output_shape, T* output_data) { + const int segment_flat_size = + MatchingFlatSizeSkipDim(input_shape, 0, output_shape); + + memset(output_data, 0, sizeof(T) * output_shape.FlatSize()); + + for (int i = 0; i < input_shape.Dims(0); i++) { + int output_index = segment_ids_data[i]; + for (int j = 0; j < segment_flat_size; ++j) { + output_data[output_index * segment_flat_size + j] += + input_data[i * segment_flat_size + j]; + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index f8ffedbfc02..4435008b653 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -281,6 +281,7 @@ BuiltinOpResolver::BuiltinOpResolver() { Register_NON_MAX_SUPPRESSION_V5()); AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND()); AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY()); + AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM()); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index e40ba896e7a..2381e8f8c9d 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -133,6 +133,7 @@ TfLiteRegistration* Register_QUANTIZE(); TfLiteRegistration* Register_HARD_SWISH_REF(); TfLiteRegistration* Register_DEPTH_TO_SPACE_REF(); TfLiteRegistration* Register_SELECT_V2(); +TfLiteRegistration* Register_SEGMENT_SUM(); namespace { @@ -286,6 +287,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE()); AddBuiltin(BuiltinOperator_HARD_SWISH, Register_HARD_SWISH_REF()); AddBuiltin(BuiltinOperator_SELECT_V2, Register_SELECT_V2()); + AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/segment_sum.cc b/tensorflow/lite/kernels/segment_sum.cc new file mode 100644 index 00000000000..db8aa688ebe --- /dev/null +++ b/tensorflow/lite/kernels/segment_sum.cc @@ -0,0 +1,112 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace segment_sum { + +static const int kInputDataTensor = 0; +static const int kInputSegmentIdsTensor = 1; +static const int kOutputTensor = 0; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const TfLiteTensor* data, + const TfLiteTensor* segment_ids, + TfLiteTensor* output) { + int max_index = -1; + const int segment_id_size = segment_ids->dims->data[0]; + if (segment_id_size > 0) { + max_index = segment_ids->data.i32[segment_id_size - 1]; + } + const int data_rank = NumDimensions(data); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data)); + output_shape->data[0] = max_index + 1; + for (int i = 1; i < data_rank; ++i) { + output_shape->data[i] = data->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* data = GetInput(context, node, kInputDataTensor); + const TfLiteTensor* segment_ids = + GetInput(context, node, kInputSegmentIdsTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE(context, + data->type == kTfLiteInt32 || data->type == kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32); + + if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + + return ResizeOutputTensor(context, data, segment_ids, output); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* data = GetInput(context, node, kInputDataTensor); + const TfLiteTensor* segment_ids = + GetInput(context, node, kInputSegmentIdsTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, data, segment_ids, output)); + } + +#define TF_LITE_SEGMENT_SUM(dtype) \ + reference_ops::SegmentSum( \ + GetTensorShape(data), GetTensorData(data), \ + GetTensorShape(segment_ids), GetTensorData(segment_ids), \ + GetTensorShape(output), GetTensorData(output)); + switch (data->type) { + case kTfLiteInt32: + TF_LITE_SEGMENT_SUM(int32_t); + break; + case kTfLiteFloat32: + TF_LITE_SEGMENT_SUM(float); + break; + default: + context->ReportError(context, + "Currently SegmentSum doesn't support type: %s", + TfLiteTypeGetName(data->type)); + return kTfLiteError; + } +#undef TF_LITE_SEGMENT_SUM + return kTfLiteOk; +} + +} // namespace segment_sum + +TfLiteRegistration* Register_SEGMENT_SUM() { + static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare, + segment_sum::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/segment_sum_test.cc b/tensorflow/lite/kernels/segment_sum_test.cc new file mode 100644 index 00000000000..d083feb44aa --- /dev/null +++ b/tensorflow/lite/kernels/segment_sum_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SegmentSumOpModel : public SingleOpModel { + public: + SegmentSumOpModel(const TensorData& data, const TensorData& segment_ids) { + data_id_ = AddInput(data); + segment_ids_id_ = AddInput(segment_ids); + output_id_ = AddOutput(data.type); + SetBuiltinOp(BuiltinOperator_SEGMENT_SUM, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(data_id_), GetShape(segment_ids_id_)}); + } + + int data() const { return data_id_; } + int segment_ids() const { return segment_ids_id_; } + std::vector GetOutput() { return ExtractVector(output_id_); } + std::vector GetOutputShape() { return GetTensorShape(output_id_); } + + protected: + int data_id_; + int segment_ids_id_; + int output_id_; +}; + +TEST(SegmentSumOpModelTest, Int32Test_Simple) { + SegmentSumOpModel model({TensorType_INT32, {3, 4}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), + {1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8}); + model.PopulateTensor(model.segment_ids(), {0, 0, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(SegmentSumOpModelTest, Int32Test_OneDimension) { + SegmentSumOpModel model({TensorType_INT32, {3}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3}); + model.PopulateTensor(model.segment_ids(), {0, 0, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(SegmentSumOpModelTest, Int32Test_ThreeDimensions) { + SegmentSumOpModel model({TensorType_INT32, {3, 2, 1}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.segment_ids(), {0, 0, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 6, 5, 6})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1})); +} + +TEST(SegmentSumOpModelTest, Float32Test_Simple) { + SegmentSumOpModel model({TensorType_FLOAT32, {3, 4}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), + {1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8}); + model.PopulateTensor(model.segment_ids(), {0, 0, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5.0f, 5.0f, 5.0f, 5.0f, 5.0f, + 6.0f, 7.0f, 8.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); +} + +TEST(SegmentSumOpModelTest, Float32Test_OneDimension) { + SegmentSumOpModel model({TensorType_FLOAT32, {3}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3}); + model.PopulateTensor(model.segment_ids(), {0, 0, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({3.0f, 3.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) { + SegmentSumOpModel model({TensorType_FLOAT32, {3, 2, 1}}, + {TensorType_INT32, {3}}); + model.PopulateTensor(model.data(), {1, 2, 3, 4, 5, 6}); + model.PopulateTensor(model.segment_ids(), {0, 0, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({4.0f, 6.0f, 5.0f, 6.0f})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index ea310734525..9793a02eb9f 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -317,7 +317,8 @@ enum BuiltinOperator : byte { NON_MAX_SUPPRESSION_V5 = 121, SCATTER_ND = 122, SELECT_V2 = 123, - DENSIFY = 124 + DENSIFY = 124, + SEGMENT_SUM = 125 } @@ -421,7 +422,8 @@ union BuiltinOptions { NonMaxSuppressionV5Options, ScatterNdOptions, SelectV2Options, - DensifyOptions + DensifyOptions, + SegmentSumOptions } enum Padding : byte { SAME, VALID } @@ -911,6 +913,9 @@ table SelectV2Options { table DensifyOptions { } +table SegmentSumOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 51ae63a5441..fc1708f8703 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -334,6 +334,9 @@ struct SelectV2OptionsT; struct DensifyOptions; struct DensifyOptionsT; +struct SegmentSumOptions; +struct SegmentSumOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -645,11 +648,12 @@ enum BuiltinOperator { BuiltinOperator_SCATTER_ND = 122, BuiltinOperator_SELECT_V2 = 123, BuiltinOperator_DENSIFY = 124, + BuiltinOperator_SEGMENT_SUM = 125, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_DENSIFY + BuiltinOperator_MAX = BuiltinOperator_SEGMENT_SUM }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[125] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[126] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -775,7 +779,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[125] { BuiltinOperator_NON_MAX_SUPPRESSION_V5, BuiltinOperator_SCATTER_ND, BuiltinOperator_SELECT_V2, - BuiltinOperator_DENSIFY + BuiltinOperator_DENSIFY, + BuiltinOperator_SEGMENT_SUM }; return values; } @@ -907,13 +912,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "SCATTER_ND", "SELECT_V2", "DENSIFY", + "SEGMENT_SUM", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (e < BuiltinOperator_ADD || e > BuiltinOperator_DENSIFY) return ""; + if (e < BuiltinOperator_ADD || e > BuiltinOperator_SEGMENT_SUM) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -1019,11 +1025,12 @@ enum BuiltinOptions { BuiltinOptions_ScatterNdOptions = 97, BuiltinOptions_SelectV2Options = 98, BuiltinOptions_DensifyOptions = 99, + BuiltinOptions_SegmentSumOptions = 100, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_DensifyOptions + BuiltinOptions_MAX = BuiltinOptions_SegmentSumOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[100] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[101] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1124,7 +1131,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[100] { BuiltinOptions_NonMaxSuppressionV5Options, BuiltinOptions_ScatterNdOptions, BuiltinOptions_SelectV2Options, - BuiltinOptions_DensifyOptions + BuiltinOptions_DensifyOptions, + BuiltinOptions_SegmentSumOptions }; return values; } @@ -1231,13 +1239,14 @@ inline const char * const *EnumNamesBuiltinOptions() { "ScatterNdOptions", "SelectV2Options", "DensifyOptions", + "SegmentSumOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (e < BuiltinOptions_NONE || e > BuiltinOptions_DensifyOptions) return ""; + if (e < BuiltinOptions_NONE || e > BuiltinOptions_SegmentSumOptions) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1642,6 +1651,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2466,6 +2479,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_DensifyOptions ? reinterpret_cast(value) : nullptr; } + SegmentSumOptionsT *AsSegmentSumOptions() { + return type == BuiltinOptions_SegmentSumOptions ? + reinterpret_cast(value) : nullptr; + } + const SegmentSumOptionsT *AsSegmentSumOptions() const { + return type == BuiltinOptions_SegmentSumOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -8659,6 +8680,46 @@ inline flatbuffers::Offset CreateDensifyOptions( flatbuffers::Offset CreateDensifyOptions(flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SegmentSumOptionsT : public flatbuffers::NativeTable { + typedef SegmentSumOptions TableType; + SegmentSumOptionsT() { + } +}; + +struct SegmentSumOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SegmentSumOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SegmentSumOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SegmentSumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SegmentSumOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit SegmentSumOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SegmentSumOptionsBuilder &operator=(const SegmentSumOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSegmentSumOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + SegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -9092,6 +9153,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const DensifyOptions *builtin_options_as_DensifyOptions() const { return builtin_options_type() == BuiltinOptions_DensifyOptions ? static_cast(builtin_options()) : nullptr; } + const SegmentSumOptions *builtin_options_as_SegmentSumOptions() const { + return builtin_options_type() == BuiltinOptions_SegmentSumOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -9524,6 +9588,10 @@ template<> inline const DensifyOptions *Operator::builtin_options_as inline const SegmentSumOptions *Operator::builtin_options_as() const { + return builtin_options_as_SegmentSumOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -12818,6 +12886,29 @@ inline flatbuffers::Offset CreateDensifyOptions(flatbuffers::Fla _fbb); } +inline SegmentSumOptionsT *SegmentSumOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SegmentSumOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SegmentSumOptions::UnPackTo(SegmentSumOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset SegmentSumOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSegmentSumOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSegmentSumOptions(flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SegmentSumOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSegmentSumOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -13507,6 +13598,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } } @@ -13921,6 +14016,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -14323,6 +14422,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateDensifyOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(value); + return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -14725,6 +14828,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new DensifyOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_SegmentSumOptions: { + value = new SegmentSumOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -15227,6 +15334,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index d225915bf74..7b07b1b8d43 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -146,6 +146,7 @@ enum class OperatorType : uint8 { // instead of being given as plain constant arrays. So we need to insert // special nodes in the graph to shuffle axes. kReorderAxes, + kSegmentSum, kSelect, kSelectV2, kSparseToDense, diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index 1a01d501152..2e27c1d8a0f 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -197,6 +197,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kLess, 2}, "1.14.0"}, {{OperatorType::kLessEqual, 1}, "1.14.0"}, {{OperatorType::kLessEqual, 2}, "1.14.0"}, + {{OperatorType::kSegmentSum, 1}, kPendingReleaseOpVersion}, {{OperatorType::kSelect, 1}, "1.14.0"}, {{OperatorType::kSelect, 2}, "1.14.0"}, {{OperatorType::kSelectV2, 1}, kPendingReleaseOpVersion},