From eedf79ed3782dddd1c4787c72fc9804a20252245 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Sat, 20 Jul 2019 15:45:31 -0700 Subject: [PATCH] Graduate TFLite control flow ops from experimental to builtin PiperOrigin-RevId: 259150573 --- .../mlir/lite/flatbuffer_translate.cc | 153 ++++----- .../lite/tests/mlir2flatbuffer/if_op.mlir | 18 +- .../lite/tests/mlir2flatbuffer/while_op.mlir | 15 +- tensorflow/lite/builtin_ops.h | 2 + tensorflow/lite/c/builtin_op_data.h | 10 + .../lite/core/api/flatbuffer_conversions.cc | 18 ++ .../writer/option_writer_generator.cc | 2 + tensorflow/lite/kernels/if.cc | 11 +- tensorflow/lite/kernels/register.cc | 8 +- tensorflow/lite/kernels/subgraph_test_util.cc | 32 +- tensorflow/lite/kernels/while.cc | 11 +- tensorflow/lite/schema/schema.fbs | 16 +- tensorflow/lite/schema/schema_generated.h | 304 +++++++++++++++++- 13 files changed, 471 insertions(+), 129 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index c6a461d7414..ab17d62fa53 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -368,9 +368,14 @@ class Translator { const std::string& name, unsigned buffer_idx); - CustomOptionsOffset CreateIfOpCustomOptions(mlir::TF::IfOp op); - - CustomOptionsOffset CreateWhileOpCustomOptions(mlir::TF::WhileOp op); + // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove + // these 2 functions here. + BufferOffset BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results); + BufferOffset BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results); Optional CreateFlexOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); @@ -544,31 +549,36 @@ Optional> Translator::BuildTensor( builder_.CreateString(name), q_params, /*is_variable=*/false); } -CustomOptionsOffset Translator::CreateIfOpCustomOptions(mlir::TF::IfOp op) { +BufferOffset Translator::BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); - - auto flex_builder = absl::make_unique(); - flex_builder->Map([&]() { - flex_builder->Int("then_subgraph_index", then_subgraph_index); - flex_builder->Int("else_subgraph_index", else_subgraph_index); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); } -CustomOptionsOffset Translator::CreateWhileOpCustomOptions( - mlir::TF::WhileOp op) { +BufferOffset Translator::BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); int body_subgraph_index = subgraph_index_map_.at(op.body().str()); - - auto flex_builder = absl::make_unique(); - flex_builder->Map([&]() { - flex_builder->Int("cond_subgraph_index", cond_subgraph_index); - flex_builder->Int("body_subgraph_index", body_subgraph_index); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); + auto builtin_options = tflite::CreateWhileOptions( + builder_, cond_subgraph_index, body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); } Optional Translator::CreateFlexOpCustomOptions( @@ -712,63 +722,60 @@ Optional> Translator::BuildOperator( if (dialect == tf_dialect_) { std::string op_name; + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } else if (auto whileOp = dyn_cast(inst)) { + return BuildWhileOperator(whileOp, operands, results); + } + CustomOptionsOffset custom_options; - if (auto ifOp = dyn_cast(inst)) { - op_name = "Experimental_If"; - custom_options = CreateIfOpCustomOptions(ifOp); - } else if (auto whileOp = dyn_cast(inst)) { - op_name = "Experimental_While"; - custom_options = CreateWhileOpCustomOptions(whileOp); - } else { - // Ops in TF dialect can either be custom ops or flex ops. - // The reason we go directly from TensorFlow dialect MLIR to tensorflow - // node instead of going to TF table gen'd ops via generated code is that - // we do not want to restrict custom and flex op conversion support to - // only those TF ops that are currently registered in MLIR. The current - // model is of an open op system. - // - // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex - // we emit op as flex. - // if custom is enabled - // we emit the op as custom. - auto node_def = getTensorFlowNodeDef(inst); - if (!node_def) { + // Ops in TF dialect can either be custom ops or flex ops. + // The reason we go directly from TensorFlow dialect MLIR to tensorflow + // node instead of going to TF table gen'd ops via generated code is that + // we do not want to restrict custom and flex op conversion support to + // only those TF ops that are currently registered in MLIR. The current + // model is of an open op system. + // + // The following algorithm is followed: + // if flex is enabled and the op is whitelisted as flex + // we emit op as flex. + // if custom is enabled + // we emit the op as custom. + auto node_def = getTensorFlowNodeDef(inst); + if (!node_def) { + return llvm::None; + } + + // Flex op case + // Eventually, the whitelist will go away and we will rely on some TF op + // trait (e.g. No side effect) to determine if it is a supported "Flex" + // op or not. + if (enabled_op_types_.contains(OpType::kSelectTf) && + IsWhitelistedFlexOp(node_def->op())) { + // Construct ops as flex op encoding TensorFlow node definition + // as custom options. + // Flex ops are named with the kFlexOpNamePrefix prefix to the actual + // TF op name. + op_name = std::string(kFlexOpNamePrefix) + node_def->op(); + if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { return llvm::None; } - - // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op - // trait (e.g. No side effect) to determine if it is a supported "Flex" - // op or not. - if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { - // Construct ops as flex op encoding TensorFlow node definition - // as custom options. - // Flex ops are named with the kFlexOpNamePrefix prefix to the actual - // TF op name. - op_name = std::string(kFlexOpNamePrefix) + node_def->op(); - if (auto options = - CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else if (enabled_op_types_.contains(OpType::kCustomOp)) { - // Generic case of custom ops - write using flex buffers since that - // is the only custom options supported by TFLite today. - op_name = node_def->op(); - if (auto options = - CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } + } else if (enabled_op_types_.contains(OpType::kCustomOp)) { + // Generic case of custom ops - write using flex buffers since that + // is the only custom options supported by TFLite today. + op_name = node_def->op(); + if (auto options = + CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; } else { - return inst->emitOpError("is neither a custom op nor a flex op"), - llvm::None; + return llvm::None; } + } else { + return inst->emitOpError("is neither a custom op nor a flex op"), + llvm::None; } uint32_t opcode_index = diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir index 7702045547e..03048bd640d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir @@ -1,12 +1,12 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure + // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { // CHECK-NEXT: builtin_code: LESS // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "Experimental_If" +// CHECK-NEXT: builtin_code: IF // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { @@ -52,8 +52,12 @@ // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 2, 0, 1 ], // CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: custom_options: [ 116, 104, 101, 110, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 101, 108, 115, 101, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ] -// CHECK-NEXT: } ] +// CHECK-NEXT: builtin_options_type: IfOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: then_subgraph_index: 1, +// CHECK-NEXT: else_subgraph_index: 2 +// CHECK-NEXT: } +// CHECK-NEXT: } ], // CHECK-NEXT: name: "main" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -88,7 +92,7 @@ // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: // CHECK-NEXT: } -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "cond_true" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -123,7 +127,7 @@ // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: // CHECK-NEXT: } -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "cond_false" // CHECK-NEXT: } ], // CHECK-NEXT: description: "MLIR Converted.", diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir index fd403aa72c5..117f97455cc 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir @@ -3,8 +3,7 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "Experimental_While" +// CHECK-NEXT: builtin_code: WHILE // CHECK-NEXT: }, { // CHECK-NEXT: builtin_code: GREATER // CHECK-NEXT: }, { @@ -49,8 +48,12 @@ // CHECK-NEXT: operators: [ { // CHECK-NEXT: inputs: [ 0, 1 ], // CHECK-NEXT: outputs: [ 2, 3 ], -// CHECK-NEXT: custom_options: [ 99, 111, 110, 100, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 98, 111, 100, 121, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ] -// CHECK-NEXT: } ] +// CHECK-NEXT: builtin_options_type: WhileOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: cond_subgraph_index: 1, +// CHECK-NEXT: body_subgraph_index: 2 +// CHECK-NEXT: } +// CHECK-NEXT: } ], // CHECK-NEXT: name: "main" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -91,7 +94,7 @@ // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 0, 2 ], // CHECK-NEXT: outputs: [ 3 ] -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "cond" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -151,7 +154,7 @@ // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: // CHECK-NEXT: } -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "body" // CHECK-NEXT: } ], // CHECK-NEXT: description: "MLIR Converted.", diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 1ed7022fc02..785853f2db1 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -143,6 +143,8 @@ typedef enum { kTfLiteBuiltinMatrixSetDiag = 115, kTfLiteBuiltinRound = 116, kTfLiteBuiltinHardSwish = 117, + kTfLiteBuiltinIf = 118, + kTfLiteBuiltinWhile = 119, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index 283d15de67b..00ed17d5a04 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -391,6 +391,16 @@ typedef struct { EmptyStructPlaceholder placeholder; } TfLiteMatrixSetDiagParams; +typedef struct { + int then_subgraph_index; + int else_subgraph_index; +} TfLiteIfParams; + +typedef struct { + int cond_subgraph_index; + int body_subgraph_index; +} TfLiteWhileParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index a0f97da58ce..53a4e8fcc5a 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -721,6 +721,24 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params.release()); break; } + case BuiltinOperator_IF: { + TfLiteIfParams* params = allocator->AllocatePOD(); + if (const auto* if_params = op->builtin_options_as_IfOptions()) { + params->then_subgraph_index = if_params->then_subgraph_index(); + params->else_subgraph_index = if_params->else_subgraph_index(); + } + *builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_WHILE: { + TfLiteWhileParams* params = allocator->AllocatePOD(); + if (const auto* while_params = op->builtin_options_as_WhileOptions()) { + params->cond_subgraph_index = while_params->cond_subgraph_index(); + params->body_subgraph_index = while_params->body_subgraph_index(); + } + *builtin_data = reinterpret_cast(params); + break; + } // Below are the ops with no builtin_data structure. case BuiltinOperator_ABS: case BuiltinOperator_BATCH_TO_SPACE_ND: diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index 2ea105f4127..cdb1372b929 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -40,6 +40,7 @@ static const char* param_structs[] = {"TfLiteAddParams", "TfLiteFakeQuantParams", "TfLiteFullyConnectedParams", "TfLiteGatherParams", + "TfLiteIfParams", "TfLiteL2NormParams", "TfLiteLeakyReluParams", "TfLiteLocalResponseNormParams", @@ -76,6 +77,7 @@ static const char* param_structs[] = {"TfLiteAddParams", "TfLiteUniqueParams", "TfLiteUnpackParams", "TfLiteReverseSequenceParams", + "TfLiteWhileParams", nullptr}; } // namespace diff --git a/tensorflow/lite/kernels/if.cc b/tensorflow/lite/kernels/if.cc index 1bd394e9800..610af8cd4b9 100644 --- a/tensorflow/lite/kernels/if.cc +++ b/tensorflow/lite/kernels/if.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers + +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/core/subgraph.h" @@ -30,10 +32,9 @@ struct OpData { void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData; - const uint8_t* buffer_t = reinterpret_cast(buffer); - const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); - op_data->then_subgraph_index = m["then_subgraph_index"].AsInt32(); - op_data->else_subgraph_index = m["else_subgraph_index"].AsInt32(); + const auto* params = reinterpret_cast(buffer); + op_data->then_subgraph_index = params->then_subgraph_index; + op_data->else_subgraph_index = params->else_subgraph_index; return op_data; } diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index bd2643aaa64..6832ac73f6d 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -381,6 +381,10 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version */ 2); AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); + // WARNING: Control flow ops are experimental and subject to change. + AddBuiltin(BuiltinOperator_IF, tflite::ops::custom::Register_IF()); + AddBuiltin(BuiltinOperator_WHILE, tflite::ops::custom::Register_WHILE()); + // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); @@ -388,10 +392,6 @@ BuiltinOpResolver::BuiltinOpResolver() { tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS()); - - // WARNING: Control flow ops are experimental and subject to change. - AddCustom("Experimental_If", tflite::ops::custom::Register_IF()); - AddCustom("Experimental_While", tflite::ops::custom::Register_WHILE()); } } // namespace builtin diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index e55965ecf94..b60bdab080d 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -170,18 +170,14 @@ void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) { SetupTensor(subgraph, kInput2, kTfLiteInt32); SetupTensor(subgraph, kOutput, kTfLiteInt32); - flexbuffers::Builder fbb; - fbb.Map([&]() { - fbb.Int("then_subgraph_index", 1); - fbb.Int("else_subgraph_index", 2); - }); - fbb.Finish(); - const auto& buffer = fbb.GetBuffer(); + TfLiteIfParams* params = + reinterpret_cast(malloc(sizeof(TfLiteIfParams))); + params->then_subgraph_index = 1; + params->else_subgraph_index = 2; int node_index; subgraph->AddNodeWithParameters( - {kCondInput, kInput1, kInput2}, {kOutput}, {}, - reinterpret_cast(buffer.data()), buffer.size(), nullptr, + {kCondInput, kInput1, kInput2}, {kOutput}, {}, nullptr, 0, params, ::tflite::ops::custom::Register_IF(), &node_index); } @@ -333,19 +329,15 @@ void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) { SetupTensor(subgraph, kOutput1, kTfLiteInt32); SetupTensor(subgraph, kOutput2, kTfLiteInt32); - flexbuffers::Builder fbb; - fbb.Map([&]() { - fbb.Int("cond_subgraph_index", 1); - fbb.Int("body_subgraph_index", 2); - }); - fbb.Finish(); - const auto& buffer = fbb.GetBuffer(); + TfLiteWhileParams* params = + reinterpret_cast(malloc(sizeof(TfLiteWhileParams))); + params->cond_subgraph_index = 1; + params->body_subgraph_index = 2; int node_index; - subgraph->AddNodeWithParameters( - {0, 1}, {2, 3}, {}, reinterpret_cast(buffer.data()), - buffer.size(), nullptr, ::tflite::ops::custom::Register_WHILE(), - &node_index); + subgraph->AddNodeWithParameters({0, 1}, {2, 3}, {}, nullptr, 0, params, + ::tflite::ops::custom::Register_WHILE(), + &node_index); } void SubgraphBuilder::CreateConstantInt32Tensor(Subgraph* subgraph, diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc index a6438558458..6ac1d4b1e91 100644 --- a/tensorflow/lite/kernels/while.cc +++ b/tensorflow/lite/kernels/while.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "flatbuffers/flexbuffers.h" // TF:flatbuffers + +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/context_util.h" @@ -107,10 +109,9 @@ struct OpData { void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData; - const uint8_t* buffer_t = reinterpret_cast(buffer); - const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); - op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32(); - op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32(); + const auto* params = reinterpret_cast(buffer); + op_data->cond_subgraph_index = params->cond_subgraph_index; + op_data->body_subgraph_index = params->body_subgraph_index; op_data->cond_has_dynamic_output_tensors = false; op_data->body_has_dynamic_output_tensors = false; return op_data; diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 65c7156f0d3..b82bbdfd103 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -231,6 +231,8 @@ enum BuiltinOperator : byte { MATRIX_SET_DIAG = 115, ROUND = 116, HARD_SWISH = 117, + IF = 118, + WHILE = 119, } // Options for the builtin operators. @@ -325,7 +327,9 @@ union BuiltinOptions { MatrixDiagOptions, QuantizeOptions, MatrixSetDiagOptions, - HardSwishOptions + HardSwishOptions, + IfOptions, + WhileOptions } enum Padding : byte { SAME, VALID } @@ -783,6 +787,16 @@ table QuantizeOptions { table MatrixSetDiagOptions { } +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index abe1f3f9a4a..07d554444b0 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -304,6 +304,12 @@ struct QuantizeOptionsT; struct MatrixSetDiagOptions; struct MatrixSetDiagOptionsT; +struct IfOptions; +struct IfOptionsT; + +struct WhileOptions; +struct WhileOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -577,11 +583,13 @@ enum BuiltinOperator { BuiltinOperator_MATRIX_SET_DIAG = 115, BuiltinOperator_ROUND = 116, BuiltinOperator_HARD_SWISH = 117, + BuiltinOperator_IF = 118, + BuiltinOperator_WHILE = 119, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_HARD_SWISH + BuiltinOperator_MAX = BuiltinOperator_WHILE }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[117] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[119] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -699,7 +707,9 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[117] { BuiltinOperator_QUANTIZE, BuiltinOperator_MATRIX_SET_DIAG, BuiltinOperator_ROUND, - BuiltinOperator_HARD_SWISH + BuiltinOperator_HARD_SWISH, + BuiltinOperator_IF, + BuiltinOperator_WHILE }; return values; } @@ -824,13 +834,15 @@ inline const char * const *EnumNamesBuiltinOperator() { "MATRIX_SET_DIAG", "ROUND", "HARD_SWISH", + "IF", + "WHILE", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (e < BuiltinOperator_ADD || e > BuiltinOperator_HARD_SWISH) return ""; + if (e < BuiltinOperator_ADD || e > BuiltinOperator_WHILE) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -928,11 +940,13 @@ enum BuiltinOptions { BuiltinOptions_QuantizeOptions = 89, BuiltinOptions_MatrixSetDiagOptions = 90, BuiltinOptions_HardSwishOptions = 91, + BuiltinOptions_IfOptions = 92, + BuiltinOptions_WhileOptions = 93, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_HardSwishOptions + BuiltinOptions_MAX = BuiltinOptions_WhileOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[92] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[94] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1025,7 +1039,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[92] { BuiltinOptions_MatrixDiagOptions, BuiltinOptions_QuantizeOptions, BuiltinOptions_MatrixSetDiagOptions, - BuiltinOptions_HardSwishOptions + BuiltinOptions_HardSwishOptions, + BuiltinOptions_IfOptions, + BuiltinOptions_WhileOptions }; return values; } @@ -1124,13 +1140,15 @@ inline const char * const *EnumNamesBuiltinOptions() { "QuantizeOptions", "MatrixSetDiagOptions", "HardSwishOptions", + "IfOptions", + "WhileOptions", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (e < BuiltinOptions_NONE || e > BuiltinOptions_HardSwishOptions) return ""; + if (e < BuiltinOptions_NONE || e > BuiltinOptions_WhileOptions) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1503,6 +1521,14 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_IfOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2263,6 +2289,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_HardSwishOptions ? reinterpret_cast(value) : nullptr; } + IfOptionsT *AsIfOptions() { + return type == BuiltinOptions_IfOptions ? + reinterpret_cast(value) : nullptr; + } + const IfOptionsT *AsIfOptions() const { + return type == BuiltinOptions_IfOptions ? + reinterpret_cast(value) : nullptr; + } + WhileOptionsT *AsWhileOptions() { + return type == BuiltinOptions_WhileOptions ? + reinterpret_cast(value) : nullptr; + } + const WhileOptionsT *AsWhileOptions() const { + return type == BuiltinOptions_WhileOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -7856,6 +7898,138 @@ inline flatbuffers::Offset CreateMatrixSetDiagOptions( flatbuffers::Offset CreateMatrixSetDiagOptions(flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct IfOptionsT : public flatbuffers::NativeTable { + typedef IfOptions TableType; + int32_t then_subgraph_index; + int32_t else_subgraph_index; + IfOptionsT() + : then_subgraph_index(0), + else_subgraph_index(0) { + } +}; + +struct IfOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef IfOptionsT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_THEN_SUBGRAPH_INDEX = 4, + VT_ELSE_SUBGRAPH_INDEX = 6 + }; + int32_t then_subgraph_index() const { + return GetField(VT_THEN_SUBGRAPH_INDEX, 0); + } + int32_t else_subgraph_index() const { + return GetField(VT_ELSE_SUBGRAPH_INDEX, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_THEN_SUBGRAPH_INDEX) && + VerifyField(verifier, VT_ELSE_SUBGRAPH_INDEX) && + verifier.EndTable(); + } + IfOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct IfOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_then_subgraph_index(int32_t then_subgraph_index) { + fbb_.AddElement(IfOptions::VT_THEN_SUBGRAPH_INDEX, then_subgraph_index, 0); + } + void add_else_subgraph_index(int32_t else_subgraph_index) { + fbb_.AddElement(IfOptions::VT_ELSE_SUBGRAPH_INDEX, else_subgraph_index, 0); + } + explicit IfOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + IfOptionsBuilder &operator=(const IfOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateIfOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t then_subgraph_index = 0, + int32_t else_subgraph_index = 0) { + IfOptionsBuilder builder_(_fbb); + builder_.add_else_subgraph_index(else_subgraph_index); + builder_.add_then_subgraph_index(then_subgraph_index); + return builder_.Finish(); +} + +flatbuffers::Offset CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct WhileOptionsT : public flatbuffers::NativeTable { + typedef WhileOptions TableType; + int32_t cond_subgraph_index; + int32_t body_subgraph_index; + WhileOptionsT() + : cond_subgraph_index(0), + body_subgraph_index(0) { + } +}; + +struct WhileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef WhileOptionsT NativeTableType; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COND_SUBGRAPH_INDEX = 4, + VT_BODY_SUBGRAPH_INDEX = 6 + }; + int32_t cond_subgraph_index() const { + return GetField(VT_COND_SUBGRAPH_INDEX, 0); + } + int32_t body_subgraph_index() const { + return GetField(VT_BODY_SUBGRAPH_INDEX, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COND_SUBGRAPH_INDEX) && + VerifyField(verifier, VT_BODY_SUBGRAPH_INDEX) && + verifier.EndTable(); + } + WhileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct WhileOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_cond_subgraph_index(int32_t cond_subgraph_index) { + fbb_.AddElement(WhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0); + } + void add_body_subgraph_index(int32_t body_subgraph_index) { + fbb_.AddElement(WhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0); + } + explicit WhileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + WhileOptionsBuilder &operator=(const WhileOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateWhileOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t cond_subgraph_index = 0, + int32_t body_subgraph_index = 0) { + WhileOptionsBuilder builder_(_fbb); + builder_.add_body_subgraph_index(body_subgraph_index); + builder_.add_cond_subgraph_index(cond_subgraph_index); + return builder_.Finish(); +} + +flatbuffers::Offset CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -8265,6 +8439,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const HardSwishOptions *builtin_options_as_HardSwishOptions() const { return builtin_options_type() == BuiltinOptions_HardSwishOptions ? static_cast(builtin_options()) : nullptr; } + const IfOptions *builtin_options_as_IfOptions() const { + return builtin_options_type() == BuiltinOptions_IfOptions ? static_cast(builtin_options()) : nullptr; + } + const WhileOptions *builtin_options_as_WhileOptions() const { + return builtin_options_type() == BuiltinOptions_WhileOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -8665,6 +8845,14 @@ template<> inline const HardSwishOptions *Operator::builtin_options_as inline const IfOptions *Operator::builtin_options_as() const { + return builtin_options_as_IfOptions(); +} + +template<> inline const WhileOptions *Operator::builtin_options_as() const { + return builtin_options_as_WhileOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -11690,6 +11878,64 @@ inline flatbuffers::Offset CreateMatrixSetDiagOptions(flat _fbb); } +inline IfOptionsT *IfOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new IfOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void IfOptions::UnPackTo(IfOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = then_subgraph_index(); _o->then_subgraph_index = _e; }; + { auto _e = else_subgraph_index(); _o->else_subgraph_index = _e; }; +} + +inline flatbuffers::Offset IfOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateIfOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateIfOptions(flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const IfOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _then_subgraph_index = _o->then_subgraph_index; + auto _else_subgraph_index = _o->else_subgraph_index; + return tflite::CreateIfOptions( + _fbb, + _then_subgraph_index, + _else_subgraph_index); +} + +inline WhileOptionsT *WhileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new WhileOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void WhileOptions::UnPackTo(WhileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; }; + { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; }; +} + +inline flatbuffers::Offset WhileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateWhileOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const WhileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _cond_subgraph_index = _o->cond_subgraph_index; + auto _body_subgraph_index = _o->body_subgraph_index; + return tflite::CreateWhileOptions( + _fbb, + _cond_subgraph_index, + _body_subgraph_index); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -12347,6 +12593,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -12729,6 +12983,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -13099,6 +13361,14 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateHardSwishOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(value); + return CreateIfOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(value); + return CreateWhileOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -13469,6 +13739,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new HardSwishOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_IfOptions: { + value = new IfOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_WhileOptions: { + value = new WhileOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -13931,6 +14209,16 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr;