Graduate TFLite control flow ops from experimental to builtin

PiperOrigin-RevId: 259150573
This commit is contained in:
Yu-Cheng Ling 2019-07-20 15:45:31 -07:00 committed by TensorFlower Gardener
parent 488b385a7d
commit eedf79ed37
13 changed files with 471 additions and 129 deletions

View File

@ -368,9 +368,14 @@ class Translator {
const std::string& name,
unsigned buffer_idx);
CustomOptionsOffset CreateIfOpCustomOptions(mlir::TF::IfOp op);
CustomOptionsOffset CreateWhileOpCustomOptions(mlir::TF::WhileOp op);
// TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove
// these 2 functions here.
BufferOffset<tflite::Operator> BuildIfOperator(
mlir::TF::IfOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
BufferOffset<tflite::Operator> BuildWhileOperator(
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
@ -544,31 +549,36 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
builder_.CreateString(name), q_params, /*is_variable=*/false);
}
CustomOptionsOffset Translator::CreateIfOpCustomOptions(mlir::TF::IfOp op) {
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
mlir::TF::IfOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
flex_builder->Map([&]() {
flex_builder->Int("then_subgraph_index", then_subgraph_index);
flex_builder->Int("else_subgraph_index", else_subgraph_index);
});
flex_builder->Finish();
return builder_.CreateVector(flex_builder->GetBuffer());
auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
else_subgraph_index)
.Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_IfOptions,
builtin_options);
}
CustomOptionsOffset Translator::CreateWhileOpCustomOptions(
mlir::TF::WhileOp op) {
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
int cond_subgraph_index = subgraph_index_map_.at(op.cond().str());
int body_subgraph_index = subgraph_index_map_.at(op.body().str());
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
flex_builder->Map([&]() {
flex_builder->Int("cond_subgraph_index", cond_subgraph_index);
flex_builder->Int("body_subgraph_index", body_subgraph_index);
});
flex_builder->Finish();
return builder_.CreateVector(flex_builder->GetBuffer());
auto builtin_options = tflite::CreateWhileOptions(
builder_, cond_subgraph_index, body_subgraph_index)
.Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_WhileOptions,
builtin_options);
}
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
@ -712,63 +722,60 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (dialect == tf_dialect_) {
std::string op_name;
if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
return BuildIfOperator(ifOp, operands, results);
} else if (auto whileOp = dyn_cast<mlir::TF::WhileOp>(inst)) {
return BuildWhileOperator(whileOp, operands, results);
}
CustomOptionsOffset custom_options;
if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
op_name = "Experimental_If";
custom_options = CreateIfOpCustomOptions(ifOp);
} else if (auto whileOp = dyn_cast<mlir::TF::WhileOp>(inst)) {
op_name = "Experimental_While";
custom_options = CreateWhileOpCustomOptions(whileOp);
} else {
// Ops in TF dialect can either be custom ops or flex ops.
// The reason we go directly from TensorFlow dialect MLIR to tensorflow
// node instead of going to TF table gen'd ops via generated code is that
// we do not want to restrict custom and flex op conversion support to
// only those TF ops that are currently registered in MLIR. The current
// model is of an open op system.
//
// The following algorithm is followed:
// if flex is enabled and the op is whitelisted as flex
// we emit op as flex.
// if custom is enabled
// we emit the op as custom.
auto node_def = getTensorFlowNodeDef(inst);
if (!node_def) {
// Ops in TF dialect can either be custom ops or flex ops.
// The reason we go directly from TensorFlow dialect MLIR to tensorflow
// node instead of going to TF table gen'd ops via generated code is that
// we do not want to restrict custom and flex op conversion support to
// only those TF ops that are currently registered in MLIR. The current
// model is of an open op system.
//
// The following algorithm is followed:
// if flex is enabled and the op is whitelisted as flex
// we emit op as flex.
// if custom is enabled
// we emit the op as custom.
auto node_def = getTensorFlowNodeDef(inst);
if (!node_def) {
return llvm::None;
}
// Flex op case
// Eventually, the whitelist will go away and we will rely on some TF op
// trait (e.g. No side effect) to determine if it is a supported "Flex"
// op or not.
if (enabled_op_types_.contains(OpType::kSelectTf) &&
IsWhitelistedFlexOp(node_def->op())) {
// Construct ops as flex op encoding TensorFlow node definition
// as custom options.
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
// TF op name.
op_name = std::string(kFlexOpNamePrefix) + node_def->op();
if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
custom_options = *options;
} else {
return llvm::None;
}
// Flex op case
// Eventually, the whitelist will go away and we will rely on some TF op
// trait (e.g. No side effect) to determine if it is a supported "Flex"
// op or not.
if (enabled_op_types_.contains(OpType::kSelectTf) &&
IsWhitelistedFlexOp(node_def->op())) {
// Construct ops as flex op encoding TensorFlow node definition
// as custom options.
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
// TF op name.
op_name = std::string(kFlexOpNamePrefix) + node_def->op();
if (auto options =
CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
custom_options = *options;
} else {
return llvm::None;
}
} else if (enabled_op_types_.contains(OpType::kCustomOp)) {
// Generic case of custom ops - write using flex buffers since that
// is the only custom options supported by TFLite today.
op_name = node_def->op();
if (auto options =
CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
custom_options = *options;
} else {
return llvm::None;
}
} else if (enabled_op_types_.contains(OpType::kCustomOp)) {
// Generic case of custom ops - write using flex buffers since that
// is the only custom options supported by TFLite today.
op_name = node_def->op();
if (auto options =
CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
custom_options = *options;
} else {
return inst->emitOpError("is neither a custom op nor a flex op"),
llvm::None;
return llvm::None;
}
} else {
return inst->emitOpError("is neither a custom op nor a flex op"),
llvm::None;
}
uint32_t opcode_index =

View File

@ -1,12 +1,12 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LESS
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Experimental_If"
// CHECK-NEXT: builtin_code: IF
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
@ -52,8 +52,12 @@
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 116, 104, 101, 110, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 101, 108, 115, 101, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: builtin_options_type: IfOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: then_subgraph_index: 1,
// CHECK-NEXT: else_subgraph_index: 2
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
@ -88,7 +92,7 @@
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "cond_true"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
@ -123,7 +127,7 @@
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "cond_false"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",

View File

@ -3,8 +3,7 @@
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "Experimental_While"
// CHECK-NEXT: builtin_code: WHILE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: GREATER
// CHECK-NEXT: }, {
@ -49,8 +48,12 @@
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2, 3 ],
// CHECK-NEXT: custom_options: [ 99, 111, 110, 100, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 98, 111, 100, 121, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: builtin_options_type: WhileOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: cond_subgraph_index: 1,
// CHECK-NEXT: body_subgraph_index: 2
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
@ -91,7 +94,7 @@
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ]
// CHECK-NEXT: } ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "cond"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
@ -151,7 +154,7 @@
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "body"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",

View File

@ -143,6 +143,8 @@ typedef enum {
kTfLiteBuiltinMatrixSetDiag = 115,
kTfLiteBuiltinRound = 116,
kTfLiteBuiltinHardSwish = 117,
kTfLiteBuiltinIf = 118,
kTfLiteBuiltinWhile = 119,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -391,6 +391,16 @@ typedef struct {
EmptyStructPlaceholder placeholder;
} TfLiteMatrixSetDiagParams;
typedef struct {
int then_subgraph_index;
int else_subgraph_index;
} TfLiteIfParams;
typedef struct {
int cond_subgraph_index;
int body_subgraph_index;
} TfLiteWhileParams;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -721,6 +721,24 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params.release());
break;
}
case BuiltinOperator_IF: {
TfLiteIfParams* params = allocator->AllocatePOD<TfLiteIfParams>();
if (const auto* if_params = op->builtin_options_as_IfOptions()) {
params->then_subgraph_index = if_params->then_subgraph_index();
params->else_subgraph_index = if_params->else_subgraph_index();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_WHILE: {
TfLiteWhileParams* params = allocator->AllocatePOD<TfLiteWhileParams>();
if (const auto* while_params = op->builtin_options_as_WhileOptions()) {
params->cond_subgraph_index = while_params->cond_subgraph_index();
params->body_subgraph_index = while_params->body_subgraph_index();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
}
// Below are the ops with no builtin_data structure.
case BuiltinOperator_ABS:
case BuiltinOperator_BATCH_TO_SPACE_ND:

View File

@ -40,6 +40,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteFakeQuantParams",
"TfLiteFullyConnectedParams",
"TfLiteGatherParams",
"TfLiteIfParams",
"TfLiteL2NormParams",
"TfLiteLeakyReluParams",
"TfLiteLocalResponseNormParams",
@ -76,6 +77,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
"TfLiteUniqueParams",
"TfLiteUnpackParams",
"TfLiteReverseSequenceParams",
"TfLiteWhileParams",
nullptr};
} // namespace

View File

@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include <cstring>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
@ -30,10 +32,9 @@ struct OpData {
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData;
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
op_data->then_subgraph_index = m["then_subgraph_index"].AsInt32();
op_data->else_subgraph_index = m["else_subgraph_index"].AsInt32();
const auto* params = reinterpret_cast<const TfLiteIfParams*>(buffer);
op_data->then_subgraph_index = params->then_subgraph_index;
op_data->else_subgraph_index = params->else_subgraph_index;
return op_data;
}

View File

@ -381,6 +381,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 2);
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
// WARNING: Control flow ops are experimental and subject to change.
AddBuiltin(BuiltinOperator_IF, tflite::ops::custom::Register_IF());
AddBuiltin(BuiltinOperator_WHILE, tflite::ops::custom::Register_WHILE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
@ -388,10 +392,6 @@ BuiltinOpResolver::BuiltinOpResolver() {
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
// WARNING: Control flow ops are experimental and subject to change.
AddCustom("Experimental_If", tflite::ops::custom::Register_IF());
AddCustom("Experimental_While", tflite::ops::custom::Register_WHILE());
}
} // namespace builtin

View File

@ -170,18 +170,14 @@ void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteInt32);
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.Int("then_subgraph_index", 1);
fbb.Int("else_subgraph_index", 2);
});
fbb.Finish();
const auto& buffer = fbb.GetBuffer();
TfLiteIfParams* params =
reinterpret_cast<TfLiteIfParams*>(malloc(sizeof(TfLiteIfParams)));
params->then_subgraph_index = 1;
params->else_subgraph_index = 2;
int node_index;
subgraph->AddNodeWithParameters(
{kCondInput, kInput1, kInput2}, {kOutput}, {},
reinterpret_cast<const char*>(buffer.data()), buffer.size(), nullptr,
{kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
::tflite::ops::custom::Register_IF(), &node_index);
}
@ -333,19 +329,15 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
SetupTensor(subgraph, kOutput1, kTfLiteInt32);
SetupTensor(subgraph, kOutput2, kTfLiteInt32);
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.Int("cond_subgraph_index", 1);
fbb.Int("body_subgraph_index", 2);
});
fbb.Finish();
const auto& buffer = fbb.GetBuffer();
TfLiteWhileParams* params =
reinterpret_cast<TfLiteWhileParams*>(malloc(sizeof(TfLiteWhileParams)));
params->cond_subgraph_index = 1;
params->body_subgraph_index = 2;
int node_index;
subgraph->AddNodeWithParameters(
{0, 1}, {2, 3}, {}, reinterpret_cast<const char*>(buffer.data()),
buffer.size(), nullptr, ::tflite::ops::custom::Register_WHILE(),
&node_index);
subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params,
::tflite::ops::custom::Register_WHILE(),
&node_index);
}
void SubgraphBuilder::CreateConstantInt32Tensor(Subgraph* subgraph,

View File

@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include <cstring>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/context_util.h"
@ -107,10 +109,9 @@ struct OpData {
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData;
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32();
op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32();
const auto* params = reinterpret_cast<const TfLiteWhileParams*>(buffer);
op_data->cond_subgraph_index = params->cond_subgraph_index;
op_data->body_subgraph_index = params->body_subgraph_index;
op_data->cond_has_dynamic_output_tensors = false;
op_data->body_has_dynamic_output_tensors = false;
return op_data;

View File

@ -231,6 +231,8 @@ enum BuiltinOperator : byte {
MATRIX_SET_DIAG = 115,
ROUND = 116,
HARD_SWISH = 117,
IF = 118,
WHILE = 119,
}
// Options for the builtin operators.
@ -325,7 +327,9 @@ union BuiltinOptions {
MatrixDiagOptions,
QuantizeOptions,
MatrixSetDiagOptions,
HardSwishOptions
HardSwishOptions,
IfOptions,
WhileOptions
}
enum Padding : byte { SAME, VALID }
@ -783,6 +787,16 @@ table QuantizeOptions {
table MatrixSetDiagOptions {
}
table IfOptions {
then_subgraph_index:int;
else_subgraph_index:int;
}
table WhileOptions {
cond_subgraph_index:int;
body_subgraph_index:int;
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -304,6 +304,12 @@ struct QuantizeOptionsT;
struct MatrixSetDiagOptions;
struct MatrixSetDiagOptionsT;
struct IfOptions;
struct IfOptionsT;
struct WhileOptions;
struct WhileOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -577,11 +583,13 @@ enum BuiltinOperator {
BuiltinOperator_MATRIX_SET_DIAG = 115,
BuiltinOperator_ROUND = 116,
BuiltinOperator_HARD_SWISH = 117,
BuiltinOperator_IF = 118,
BuiltinOperator_WHILE = 119,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_HARD_SWISH
BuiltinOperator_MAX = BuiltinOperator_WHILE
};
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[117] {
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[119] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -699,7 +707,9 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[117] {
BuiltinOperator_QUANTIZE,
BuiltinOperator_MATRIX_SET_DIAG,
BuiltinOperator_ROUND,
BuiltinOperator_HARD_SWISH
BuiltinOperator_HARD_SWISH,
BuiltinOperator_IF,
BuiltinOperator_WHILE
};
return values;
}
@ -824,13 +834,15 @@ inline const char * const *EnumNamesBuiltinOperator() {
"MATRIX_SET_DIAG",
"ROUND",
"HARD_SWISH",
"IF",
"WHILE",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
if (e < BuiltinOperator_ADD || e > BuiltinOperator_HARD_SWISH) return "";
if (e < BuiltinOperator_ADD || e > BuiltinOperator_WHILE) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOperator()[index];
}
@ -928,11 +940,13 @@ enum BuiltinOptions {
BuiltinOptions_QuantizeOptions = 89,
BuiltinOptions_MatrixSetDiagOptions = 90,
BuiltinOptions_HardSwishOptions = 91,
BuiltinOptions_IfOptions = 92,
BuiltinOptions_WhileOptions = 93,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_HardSwishOptions
BuiltinOptions_MAX = BuiltinOptions_WhileOptions
};
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[92] {
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[94] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -1025,7 +1039,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[92] {
BuiltinOptions_MatrixDiagOptions,
BuiltinOptions_QuantizeOptions,
BuiltinOptions_MatrixSetDiagOptions,
BuiltinOptions_HardSwishOptions
BuiltinOptions_HardSwishOptions,
BuiltinOptions_IfOptions,
BuiltinOptions_WhileOptions
};
return values;
}
@ -1124,13 +1140,15 @@ inline const char * const *EnumNamesBuiltinOptions() {
"QuantizeOptions",
"MatrixSetDiagOptions",
"HardSwishOptions",
"IfOptions",
"WhileOptions",
nullptr
};
return names;
}
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
if (e < BuiltinOptions_NONE || e > BuiltinOptions_HardSwishOptions) return "";
if (e < BuiltinOptions_NONE || e > BuiltinOptions_WhileOptions) return "";
const size_t index = static_cast<size_t>(e);
return EnumNamesBuiltinOptions()[index];
}
@ -1503,6 +1521,14 @@ template<> struct BuiltinOptionsTraits<HardSwishOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions;
};
template<> struct BuiltinOptionsTraits<IfOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_IfOptions;
};
template<> struct BuiltinOptionsTraits<WhileOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -2263,6 +2289,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_HardSwishOptions ?
reinterpret_cast<const HardSwishOptionsT *>(value) : nullptr;
}
IfOptionsT *AsIfOptions() {
return type == BuiltinOptions_IfOptions ?
reinterpret_cast<IfOptionsT *>(value) : nullptr;
}
const IfOptionsT *AsIfOptions() const {
return type == BuiltinOptions_IfOptions ?
reinterpret_cast<const IfOptionsT *>(value) : nullptr;
}
WhileOptionsT *AsWhileOptions() {
return type == BuiltinOptions_WhileOptions ?
reinterpret_cast<WhileOptionsT *>(value) : nullptr;
}
const WhileOptionsT *AsWhileOptions() const {
return type == BuiltinOptions_WhileOptions ?
reinterpret_cast<const WhileOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -7856,6 +7898,138 @@ inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(
flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct IfOptionsT : public flatbuffers::NativeTable {
typedef IfOptions TableType;
int32_t then_subgraph_index;
int32_t else_subgraph_index;
IfOptionsT()
: then_subgraph_index(0),
else_subgraph_index(0) {
}
};
struct IfOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef IfOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_THEN_SUBGRAPH_INDEX = 4,
VT_ELSE_SUBGRAPH_INDEX = 6
};
int32_t then_subgraph_index() const {
return GetField<int32_t>(VT_THEN_SUBGRAPH_INDEX, 0);
}
int32_t else_subgraph_index() const {
return GetField<int32_t>(VT_ELSE_SUBGRAPH_INDEX, 0);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_THEN_SUBGRAPH_INDEX) &&
VerifyField<int32_t>(verifier, VT_ELSE_SUBGRAPH_INDEX) &&
verifier.EndTable();
}
IfOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<IfOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct IfOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_then_subgraph_index(int32_t then_subgraph_index) {
fbb_.AddElement<int32_t>(IfOptions::VT_THEN_SUBGRAPH_INDEX, then_subgraph_index, 0);
}
void add_else_subgraph_index(int32_t else_subgraph_index) {
fbb_.AddElement<int32_t>(IfOptions::VT_ELSE_SUBGRAPH_INDEX, else_subgraph_index, 0);
}
explicit IfOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
IfOptionsBuilder &operator=(const IfOptionsBuilder &);
flatbuffers::Offset<IfOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<IfOptions>(end);
return o;
}
};
inline flatbuffers::Offset<IfOptions> CreateIfOptions(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t then_subgraph_index = 0,
int32_t else_subgraph_index = 0) {
IfOptionsBuilder builder_(_fbb);
builder_.add_else_subgraph_index(else_subgraph_index);
builder_.add_then_subgraph_index(then_subgraph_index);
return builder_.Finish();
}
flatbuffers::Offset<IfOptions> CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct WhileOptionsT : public flatbuffers::NativeTable {
typedef WhileOptions TableType;
int32_t cond_subgraph_index;
int32_t body_subgraph_index;
WhileOptionsT()
: cond_subgraph_index(0),
body_subgraph_index(0) {
}
};
struct WhileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef WhileOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_COND_SUBGRAPH_INDEX = 4,
VT_BODY_SUBGRAPH_INDEX = 6
};
int32_t cond_subgraph_index() const {
return GetField<int32_t>(VT_COND_SUBGRAPH_INDEX, 0);
}
int32_t body_subgraph_index() const {
return GetField<int32_t>(VT_BODY_SUBGRAPH_INDEX, 0);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_COND_SUBGRAPH_INDEX) &&
VerifyField<int32_t>(verifier, VT_BODY_SUBGRAPH_INDEX) &&
verifier.EndTable();
}
WhileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<WhileOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct WhileOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_cond_subgraph_index(int32_t cond_subgraph_index) {
fbb_.AddElement<int32_t>(WhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0);
}
void add_body_subgraph_index(int32_t body_subgraph_index) {
fbb_.AddElement<int32_t>(WhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0);
}
explicit WhileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
WhileOptionsBuilder &operator=(const WhileOptionsBuilder &);
flatbuffers::Offset<WhileOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<WhileOptions>(end);
return o;
}
};
inline flatbuffers::Offset<WhileOptions> CreateWhileOptions(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t cond_subgraph_index = 0,
int32_t body_subgraph_index = 0) {
WhileOptionsBuilder builder_(_fbb);
builder_.add_body_subgraph_index(body_subgraph_index);
builder_.add_cond_subgraph_index(cond_subgraph_index);
return builder_.Finish();
}
flatbuffers::Offset<WhileOptions> CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -8265,6 +8439,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const HardSwishOptions *builtin_options_as_HardSwishOptions() const {
return builtin_options_type() == BuiltinOptions_HardSwishOptions ? static_cast<const HardSwishOptions *>(builtin_options()) : nullptr;
}
const IfOptions *builtin_options_as_IfOptions() const {
return builtin_options_type() == BuiltinOptions_IfOptions ? static_cast<const IfOptions *>(builtin_options()) : nullptr;
}
const WhileOptions *builtin_options_as_WhileOptions() const {
return builtin_options_type() == BuiltinOptions_WhileOptions ? static_cast<const WhileOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -8665,6 +8845,14 @@ template<> inline const HardSwishOptions *Operator::builtin_options_as<HardSwish
return builtin_options_as_HardSwishOptions();
}
template<> inline const IfOptions *Operator::builtin_options_as<IfOptions>() const {
return builtin_options_as_IfOptions();
}
template<> inline const WhileOptions *Operator::builtin_options_as<WhileOptions>() const {
return builtin_options_as_WhileOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -11690,6 +11878,64 @@ inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flat
_fbb);
}
inline IfOptionsT *IfOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new IfOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void IfOptions::UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = then_subgraph_index(); _o->then_subgraph_index = _e; };
{ auto _e = else_subgraph_index(); _o->else_subgraph_index = _e; };
}
inline flatbuffers::Offset<IfOptions> IfOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateIfOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<IfOptions> CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const IfOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _then_subgraph_index = _o->then_subgraph_index;
auto _else_subgraph_index = _o->else_subgraph_index;
return tflite::CreateIfOptions(
_fbb,
_then_subgraph_index,
_else_subgraph_index);
}
inline WhileOptionsT *WhileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new WhileOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void WhileOptions::UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; };
{ auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; };
}
inline flatbuffers::Offset<WhileOptions> WhileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateWhileOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<WhileOptions> CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const WhileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _cond_subgraph_index = _o->cond_subgraph_index;
auto _body_subgraph_index = _o->body_subgraph_index;
return tflite::CreateWhileOptions(
_fbb,
_cond_subgraph_index,
_body_subgraph_index);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -12347,6 +12593,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const HardSwishOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_IfOptions: {
auto ptr = reinterpret_cast<const IfOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_WhileOptions: {
auto ptr = reinterpret_cast<const WhileOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -12729,6 +12983,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const HardSwishOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_IfOptions: {
auto ptr = reinterpret_cast<const IfOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_WhileOptions: {
auto ptr = reinterpret_cast<const WhileOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -13099,6 +13361,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const HardSwishOptionsT *>(value);
return CreateHardSwishOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_IfOptions: {
auto ptr = reinterpret_cast<const IfOptionsT *>(value);
return CreateIfOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_WhileOptions: {
auto ptr = reinterpret_cast<const WhileOptionsT *>(value);
return CreateWhileOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -13469,6 +13739,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new HardSwishOptionsT(*reinterpret_cast<HardSwishOptionsT *>(u.value));
break;
}
case BuiltinOptions_IfOptions: {
value = new IfOptionsT(*reinterpret_cast<IfOptionsT *>(u.value));
break;
}
case BuiltinOptions_WhileOptions: {
value = new WhileOptionsT(*reinterpret_cast<WhileOptionsT *>(u.value));
break;
}
default:
break;
}
@ -13931,6 +14209,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_IfOptions: {
auto ptr = reinterpret_cast<IfOptionsT *>(value);
delete ptr;
break;
}
case BuiltinOptions_WhileOptions: {
auto ptr = reinterpret_cast<WhileOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;