Graduate TFLite control flow ops from experimental to builtin
PiperOrigin-RevId: 259150573
This commit is contained in:
parent
488b385a7d
commit
eedf79ed37
@ -368,9 +368,14 @@ class Translator {
|
|||||||
const std::string& name,
|
const std::string& name,
|
||||||
unsigned buffer_idx);
|
unsigned buffer_idx);
|
||||||
|
|
||||||
CustomOptionsOffset CreateIfOpCustomOptions(mlir::TF::IfOp op);
|
// TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove
|
||||||
|
// these 2 functions here.
|
||||||
CustomOptionsOffset CreateWhileOpCustomOptions(mlir::TF::WhileOp op);
|
BufferOffset<tflite::Operator> BuildIfOperator(
|
||||||
|
mlir::TF::IfOp op, const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results);
|
||||||
|
BufferOffset<tflite::Operator> BuildWhileOperator(
|
||||||
|
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results);
|
||||||
|
|
||||||
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
||||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||||
@ -544,31 +549,36 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
builder_.CreateString(name), q_params, /*is_variable=*/false);
|
builder_.CreateString(name), q_params, /*is_variable=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
CustomOptionsOffset Translator::CreateIfOpCustomOptions(mlir::TF::IfOp op) {
|
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||||
|
mlir::TF::IfOp op, const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results) {
|
||||||
|
auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
|
||||||
int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
|
int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
|
||||||
int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
|
int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
|
||||||
|
auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
|
||||||
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
|
else_subgraph_index)
|
||||||
flex_builder->Map([&]() {
|
.Union();
|
||||||
flex_builder->Int("then_subgraph_index", then_subgraph_index);
|
auto inputs = builder_.CreateVector(operands);
|
||||||
flex_builder->Int("else_subgraph_index", else_subgraph_index);
|
auto outputs = builder_.CreateVector(results);
|
||||||
});
|
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||||
flex_builder->Finish();
|
tflite::BuiltinOptions_IfOptions,
|
||||||
return builder_.CreateVector(flex_builder->GetBuffer());
|
builtin_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
CustomOptionsOffset Translator::CreateWhileOpCustomOptions(
|
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||||
mlir::TF::WhileOp op) {
|
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||||
|
const std::vector<int32_t>& results) {
|
||||||
|
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
|
||||||
int cond_subgraph_index = subgraph_index_map_.at(op.cond().str());
|
int cond_subgraph_index = subgraph_index_map_.at(op.cond().str());
|
||||||
int body_subgraph_index = subgraph_index_map_.at(op.body().str());
|
int body_subgraph_index = subgraph_index_map_.at(op.body().str());
|
||||||
|
auto builtin_options = tflite::CreateWhileOptions(
|
||||||
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
|
builder_, cond_subgraph_index, body_subgraph_index)
|
||||||
flex_builder->Map([&]() {
|
.Union();
|
||||||
flex_builder->Int("cond_subgraph_index", cond_subgraph_index);
|
auto inputs = builder_.CreateVector(operands);
|
||||||
flex_builder->Int("body_subgraph_index", body_subgraph_index);
|
auto outputs = builder_.CreateVector(results);
|
||||||
});
|
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||||
flex_builder->Finish();
|
tflite::BuiltinOptions_WhileOptions,
|
||||||
return builder_.CreateVector(flex_builder->GetBuffer());
|
builtin_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||||
@ -712,15 +722,14 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
|
|
||||||
if (dialect == tf_dialect_) {
|
if (dialect == tf_dialect_) {
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
|
if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
|
||||||
|
return BuildIfOperator(ifOp, operands, results);
|
||||||
|
} else if (auto whileOp = dyn_cast<mlir::TF::WhileOp>(inst)) {
|
||||||
|
return BuildWhileOperator(whileOp, operands, results);
|
||||||
|
}
|
||||||
|
|
||||||
CustomOptionsOffset custom_options;
|
CustomOptionsOffset custom_options;
|
||||||
|
|
||||||
if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
|
|
||||||
op_name = "Experimental_If";
|
|
||||||
custom_options = CreateIfOpCustomOptions(ifOp);
|
|
||||||
} else if (auto whileOp = dyn_cast<mlir::TF::WhileOp>(inst)) {
|
|
||||||
op_name = "Experimental_While";
|
|
||||||
custom_options = CreateWhileOpCustomOptions(whileOp);
|
|
||||||
} else {
|
|
||||||
// Ops in TF dialect can either be custom ops or flex ops.
|
// Ops in TF dialect can either be custom ops or flex ops.
|
||||||
// The reason we go directly from TensorFlow dialect MLIR to tensorflow
|
// The reason we go directly from TensorFlow dialect MLIR to tensorflow
|
||||||
// node instead of going to TF table gen'd ops via generated code is that
|
// node instead of going to TF table gen'd ops via generated code is that
|
||||||
@ -749,8 +758,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
||||||
// TF op name.
|
// TF op name.
|
||||||
op_name = std::string(kFlexOpNamePrefix) + node_def->op();
|
op_name = std::string(kFlexOpNamePrefix) + node_def->op();
|
||||||
if (auto options =
|
if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
|
||||||
CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
|
|
||||||
custom_options = *options;
|
custom_options = *options;
|
||||||
} else {
|
} else {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
@ -769,7 +777,6 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
return inst->emitOpError("is neither a custom op nor a flex op"),
|
return inst->emitOpError("is neither a custom op nor a flex op"),
|
||||||
llvm::None;
|
llvm::None;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t opcode_index =
|
uint32_t opcode_index =
|
||||||
GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
|
GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure
|
||||||
|
|
||||||
|
|
||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: LESS
|
// CHECK-NEXT: builtin_code: LESS
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: builtin_code: IF
|
||||||
// CHECK-NEXT: custom_code: "Experimental_If"
|
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
@ -52,8 +52,12 @@
|
|||||||
// CHECK-NEXT: opcode_index: 1,
|
// CHECK-NEXT: opcode_index: 1,
|
||||||
// CHECK-NEXT: inputs: [ 2, 0, 1 ],
|
// CHECK-NEXT: inputs: [ 2, 0, 1 ],
|
||||||
// CHECK-NEXT: outputs: [ 3 ],
|
// CHECK-NEXT: outputs: [ 3 ],
|
||||||
// CHECK-NEXT: custom_options: [ 116, 104, 101, 110, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 101, 108, 115, 101, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
|
// CHECK-NEXT: builtin_options_type: IfOptions,
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: builtin_options: {
|
||||||
|
// CHECK-NEXT: then_subgraph_index: 1,
|
||||||
|
// CHECK-NEXT: else_subgraph_index: 2
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "main"
|
// CHECK-NEXT: name: "main"
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
@ -88,7 +92,7 @@
|
|||||||
// CHECK-NEXT: builtin_options: {
|
// CHECK-NEXT: builtin_options: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "cond_true"
|
// CHECK-NEXT: name: "cond_true"
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
@ -123,7 +127,7 @@
|
|||||||
// CHECK-NEXT: builtin_options: {
|
// CHECK-NEXT: builtin_options: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "cond_false"
|
// CHECK-NEXT: name: "cond_false"
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: description: "MLIR Converted.",
|
// CHECK-NEXT: description: "MLIR Converted.",
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: builtin_code: WHILE
|
||||||
// CHECK-NEXT: custom_code: "Experimental_While"
|
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: GREATER
|
// CHECK-NEXT: builtin_code: GREATER
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
@ -49,8 +48,12 @@
|
|||||||
// CHECK-NEXT: operators: [ {
|
// CHECK-NEXT: operators: [ {
|
||||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||||
// CHECK-NEXT: custom_options: [ 99, 111, 110, 100, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 98, 111, 100, 121, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ]
|
// CHECK-NEXT: builtin_options_type: WhileOptions,
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: builtin_options: {
|
||||||
|
// CHECK-NEXT: cond_subgraph_index: 1,
|
||||||
|
// CHECK-NEXT: body_subgraph_index: 2
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "main"
|
// CHECK-NEXT: name: "main"
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
@ -91,7 +94,7 @@
|
|||||||
// CHECK-NEXT: opcode_index: 1,
|
// CHECK-NEXT: opcode_index: 1,
|
||||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||||
// CHECK-NEXT: outputs: [ 3 ]
|
// CHECK-NEXT: outputs: [ 3 ]
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "cond"
|
// CHECK-NEXT: name: "cond"
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
@ -151,7 +154,7 @@
|
|||||||
// CHECK-NEXT: builtin_options: {
|
// CHECK-NEXT: builtin_options: {
|
||||||
// CHECK-EMPTY:
|
// CHECK-EMPTY:
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: } ]
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: name: "body"
|
// CHECK-NEXT: name: "body"
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: description: "MLIR Converted.",
|
// CHECK-NEXT: description: "MLIR Converted.",
|
||||||
|
@ -143,6 +143,8 @@ typedef enum {
|
|||||||
kTfLiteBuiltinMatrixSetDiag = 115,
|
kTfLiteBuiltinMatrixSetDiag = 115,
|
||||||
kTfLiteBuiltinRound = 116,
|
kTfLiteBuiltinRound = 116,
|
||||||
kTfLiteBuiltinHardSwish = 117,
|
kTfLiteBuiltinHardSwish = 117,
|
||||||
|
kTfLiteBuiltinIf = 118,
|
||||||
|
kTfLiteBuiltinWhile = 119,
|
||||||
} TfLiteBuiltinOperator;
|
} TfLiteBuiltinOperator;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
@ -391,6 +391,16 @@ typedef struct {
|
|||||||
EmptyStructPlaceholder placeholder;
|
EmptyStructPlaceholder placeholder;
|
||||||
} TfLiteMatrixSetDiagParams;
|
} TfLiteMatrixSetDiagParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int then_subgraph_index;
|
||||||
|
int else_subgraph_index;
|
||||||
|
} TfLiteIfParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int cond_subgraph_index;
|
||||||
|
int body_subgraph_index;
|
||||||
|
} TfLiteWhileParams;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -721,6 +721,24 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params.release());
|
*builtin_data = reinterpret_cast<void*>(params.release());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOperator_IF: {
|
||||||
|
TfLiteIfParams* params = allocator->AllocatePOD<TfLiteIfParams>();
|
||||||
|
if (const auto* if_params = op->builtin_options_as_IfOptions()) {
|
||||||
|
params->then_subgraph_index = if_params->then_subgraph_index();
|
||||||
|
params->else_subgraph_index = if_params->else_subgraph_index();
|
||||||
|
}
|
||||||
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BuiltinOperator_WHILE: {
|
||||||
|
TfLiteWhileParams* params = allocator->AllocatePOD<TfLiteWhileParams>();
|
||||||
|
if (const auto* while_params = op->builtin_options_as_WhileOptions()) {
|
||||||
|
params->cond_subgraph_index = while_params->cond_subgraph_index();
|
||||||
|
params->body_subgraph_index = while_params->body_subgraph_index();
|
||||||
|
}
|
||||||
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
|
break;
|
||||||
|
}
|
||||||
// Below are the ops with no builtin_data structure.
|
// Below are the ops with no builtin_data structure.
|
||||||
case BuiltinOperator_ABS:
|
case BuiltinOperator_ABS:
|
||||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||||
|
@ -40,6 +40,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
|
|||||||
"TfLiteFakeQuantParams",
|
"TfLiteFakeQuantParams",
|
||||||
"TfLiteFullyConnectedParams",
|
"TfLiteFullyConnectedParams",
|
||||||
"TfLiteGatherParams",
|
"TfLiteGatherParams",
|
||||||
|
"TfLiteIfParams",
|
||||||
"TfLiteL2NormParams",
|
"TfLiteL2NormParams",
|
||||||
"TfLiteLeakyReluParams",
|
"TfLiteLeakyReluParams",
|
||||||
"TfLiteLocalResponseNormParams",
|
"TfLiteLocalResponseNormParams",
|
||||||
@ -76,6 +77,7 @@ static const char* param_structs[] = {"TfLiteAddParams",
|
|||||||
"TfLiteUniqueParams",
|
"TfLiteUniqueParams",
|
||||||
"TfLiteUnpackParams",
|
"TfLiteUnpackParams",
|
||||||
"TfLiteReverseSequenceParams",
|
"TfLiteReverseSequenceParams",
|
||||||
|
"TfLiteWhileParams",
|
||||||
nullptr};
|
nullptr};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/core/subgraph.h"
|
#include "tensorflow/lite/core/subgraph.h"
|
||||||
@ -30,10 +32,9 @@ struct OpData {
|
|||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData;
|
auto* op_data = new OpData;
|
||||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
const auto* params = reinterpret_cast<const TfLiteIfParams*>(buffer);
|
||||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
op_data->then_subgraph_index = params->then_subgraph_index;
|
||||||
op_data->then_subgraph_index = m["then_subgraph_index"].AsInt32();
|
op_data->else_subgraph_index = params->else_subgraph_index;
|
||||||
op_data->else_subgraph_index = m["else_subgraph_index"].AsInt32();
|
|
||||||
return op_data;
|
return op_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,6 +381,10 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
|
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
|
||||||
|
|
||||||
|
// WARNING: Control flow ops are experimental and subject to change.
|
||||||
|
AddBuiltin(BuiltinOperator_IF, tflite::ops::custom::Register_IF());
|
||||||
|
AddBuiltin(BuiltinOperator_WHILE, tflite::ops::custom::Register_WHILE());
|
||||||
|
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
|
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
|
||||||
@ -388,10 +392,6 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
|
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
|
||||||
AddCustom("TFLite_Detection_PostProcess",
|
AddCustom("TFLite_Detection_PostProcess",
|
||||||
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
|
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
|
||||||
|
|
||||||
// WARNING: Control flow ops are experimental and subject to change.
|
|
||||||
AddCustom("Experimental_If", tflite::ops::custom::Register_IF());
|
|
||||||
AddCustom("Experimental_While", tflite::ops::custom::Register_WHILE());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace builtin
|
} // namespace builtin
|
||||||
|
@ -170,18 +170,14 @@ void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
|
|||||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||||
SetupTensor(subgraph, kOutput, kTfLiteInt32);
|
SetupTensor(subgraph, kOutput, kTfLiteInt32);
|
||||||
|
|
||||||
flexbuffers::Builder fbb;
|
TfLiteIfParams* params =
|
||||||
fbb.Map([&]() {
|
reinterpret_cast<TfLiteIfParams*>(malloc(sizeof(TfLiteIfParams)));
|
||||||
fbb.Int("then_subgraph_index", 1);
|
params->then_subgraph_index = 1;
|
||||||
fbb.Int("else_subgraph_index", 2);
|
params->else_subgraph_index = 2;
|
||||||
});
|
|
||||||
fbb.Finish();
|
|
||||||
const auto& buffer = fbb.GetBuffer();
|
|
||||||
|
|
||||||
int node_index;
|
int node_index;
|
||||||
subgraph->AddNodeWithParameters(
|
subgraph->AddNodeWithParameters(
|
||||||
{kCondInput, kInput1, kInput2}, {kOutput}, {},
|
{kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params,
|
||||||
reinterpret_cast<const char*>(buffer.data()), buffer.size(), nullptr,
|
|
||||||
::tflite::ops::custom::Register_IF(), &node_index);
|
::tflite::ops::custom::Register_IF(), &node_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,18 +329,14 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
|
|||||||
SetupTensor(subgraph, kOutput1, kTfLiteInt32);
|
SetupTensor(subgraph, kOutput1, kTfLiteInt32);
|
||||||
SetupTensor(subgraph, kOutput2, kTfLiteInt32);
|
SetupTensor(subgraph, kOutput2, kTfLiteInt32);
|
||||||
|
|
||||||
flexbuffers::Builder fbb;
|
TfLiteWhileParams* params =
|
||||||
fbb.Map([&]() {
|
reinterpret_cast<TfLiteWhileParams*>(malloc(sizeof(TfLiteWhileParams)));
|
||||||
fbb.Int("cond_subgraph_index", 1);
|
params->cond_subgraph_index = 1;
|
||||||
fbb.Int("body_subgraph_index", 2);
|
params->body_subgraph_index = 2;
|
||||||
});
|
|
||||||
fbb.Finish();
|
|
||||||
const auto& buffer = fbb.GetBuffer();
|
|
||||||
|
|
||||||
int node_index;
|
int node_index;
|
||||||
subgraph->AddNodeWithParameters(
|
subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params,
|
||||||
{0, 1}, {2, 3}, {}, reinterpret_cast<const char*>(buffer.data()),
|
::tflite::ops::custom::Register_WHILE(),
|
||||||
buffer.size(), nullptr, ::tflite::ops::custom::Register_WHILE(),
|
|
||||||
&node_index);
|
&node_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/context_util.h"
|
#include "tensorflow/lite/context_util.h"
|
||||||
@ -107,10 +109,9 @@ struct OpData {
|
|||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* op_data = new OpData;
|
auto* op_data = new OpData;
|
||||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
const auto* params = reinterpret_cast<const TfLiteWhileParams*>(buffer);
|
||||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
op_data->cond_subgraph_index = params->cond_subgraph_index;
|
||||||
op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32();
|
op_data->body_subgraph_index = params->body_subgraph_index;
|
||||||
op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32();
|
|
||||||
op_data->cond_has_dynamic_output_tensors = false;
|
op_data->cond_has_dynamic_output_tensors = false;
|
||||||
op_data->body_has_dynamic_output_tensors = false;
|
op_data->body_has_dynamic_output_tensors = false;
|
||||||
return op_data;
|
return op_data;
|
||||||
|
@ -231,6 +231,8 @@ enum BuiltinOperator : byte {
|
|||||||
MATRIX_SET_DIAG = 115,
|
MATRIX_SET_DIAG = 115,
|
||||||
ROUND = 116,
|
ROUND = 116,
|
||||||
HARD_SWISH = 117,
|
HARD_SWISH = 117,
|
||||||
|
IF = 118,
|
||||||
|
WHILE = 119,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options for the builtin operators.
|
// Options for the builtin operators.
|
||||||
@ -325,7 +327,9 @@ union BuiltinOptions {
|
|||||||
MatrixDiagOptions,
|
MatrixDiagOptions,
|
||||||
QuantizeOptions,
|
QuantizeOptions,
|
||||||
MatrixSetDiagOptions,
|
MatrixSetDiagOptions,
|
||||||
HardSwishOptions
|
HardSwishOptions,
|
||||||
|
IfOptions,
|
||||||
|
WhileOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Padding : byte { SAME, VALID }
|
enum Padding : byte { SAME, VALID }
|
||||||
@ -783,6 +787,16 @@ table QuantizeOptions {
|
|||||||
table MatrixSetDiagOptions {
|
table MatrixSetDiagOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table IfOptions {
|
||||||
|
then_subgraph_index:int;
|
||||||
|
else_subgraph_index:int;
|
||||||
|
}
|
||||||
|
|
||||||
|
table WhileOptions {
|
||||||
|
cond_subgraph_index:int;
|
||||||
|
body_subgraph_index:int;
|
||||||
|
}
|
||||||
|
|
||||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||||
// builtin, or a string if the operator is custom.
|
// builtin, or a string if the operator is custom.
|
||||||
table OperatorCode {
|
table OperatorCode {
|
||||||
|
@ -304,6 +304,12 @@ struct QuantizeOptionsT;
|
|||||||
struct MatrixSetDiagOptions;
|
struct MatrixSetDiagOptions;
|
||||||
struct MatrixSetDiagOptionsT;
|
struct MatrixSetDiagOptionsT;
|
||||||
|
|
||||||
|
struct IfOptions;
|
||||||
|
struct IfOptionsT;
|
||||||
|
|
||||||
|
struct WhileOptions;
|
||||||
|
struct WhileOptionsT;
|
||||||
|
|
||||||
struct OperatorCode;
|
struct OperatorCode;
|
||||||
struct OperatorCodeT;
|
struct OperatorCodeT;
|
||||||
|
|
||||||
@ -577,11 +583,13 @@ enum BuiltinOperator {
|
|||||||
BuiltinOperator_MATRIX_SET_DIAG = 115,
|
BuiltinOperator_MATRIX_SET_DIAG = 115,
|
||||||
BuiltinOperator_ROUND = 116,
|
BuiltinOperator_ROUND = 116,
|
||||||
BuiltinOperator_HARD_SWISH = 117,
|
BuiltinOperator_HARD_SWISH = 117,
|
||||||
|
BuiltinOperator_IF = 118,
|
||||||
|
BuiltinOperator_WHILE = 119,
|
||||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||||
BuiltinOperator_MAX = BuiltinOperator_HARD_SWISH
|
BuiltinOperator_MAX = BuiltinOperator_WHILE
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[117] {
|
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[119] {
|
||||||
static const BuiltinOperator values[] = {
|
static const BuiltinOperator values[] = {
|
||||||
BuiltinOperator_ADD,
|
BuiltinOperator_ADD,
|
||||||
BuiltinOperator_AVERAGE_POOL_2D,
|
BuiltinOperator_AVERAGE_POOL_2D,
|
||||||
@ -699,7 +707,9 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[117] {
|
|||||||
BuiltinOperator_QUANTIZE,
|
BuiltinOperator_QUANTIZE,
|
||||||
BuiltinOperator_MATRIX_SET_DIAG,
|
BuiltinOperator_MATRIX_SET_DIAG,
|
||||||
BuiltinOperator_ROUND,
|
BuiltinOperator_ROUND,
|
||||||
BuiltinOperator_HARD_SWISH
|
BuiltinOperator_HARD_SWISH,
|
||||||
|
BuiltinOperator_IF,
|
||||||
|
BuiltinOperator_WHILE
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
@ -824,13 +834,15 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
|||||||
"MATRIX_SET_DIAG",
|
"MATRIX_SET_DIAG",
|
||||||
"ROUND",
|
"ROUND",
|
||||||
"HARD_SWISH",
|
"HARD_SWISH",
|
||||||
|
"IF",
|
||||||
|
"WHILE",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
inline const char *EnumNameBuiltinOperator(BuiltinOperator e) {
|
||||||
if (e < BuiltinOperator_ADD || e > BuiltinOperator_HARD_SWISH) return "";
|
if (e < BuiltinOperator_ADD || e > BuiltinOperator_WHILE) return "";
|
||||||
const size_t index = static_cast<size_t>(e);
|
const size_t index = static_cast<size_t>(e);
|
||||||
return EnumNamesBuiltinOperator()[index];
|
return EnumNamesBuiltinOperator()[index];
|
||||||
}
|
}
|
||||||
@ -928,11 +940,13 @@ enum BuiltinOptions {
|
|||||||
BuiltinOptions_QuantizeOptions = 89,
|
BuiltinOptions_QuantizeOptions = 89,
|
||||||
BuiltinOptions_MatrixSetDiagOptions = 90,
|
BuiltinOptions_MatrixSetDiagOptions = 90,
|
||||||
BuiltinOptions_HardSwishOptions = 91,
|
BuiltinOptions_HardSwishOptions = 91,
|
||||||
|
BuiltinOptions_IfOptions = 92,
|
||||||
|
BuiltinOptions_WhileOptions = 93,
|
||||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||||
BuiltinOptions_MAX = BuiltinOptions_HardSwishOptions
|
BuiltinOptions_MAX = BuiltinOptions_WhileOptions
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[92] {
|
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[94] {
|
||||||
static const BuiltinOptions values[] = {
|
static const BuiltinOptions values[] = {
|
||||||
BuiltinOptions_NONE,
|
BuiltinOptions_NONE,
|
||||||
BuiltinOptions_Conv2DOptions,
|
BuiltinOptions_Conv2DOptions,
|
||||||
@ -1025,7 +1039,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[92] {
|
|||||||
BuiltinOptions_MatrixDiagOptions,
|
BuiltinOptions_MatrixDiagOptions,
|
||||||
BuiltinOptions_QuantizeOptions,
|
BuiltinOptions_QuantizeOptions,
|
||||||
BuiltinOptions_MatrixSetDiagOptions,
|
BuiltinOptions_MatrixSetDiagOptions,
|
||||||
BuiltinOptions_HardSwishOptions
|
BuiltinOptions_HardSwishOptions,
|
||||||
|
BuiltinOptions_IfOptions,
|
||||||
|
BuiltinOptions_WhileOptions
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
@ -1124,13 +1140,15 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
|||||||
"QuantizeOptions",
|
"QuantizeOptions",
|
||||||
"MatrixSetDiagOptions",
|
"MatrixSetDiagOptions",
|
||||||
"HardSwishOptions",
|
"HardSwishOptions",
|
||||||
|
"IfOptions",
|
||||||
|
"WhileOptions",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
inline const char *EnumNameBuiltinOptions(BuiltinOptions e) {
|
||||||
if (e < BuiltinOptions_NONE || e > BuiltinOptions_HardSwishOptions) return "";
|
if (e < BuiltinOptions_NONE || e > BuiltinOptions_WhileOptions) return "";
|
||||||
const size_t index = static_cast<size_t>(e);
|
const size_t index = static_cast<size_t>(e);
|
||||||
return EnumNamesBuiltinOptions()[index];
|
return EnumNamesBuiltinOptions()[index];
|
||||||
}
|
}
|
||||||
@ -1503,6 +1521,14 @@ template<> struct BuiltinOptionsTraits<HardSwishOptions> {
|
|||||||
static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions;
|
static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<> struct BuiltinOptionsTraits<IfOptions> {
|
||||||
|
static const BuiltinOptions enum_value = BuiltinOptions_IfOptions;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct BuiltinOptionsTraits<WhileOptions> {
|
||||||
|
static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions;
|
||||||
|
};
|
||||||
|
|
||||||
struct BuiltinOptionsUnion {
|
struct BuiltinOptionsUnion {
|
||||||
BuiltinOptions type;
|
BuiltinOptions type;
|
||||||
void *value;
|
void *value;
|
||||||
@ -2263,6 +2289,22 @@ struct BuiltinOptionsUnion {
|
|||||||
return type == BuiltinOptions_HardSwishOptions ?
|
return type == BuiltinOptions_HardSwishOptions ?
|
||||||
reinterpret_cast<const HardSwishOptionsT *>(value) : nullptr;
|
reinterpret_cast<const HardSwishOptionsT *>(value) : nullptr;
|
||||||
}
|
}
|
||||||
|
IfOptionsT *AsIfOptions() {
|
||||||
|
return type == BuiltinOptions_IfOptions ?
|
||||||
|
reinterpret_cast<IfOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
|
const IfOptionsT *AsIfOptions() const {
|
||||||
|
return type == BuiltinOptions_IfOptions ?
|
||||||
|
reinterpret_cast<const IfOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
|
WhileOptionsT *AsWhileOptions() {
|
||||||
|
return type == BuiltinOptions_WhileOptions ?
|
||||||
|
reinterpret_cast<WhileOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
|
const WhileOptionsT *AsWhileOptions() const {
|
||||||
|
return type == BuiltinOptions_WhileOptions ?
|
||||||
|
reinterpret_cast<const WhileOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||||
@ -7856,6 +7898,138 @@ inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(
|
|||||||
|
|
||||||
flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
|
struct IfOptionsT : public flatbuffers::NativeTable {
|
||||||
|
typedef IfOptions TableType;
|
||||||
|
int32_t then_subgraph_index;
|
||||||
|
int32_t else_subgraph_index;
|
||||||
|
IfOptionsT()
|
||||||
|
: then_subgraph_index(0),
|
||||||
|
else_subgraph_index(0) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct IfOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
typedef IfOptionsT NativeTableType;
|
||||||
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
|
VT_THEN_SUBGRAPH_INDEX = 4,
|
||||||
|
VT_ELSE_SUBGRAPH_INDEX = 6
|
||||||
|
};
|
||||||
|
int32_t then_subgraph_index() const {
|
||||||
|
return GetField<int32_t>(VT_THEN_SUBGRAPH_INDEX, 0);
|
||||||
|
}
|
||||||
|
int32_t else_subgraph_index() const {
|
||||||
|
return GetField<int32_t>(VT_ELSE_SUBGRAPH_INDEX, 0);
|
||||||
|
}
|
||||||
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
|
return VerifyTableStart(verifier) &&
|
||||||
|
VerifyField<int32_t>(verifier, VT_THEN_SUBGRAPH_INDEX) &&
|
||||||
|
VerifyField<int32_t>(verifier, VT_ELSE_SUBGRAPH_INDEX) &&
|
||||||
|
verifier.EndTable();
|
||||||
|
}
|
||||||
|
IfOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
void UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
static flatbuffers::Offset<IfOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct IfOptionsBuilder {
|
||||||
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
|
flatbuffers::uoffset_t start_;
|
||||||
|
void add_then_subgraph_index(int32_t then_subgraph_index) {
|
||||||
|
fbb_.AddElement<int32_t>(IfOptions::VT_THEN_SUBGRAPH_INDEX, then_subgraph_index, 0);
|
||||||
|
}
|
||||||
|
void add_else_subgraph_index(int32_t else_subgraph_index) {
|
||||||
|
fbb_.AddElement<int32_t>(IfOptions::VT_ELSE_SUBGRAPH_INDEX, else_subgraph_index, 0);
|
||||||
|
}
|
||||||
|
explicit IfOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
|
: fbb_(_fbb) {
|
||||||
|
start_ = fbb_.StartTable();
|
||||||
|
}
|
||||||
|
IfOptionsBuilder &operator=(const IfOptionsBuilder &);
|
||||||
|
flatbuffers::Offset<IfOptions> Finish() {
|
||||||
|
const auto end = fbb_.EndTable(start_);
|
||||||
|
auto o = flatbuffers::Offset<IfOptions>(end);
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<IfOptions> CreateIfOptions(
|
||||||
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
|
int32_t then_subgraph_index = 0,
|
||||||
|
int32_t else_subgraph_index = 0) {
|
||||||
|
IfOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_else_subgraph_index(else_subgraph_index);
|
||||||
|
builder_.add_then_subgraph_index(then_subgraph_index);
|
||||||
|
return builder_.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<IfOptions> CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
|
struct WhileOptionsT : public flatbuffers::NativeTable {
|
||||||
|
typedef WhileOptions TableType;
|
||||||
|
int32_t cond_subgraph_index;
|
||||||
|
int32_t body_subgraph_index;
|
||||||
|
WhileOptionsT()
|
||||||
|
: cond_subgraph_index(0),
|
||||||
|
body_subgraph_index(0) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WhileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
typedef WhileOptionsT NativeTableType;
|
||||||
|
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
|
||||||
|
VT_COND_SUBGRAPH_INDEX = 4,
|
||||||
|
VT_BODY_SUBGRAPH_INDEX = 6
|
||||||
|
};
|
||||||
|
int32_t cond_subgraph_index() const {
|
||||||
|
return GetField<int32_t>(VT_COND_SUBGRAPH_INDEX, 0);
|
||||||
|
}
|
||||||
|
int32_t body_subgraph_index() const {
|
||||||
|
return GetField<int32_t>(VT_BODY_SUBGRAPH_INDEX, 0);
|
||||||
|
}
|
||||||
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
|
return VerifyTableStart(verifier) &&
|
||||||
|
VerifyField<int32_t>(verifier, VT_COND_SUBGRAPH_INDEX) &&
|
||||||
|
VerifyField<int32_t>(verifier, VT_BODY_SUBGRAPH_INDEX) &&
|
||||||
|
verifier.EndTable();
|
||||||
|
}
|
||||||
|
WhileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
void UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
static flatbuffers::Offset<WhileOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WhileOptionsBuilder {
|
||||||
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
|
flatbuffers::uoffset_t start_;
|
||||||
|
void add_cond_subgraph_index(int32_t cond_subgraph_index) {
|
||||||
|
fbb_.AddElement<int32_t>(WhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0);
|
||||||
|
}
|
||||||
|
void add_body_subgraph_index(int32_t body_subgraph_index) {
|
||||||
|
fbb_.AddElement<int32_t>(WhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0);
|
||||||
|
}
|
||||||
|
explicit WhileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
|
: fbb_(_fbb) {
|
||||||
|
start_ = fbb_.StartTable();
|
||||||
|
}
|
||||||
|
WhileOptionsBuilder &operator=(const WhileOptionsBuilder &);
|
||||||
|
flatbuffers::Offset<WhileOptions> Finish() {
|
||||||
|
const auto end = fbb_.EndTable(start_);
|
||||||
|
auto o = flatbuffers::Offset<WhileOptions>(end);
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<WhileOptions> CreateWhileOptions(
|
||||||
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
|
int32_t cond_subgraph_index = 0,
|
||||||
|
int32_t body_subgraph_index = 0) {
|
||||||
|
WhileOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_body_subgraph_index(body_subgraph_index);
|
||||||
|
builder_.add_cond_subgraph_index(cond_subgraph_index);
|
||||||
|
return builder_.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<WhileOptions> CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||||
typedef OperatorCode TableType;
|
typedef OperatorCode TableType;
|
||||||
BuiltinOperator builtin_code;
|
BuiltinOperator builtin_code;
|
||||||
@ -8265,6 +8439,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
const HardSwishOptions *builtin_options_as_HardSwishOptions() const {
|
const HardSwishOptions *builtin_options_as_HardSwishOptions() const {
|
||||||
return builtin_options_type() == BuiltinOptions_HardSwishOptions ? static_cast<const HardSwishOptions *>(builtin_options()) : nullptr;
|
return builtin_options_type() == BuiltinOptions_HardSwishOptions ? static_cast<const HardSwishOptions *>(builtin_options()) : nullptr;
|
||||||
}
|
}
|
||||||
|
const IfOptions *builtin_options_as_IfOptions() const {
|
||||||
|
return builtin_options_type() == BuiltinOptions_IfOptions ? static_cast<const IfOptions *>(builtin_options()) : nullptr;
|
||||||
|
}
|
||||||
|
const WhileOptions *builtin_options_as_WhileOptions() const {
|
||||||
|
return builtin_options_type() == BuiltinOptions_WhileOptions ? static_cast<const WhileOptions *>(builtin_options()) : nullptr;
|
||||||
|
}
|
||||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||||
}
|
}
|
||||||
@ -8665,6 +8845,14 @@ template<> inline const HardSwishOptions *Operator::builtin_options_as<HardSwish
|
|||||||
return builtin_options_as_HardSwishOptions();
|
return builtin_options_as_HardSwishOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline const IfOptions *Operator::builtin_options_as<IfOptions>() const {
|
||||||
|
return builtin_options_as_IfOptions();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> inline const WhileOptions *Operator::builtin_options_as<WhileOptions>() const {
|
||||||
|
return builtin_options_as_WhileOptions();
|
||||||
|
}
|
||||||
|
|
||||||
struct OperatorBuilder {
|
struct OperatorBuilder {
|
||||||
flatbuffers::FlatBufferBuilder &fbb_;
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
flatbuffers::uoffset_t start_;
|
flatbuffers::uoffset_t start_;
|
||||||
@ -11690,6 +11878,64 @@ inline flatbuffers::Offset<MatrixSetDiagOptions> CreateMatrixSetDiagOptions(flat
|
|||||||
_fbb);
|
_fbb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline IfOptionsT *IfOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
auto _o = new IfOptionsT();
|
||||||
|
UnPackTo(_o, _resolver);
|
||||||
|
return _o;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void IfOptions::UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
(void)_o;
|
||||||
|
(void)_resolver;
|
||||||
|
{ auto _e = then_subgraph_index(); _o->then_subgraph_index = _e; };
|
||||||
|
{ auto _e = else_subgraph_index(); _o->else_subgraph_index = _e; };
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<IfOptions> IfOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
return CreateIfOptions(_fbb, _o, _rehasher);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<IfOptions> CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
(void)_rehasher;
|
||||||
|
(void)_o;
|
||||||
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const IfOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
|
auto _then_subgraph_index = _o->then_subgraph_index;
|
||||||
|
auto _else_subgraph_index = _o->else_subgraph_index;
|
||||||
|
return tflite::CreateIfOptions(
|
||||||
|
_fbb,
|
||||||
|
_then_subgraph_index,
|
||||||
|
_else_subgraph_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline WhileOptionsT *WhileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
auto _o = new WhileOptionsT();
|
||||||
|
UnPackTo(_o, _resolver);
|
||||||
|
return _o;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void WhileOptions::UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
(void)_o;
|
||||||
|
(void)_resolver;
|
||||||
|
{ auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; };
|
||||||
|
{ auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; };
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<WhileOptions> WhileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
return CreateWhileOptions(_fbb, _o, _rehasher);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<WhileOptions> CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
(void)_rehasher;
|
||||||
|
(void)_o;
|
||||||
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const WhileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
|
auto _cond_subgraph_index = _o->cond_subgraph_index;
|
||||||
|
auto _body_subgraph_index = _o->body_subgraph_index;
|
||||||
|
return tflite::CreateWhileOptions(
|
||||||
|
_fbb,
|
||||||
|
_cond_subgraph_index,
|
||||||
|
_body_subgraph_index);
|
||||||
|
}
|
||||||
|
|
||||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
auto _o = new OperatorCodeT();
|
auto _o = new OperatorCodeT();
|
||||||
UnPackTo(_o, _resolver);
|
UnPackTo(_o, _resolver);
|
||||||
@ -12347,6 +12593,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
|||||||
auto ptr = reinterpret_cast<const HardSwishOptions *>(obj);
|
auto ptr = reinterpret_cast<const HardSwishOptions *>(obj);
|
||||||
return verifier.VerifyTable(ptr);
|
return verifier.VerifyTable(ptr);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_IfOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const IfOptions *>(obj);
|
||||||
|
return verifier.VerifyTable(ptr);
|
||||||
|
}
|
||||||
|
case BuiltinOptions_WhileOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const WhileOptions *>(obj);
|
||||||
|
return verifier.VerifyTable(ptr);
|
||||||
|
}
|
||||||
default: return false;
|
default: return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -12729,6 +12983,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
|||||||
auto ptr = reinterpret_cast<const HardSwishOptions *>(obj);
|
auto ptr = reinterpret_cast<const HardSwishOptions *>(obj);
|
||||||
return ptr->UnPack(resolver);
|
return ptr->UnPack(resolver);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_IfOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const IfOptions *>(obj);
|
||||||
|
return ptr->UnPack(resolver);
|
||||||
|
}
|
||||||
|
case BuiltinOptions_WhileOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const WhileOptions *>(obj);
|
||||||
|
return ptr->UnPack(resolver);
|
||||||
|
}
|
||||||
default: return nullptr;
|
default: return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -13099,6 +13361,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
|||||||
auto ptr = reinterpret_cast<const HardSwishOptionsT *>(value);
|
auto ptr = reinterpret_cast<const HardSwishOptionsT *>(value);
|
||||||
return CreateHardSwishOptions(_fbb, ptr, _rehasher).Union();
|
return CreateHardSwishOptions(_fbb, ptr, _rehasher).Union();
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_IfOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const IfOptionsT *>(value);
|
||||||
|
return CreateIfOptions(_fbb, ptr, _rehasher).Union();
|
||||||
|
}
|
||||||
|
case BuiltinOptions_WhileOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const WhileOptionsT *>(value);
|
||||||
|
return CreateWhileOptions(_fbb, ptr, _rehasher).Union();
|
||||||
|
}
|
||||||
default: return 0;
|
default: return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -13469,6 +13739,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
|||||||
value = new HardSwishOptionsT(*reinterpret_cast<HardSwishOptionsT *>(u.value));
|
value = new HardSwishOptionsT(*reinterpret_cast<HardSwishOptionsT *>(u.value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_IfOptions: {
|
||||||
|
value = new IfOptionsT(*reinterpret_cast<IfOptionsT *>(u.value));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BuiltinOptions_WhileOptions: {
|
||||||
|
value = new WhileOptionsT(*reinterpret_cast<WhileOptionsT *>(u.value));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -13931,6 +14209,16 @@ inline void BuiltinOptionsUnion::Reset() {
|
|||||||
delete ptr;
|
delete ptr;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_IfOptions: {
|
||||||
|
auto ptr = reinterpret_cast<IfOptionsT *>(value);
|
||||||
|
delete ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case BuiltinOptions_WhileOptions: {
|
||||||
|
auto ptr = reinterpret_cast<WhileOptionsT *>(value);
|
||||||
|
delete ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
value = nullptr;
|
value = nullptr;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user