Add new op schema for TFlite equivalent of tf.gather_nd.

PiperOrigin-RevId: 232728420
This commit is contained in:
A. Unique TensorFlower 2019-02-06 12:56:47 -08:00 committed by TensorFlower Gardener
parent ca62689feb
commit 0c4f5dfea4
5 changed files with 126 additions and 6 deletions

View File

@ -132,6 +132,7 @@ typedef enum {
kTfLiteBuiltinCeil = 104,
kTfLiteBuiltinReverseV2 = 105,
kTfLiteBuiltinAddN = 106,
kTfLiteBuiltinGatherNd = 107,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -728,6 +728,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_REVERSE_V2:
case BuiltinOperator_ADD_N:
case BuiltinOperator_GATHER_ND:
break;
}
return kTfLiteOk;

View File

@ -665,6 +665,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_CEIL:
case tflite::BuiltinOperator_REVERSE_V2:
case tflite::BuiltinOperator_ADD_N:
case tflite::BuiltinOperator_GATHER_ND:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -220,6 +220,7 @@ enum BuiltinOperator : byte {
CEIL = 104,
REVERSE_V2 = 105,
ADD_N = 106,
GATHER_ND = 107,
}
// Options for the builtin operators.
@ -306,6 +307,7 @@ union BuiltinOptions {
UniqueOptions,
ReverseV2Options,
AddNOptions,
GatherNdOptions,
}
enum Padding : byte { SAME, VALID }
@ -729,6 +731,9 @@ table ReverseV2Options {
table AddNOptions {
}
table GatherNdOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -277,6 +277,9 @@ struct ReverseV2OptionsT;
struct AddNOptions;
struct AddNOptionsT;
struct GatherNdOptions;
struct GatherNdOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -533,11 +536,12 @@ enum BuiltinOperator {
BuiltinOperator_CEIL = 104,
BuiltinOperator_REVERSE_V2 = 105,
BuiltinOperator_ADD_N = 106,
BuiltinOperator_GATHER_ND = 107,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_ADD_N
BuiltinOperator_MAX = BuiltinOperator_GATHER_ND
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[106] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[107] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -644,7 +648,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[106] {
BuiltinOperator_UNIQUE,
BuiltinOperator_CEIL,
BuiltinOperator_REVERSE_V2,
BuiltinOperator_ADD_N
BuiltinOperator_ADD_N,
BuiltinOperator_GATHER_ND
};
return values;
}
@ -758,6 +763,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
"CEIL",
"REVERSE_V2",
"ADD_N",
"GATHER_ND",
nullptr
};
return names;
@ -852,11 +858,12 @@ enum BuiltinOptions {
BuiltinOptions_UniqueOptions = 80,
BuiltinOptions_ReverseV2Options = 81,
BuiltinOptions_AddNOptions = 82,
BuiltinOptions_GatherNdOptions = 83,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_AddNOptions
BuiltinOptions_MAX = BuiltinOptions_GatherNdOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[83] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[84] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -940,7 +947,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[83] {
BuiltinOptions_SplitVOptions,
BuiltinOptions_UniqueOptions,
BuiltinOptions_ReverseV2Options,
BuiltinOptions_AddNOptions
BuiltinOptions_AddNOptions,
BuiltinOptions_GatherNdOptions
};
return values;
}
@ -1030,6 +1038,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"UniqueOptions",
"ReverseV2Options",
"AddNOptions",
"GatherNdOptions",
nullptr
};
return names;
@ -1372,6 +1381,10 @@ template<> struct BuiltinOptionsTraits<AddNOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions;
};
template<> struct BuiltinOptionsTraits<GatherNdOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2059,6 +2072,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_AddNOptions ?
reinterpret_cast<const AddNOptionsT *>(value) : nullptr;
}
GatherNdOptionsT *AsGatherNdOptions() {
return type == BuiltinOptions_GatherNdOptions ?
reinterpret_cast<GatherNdOptionsT *>(value) : nullptr;
}
const GatherNdOptionsT *AsGatherNdOptions() const {
return type == BuiltinOptions_GatherNdOptions ?
reinterpret_cast<const GatherNdOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7235,6 +7256,46 @@ inline flatbuffers::Offset<AddNOptions> CreateAddNOptions(
flatbuffers::Offset<AddNOptions> CreateAddNOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct GatherNdOptionsT : public flatbuffers::NativeTable {
typedef GatherNdOptions TableType;
GatherNdOptionsT() {
}
};
struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef GatherNdOptionsT NativeTableType;
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
verifier.EndTable();
}
GatherNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<GatherNdOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct GatherNdOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
explicit GatherNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
GatherNdOptionsBuilder &operator=(const GatherNdOptionsBuilder &);
flatbuffers::Offset<GatherNdOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<GatherNdOptions>(end);
return o;
}
};
inline flatbuffers::Offset<GatherNdOptions> CreateGatherNdOptions(
flatbuffers::FlatBufferBuilder &_fbb) {
GatherNdOptionsBuilder builder_(_fbb);
return builder_.Finish();
}
flatbuffers::Offset<GatherNdOptions> CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -7614,6 +7675,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const AddNOptions *builtin_options_as_AddNOptions() const {
return builtin_options_type() == BuiltinOptions_AddNOptions ? static_cast<const AddNOptions *>(builtin_options()) : nullptr;
}
const GatherNdOptions *builtin_options_as_GatherNdOptions() const {
return builtin_options_type() == BuiltinOptions_GatherNdOptions ? static_cast<const GatherNdOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -7973,6 +8037,10 @@ template<> inline const AddNOptions *Operator::builtin_options_as<AddNOptions>()
return builtin_options_as_AddNOptions();
}
template<> inline const GatherNdOptions *Operator::builtin_options_as<GatherNdOptions>() const {
return builtin_options_as_GatherNdOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -10666,6 +10734,29 @@ inline flatbuffers::Offset<AddNOptions> CreateAddNOptions(flatbuffers::FlatBuffe
_fbb);
}
inline GatherNdOptionsT *GatherNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new GatherNdOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void GatherNdOptions::UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
}
inline flatbuffers::Offset<GatherNdOptions> GatherNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateGatherNdOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<GatherNdOptions> CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GatherNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
return tflite::CreateGatherNdOptions(
_fbb);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -11252,6 +11343,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const AddNOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_GatherNdOptions: {
auto ptr = reinterpret_cast<const GatherNdOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -11598,6 +11693,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const AddNOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_GatherNdOptions: {
auto ptr = reinterpret_cast<const GatherNdOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -11932,6 +12031,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const AddNOptionsT *>(value);
return CreateAddNOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_GatherNdOptions: {
auto ptr = reinterpret_cast<const GatherNdOptionsT *>(value);
return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -12266,6 +12369,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new AddNOptionsT(*reinterpret_cast<AddNOptionsT *>(u.value));
break;
}
case BuiltinOptions_GatherNdOptions: {
value = new GatherNdOptionsT(*reinterpret_cast<GatherNdOptionsT *>(u.value));
break;
}
default:
break;
}
@ -12683,6 +12790,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_GatherNdOptions: {
auto ptr = reinterpret_cast<GatherNdOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;