From 0c4f5dfea4ceb3d7c0b46fc04828420a344f7598 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Feb 2019 12:56:47 -0800 Subject: [PATCH] Add new op schema for TFlite equivalent of tf.gather_nd. PiperOrigin-RevId: 232728420 --- tensorflow/lite/builtin_ops.h | 1 + .../lite/core/api/flatbuffer_conversions.cc | 1 + tensorflow/lite/nnapi_delegate.cc | 1 + tensorflow/lite/schema/schema.fbs | 5 + tensorflow/lite/schema/schema_generated.h | 124 +++++++++++++++++- 5 files changed, 126 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index fc871d35924..1915565f4be 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -132,6 +132,7 @@ typedef enum { kTfLiteBuiltinCeil = 104, kTfLiteBuiltinReverseV2 = 105, kTfLiteBuiltinAddN = 106, + kTfLiteBuiltinGatherNd = 107, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 175caa1afa6..72667d4260e 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -728,6 +728,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_SQUARED_DIFFERENCE: case BuiltinOperator_REVERSE_V2: case BuiltinOperator_ADD_N: + case BuiltinOperator_GATHER_ND: break; } return kTfLiteOk; diff --git a/tensorflow/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc index 7c674688797..f69baf1ac87 100644 --- a/tensorflow/lite/nnapi_delegate.cc +++ b/tensorflow/lite/nnapi_delegate.cc @@ -665,6 +665,7 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_CEIL: case tflite::BuiltinOperator_REVERSE_V2: case tflite::BuiltinOperator_ADD_N: + case tflite::BuiltinOperator_GATHER_ND: logError("Op code %d is currently not delegated to NNAPI", builtin); return kTfLiteError; break; diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 098b7c71939..b9c6e988f9d 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -220,6 +220,7 @@ enum BuiltinOperator : byte { CEIL = 104, REVERSE_V2 = 105, ADD_N = 106, + GATHER_ND = 107, } // Options for the builtin operators. @@ -306,6 +307,7 @@ union BuiltinOptions { UniqueOptions, ReverseV2Options, AddNOptions, + GatherNdOptions, } enum Padding : byte { SAME, VALID } @@ -729,6 +731,9 @@ table ReverseV2Options { table AddNOptions { } +table GatherNdOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index f0a6b00f968..31177895290 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -277,6 +277,9 @@ struct ReverseV2OptionsT; struct AddNOptions; struct AddNOptionsT; +struct GatherNdOptions; +struct GatherNdOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -533,11 +536,12 @@ enum BuiltinOperator { BuiltinOperator_CEIL = 104, BuiltinOperator_REVERSE_V2 = 105, BuiltinOperator_ADD_N = 106, + BuiltinOperator_GATHER_ND = 107, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_ADD_N + BuiltinOperator_MAX = BuiltinOperator_GATHER_ND }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[106] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[107] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -644,7 +648,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[106] { BuiltinOperator_UNIQUE, BuiltinOperator_CEIL, BuiltinOperator_REVERSE_V2, - BuiltinOperator_ADD_N + BuiltinOperator_ADD_N, + BuiltinOperator_GATHER_ND }; return values; } @@ -758,6 +763,7 @@ inline const char * const *EnumNamesBuiltinOperator() { "CEIL", "REVERSE_V2", "ADD_N", + "GATHER_ND", nullptr }; return names; @@ -852,11 +858,12 @@ enum BuiltinOptions { BuiltinOptions_UniqueOptions = 80, BuiltinOptions_ReverseV2Options = 81, BuiltinOptions_AddNOptions = 82, + BuiltinOptions_GatherNdOptions = 83, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_AddNOptions + BuiltinOptions_MAX = BuiltinOptions_GatherNdOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[83] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[84] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -940,7 +947,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[83] { BuiltinOptions_SplitVOptions, BuiltinOptions_UniqueOptions, BuiltinOptions_ReverseV2Options, - BuiltinOptions_AddNOptions + BuiltinOptions_AddNOptions, + BuiltinOptions_GatherNdOptions }; return values; } @@ -1030,6 +1038,7 @@ inline const char * const *EnumNamesBuiltinOptions() { "UniqueOptions", "ReverseV2Options", "AddNOptions", + "GatherNdOptions", nullptr }; return names; @@ -1372,6 +1381,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2059,6 +2072,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_AddNOptions ? reinterpret_cast(value) : nullptr; } + GatherNdOptionsT *AsGatherNdOptions() { + return type == BuiltinOptions_GatherNdOptions ? + reinterpret_cast(value) : nullptr; + } + const GatherNdOptionsT *AsGatherNdOptions() const { + return type == BuiltinOptions_GatherNdOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -7235,6 +7256,46 @@ inline flatbuffers::Offset CreateAddNOptions( flatbuffers::Offset CreateAddNOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct GatherNdOptionsT : public flatbuffers::NativeTable { + typedef GatherNdOptions TableType; + GatherNdOptionsT() { + } +}; + +struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GatherNdOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GatherNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GatherNdOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit GatherNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GatherNdOptionsBuilder &operator=(const GatherNdOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateGatherNdOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + GatherNdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -7614,6 +7675,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const AddNOptions *builtin_options_as_AddNOptions() const { return builtin_options_type() == BuiltinOptions_AddNOptions ? static_cast(builtin_options()) : nullptr; } + const GatherNdOptions *builtin_options_as_GatherNdOptions() const { + return builtin_options_type() == BuiltinOptions_GatherNdOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -7973,6 +8037,10 @@ template<> inline const AddNOptions *Operator::builtin_options_as() return builtin_options_as_AddNOptions(); } +template<> inline const GatherNdOptions *Operator::builtin_options_as() const { + return builtin_options_as_GatherNdOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -10666,6 +10734,29 @@ inline flatbuffers::Offset CreateAddNOptions(flatbuffers::FlatBuffe _fbb); } +inline GatherNdOptionsT *GatherNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GatherNdOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GatherNdOptions::UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset GatherNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGatherNdOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GatherNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGatherNdOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -11252,6 +11343,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -11598,6 +11693,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -11932,6 +12031,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateAddNOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(value); + return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -12266,6 +12369,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new AddNOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_GatherNdOptions: { + value = new GatherNdOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -12683,6 +12790,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr;