From d54578dba2178cda3d6235a47e311015de62c3f1 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Fri, 5 Jun 2020 13:09:00 -0700 Subject: [PATCH] Add APIs to enable selective registration of the builtin parse function. With this CL: * We have the hooks needed to register an operator specific parse function with MicroMutableOpResolver and the retrieve it without ParseOpData being used. * This CL is still passing in ParseOpData as the operator specific parse function and that will be changed in a follow-on CL. PiperOrigin-RevId: 314982707 Change-Id: I174259aabd66e97184a8a282832f6c71580366c9 --- tensorflow/lite/micro/BUILD | 4 + tensorflow/lite/micro/kernels/dequantize.cc | 3 + .../lite/micro/kernels/fully_connected.cc | 3 + tensorflow/lite/micro/kernels/softmax.cc | 3 + tensorflow/lite/micro/kernels/svdf.cc | 3 + tensorflow/lite/micro/micro_allocator.cc | 38 +++-- .../lite/micro/micro_mutable_op_resolver.h | 142 ++++++++++++++---- tensorflow/lite/micro/micro_op_resolver.h | 3 + .../lite/micro/simple_memory_allocator.cc | 1 + 9 files changed, 157 insertions(+), 43 deletions(-) diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index e8e489141ed..915e3b64272 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -74,6 +74,8 @@ cc_library( ":micro_compatibility", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", ], @@ -96,6 +98,8 @@ cc_library( ":micro_compatibility", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/micro/kernels:portable_optimized_micro_ops", "//tensorflow/lite/schema:schema_fbs", ], diff --git a/tensorflow/lite/micro/kernels/dequantize.cc b/tensorflow/lite/micro/kernels/dequantize.cc index 4b87c0eb04c..1fa136ae117 100644 --- a/tensorflow/lite/micro/kernels/dequantize.cc +++ b/tensorflow/lite/micro/kernels/dequantize.cc @@ -142,6 +142,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace dequantize TfLiteRegistration* Register_DEQUANTIZE() { + // TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and + // completely switch to the templated AddBuiltin from MicroMutableOpResolver, + // this struct no longer needs to be static and can be returned by value. static TfLiteRegistration r = {/*init=*/dequantize::Init, /*free=*/nullptr, /*prepare=*/dequantize::Prepare, diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 66b8379739d..bd949e6f552 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -217,6 +217,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace fully_connected TfLiteRegistration* Register_FULLY_CONNECTED() { + // TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and + // completely switch to the templated AddBuiltin from MicroMutableOpResolver, + // this struct no longer needs to be static and can be returned by value. static TfLiteRegistration r = {/*init=*/fully_connected::Init, /*free=*/nullptr, /*prepare=*/fully_connected::Prepare, diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index 3d8e7cd87f7..616017ecc5b 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -137,6 +137,9 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } // namespace activations TfLiteRegistration* Register_SOFTMAX() { + // TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and + // completely switch to the templated AddBuiltin from MicroMutableOpResolver, + // this struct no longer needs to be static and can be returned by value. static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, /*prepare=*/activations::SoftmaxPrepare, diff --git a/tensorflow/lite/micro/kernels/svdf.cc b/tensorflow/lite/micro/kernels/svdf.cc index 8c33fde5a87..ba7cb05da57 100644 --- a/tensorflow/lite/micro/kernels/svdf.cc +++ b/tensorflow/lite/micro/kernels/svdf.cc @@ -528,6 +528,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace svdf TfLiteRegistration* Register_SVDF() { + // TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and + // completely switch to the templated AddBuiltin from MicroMutableOpResolver, + // this struct no longer needs to be static and can be returned by value. static TfLiteRegistration r = {/*init=*/svdf::Init, /*free=*/nullptr, /*prepare=*/svdf::Prepare, diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index ad26483ec3c..5323476ec9f 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -684,25 +684,35 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( BuiltinOperator op_type = static_cast(registration->builtin_code); - if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Unsupported behavior: found builtin operator %s with custom " - "options.\n", - EnumNameBuiltinOperator(op_type)); - return kTfLiteError; - } - const char* custom_data = nullptr; size_t custom_data_size = 0; unsigned char* builtin_data = nullptr; - if (op->custom_options()) { - custom_data = reinterpret_cast(op->custom_options()->data()); - custom_data_size = op->custom_options()->size(); + + if (op_type == BuiltinOperator_CUSTOM) { + // Custom Ops may or may not have a non-null custom_options field. + if (op->custom_options() != nullptr) { + custom_data = + reinterpret_cast(op->custom_options()->data()); + custom_data_size = op->custom_options()->size(); + } } else { + if (op->custom_options() != nullptr) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Unsupported behavior: found builtin operator %s with custom " + "options.\n", + EnumNameBuiltinOperator(op_type)); + return kTfLiteError; + } + MicroOpResolver::BuiltinParseFunction parser = op_resolver.GetOpDataParser(op_type); - TFLITE_DCHECK(parser != nullptr); + if (parser == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, "Did not find a parser for %s", + EnumNameBuiltinOperator(op_type)); + + return kTfLiteError; + } TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_, &builtin_data_allocator, (void**)(&builtin_data))); @@ -724,6 +734,6 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( } return kTfLiteOk; -} +} // namespace tflite } // namespace tflite diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index dd0b593e201..34768ae3cb8 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -15,12 +15,16 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#include #include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -56,39 +60,67 @@ class MicroMutableOpResolver : public MicroOpResolver { } MicroOpResolver::BuiltinParseFunction GetOpDataParser( - tflite::BuiltinOperator) const override { - // TODO(b/149408647): Replace with the more selective builtin parser. - return ParseOpData; + BuiltinOperator op) const override { + TFLITE_DCHECK(num_buitin_ops_ <= tOpCount); + for (unsigned int i = 0; i < num_buitin_ops_; ++i) { + if (builtin_codes_[i] == op) return builtin_parsers_[i]; + } + return nullptr; + } + + // The Add* functions below add the various Builtin operators to the + // MicroMutableOpResolver object. + // + // This API is currently experimental (and only supported for a small subset + // of operators). It will soon be preferred over the AddBuiltin override of + // the MicroOpResolver interface for the following reason: + // * If all calls to AddBuiltin for an application use this API, the code + // size will be smaller by 5-8K (compared to the using the AddBuiltin + // override). + + TfLiteStatus AddDequantize() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function once cl/313453102 lands. + return AddBuiltin(BuiltinOperator_DEQUANTIZE, + *tflite::ops::micro::Register_DEQUANTIZE(), ParseOpData); + } + + TfLiteStatus AddFullyConnected() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function once cl/313453102 lands. + return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, + *tflite::ops::micro::Register_FULLY_CONNECTED(), + ParseOpData); + } + + TfLiteStatus AddQuantize() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function once cl/313453102 lands. + return AddBuiltin(BuiltinOperator_QUANTIZE, + *tflite::ops::micro::Register_QUANTIZE(), ParseOpData); + } + + TfLiteStatus AddSoftmax() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function once cl/313453102 lands. + return AddBuiltin(BuiltinOperator_SOFTMAX, + *tflite::ops::micro::Register_SOFTMAX(), ParseOpData); + } + + TfLiteStatus AddSvdf() { + // TODO(b/149408647): Replace ParseOpData with the operator specific parse + // function once cl/313453102 lands. + return AddBuiltin(BuiltinOperator_SVDF, + *tflite::ops::micro::Register_SVDF(), ParseOpData); } TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration) override { - if (registrations_len_ >= tOpCount) { - if (error_reporter_) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Couldn't register builtin op #%d, resolver size " - "is too small (%d)", - op, tOpCount); - } - return kTfLiteError; - } - - if (FindOp(op) != nullptr) { - if (error_reporter_ != nullptr) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Calling AddBuiltin with the same op more than " - "once is not supported (Op: #%d).", - op); - } - return kTfLiteError; - } - - TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; - registrations_len_ += 1; - - *new_registration = *registration; - new_registration->builtin_code = op; - return kTfLiteOk; + TFLITE_DCHECK(registration != nullptr); + // For code that is not switched over to the new selective registration of + // the parse function, we pass in ParseOpData. This allows for backwards + // compatibility. + return AddBuiltin(op, *registration, ParseOpData); } TfLiteStatus AddCustom(const char* name, @@ -125,8 +157,60 @@ class MicroMutableOpResolver : public MicroOpResolver { unsigned int GetRegistrationLength() { return registrations_len_; } private: + TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, + const TfLiteRegistration& registration, + MicroOpResolver::BuiltinParseFunction parser) { + if (op == BuiltinOperator_CUSTOM) { + if (error_reporter_ != nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Invalid parameter BuiltinOperator_CUSTOM to the " + "AddBuiltin function."); + } + return kTfLiteError; + } + + if (FindOp(op) != nullptr) { + if (error_reporter_ != nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Calling AddBuiltin with the same op more than " + "once is not supported (Op: #%d).", + op); + } + return kTfLiteError; + } + + if (registrations_len_ >= tOpCount) { + if (error_reporter_) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Couldn't register builtin op #%d, resolver size " + "is too small (%d).", + op, tOpCount); + } + return kTfLiteError; + } + + registrations_[registrations_len_] = registration; + // Strictly speaking, the builtin_code is not necessary for TFLM but filling + // it in regardless. + registrations_[registrations_len_].builtin_code = op; + registrations_len_++; + + builtin_codes_[num_buitin_ops_] = op; + builtin_parsers_[num_buitin_ops_] = parser; + num_buitin_ops_++; + + return kTfLiteOk; + } + TfLiteRegistration registrations_[tOpCount]; unsigned int registrations_len_ = 0; + + // Arrays (and counter) to store the builtin codes and their corresponding + // parse functions as these are registered with the Op Resolver. + BuiltinOperator builtin_codes_[tOpCount]; + MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount]; + unsigned int num_buitin_ops_ = 0; + ErrorReporter* error_reporter_; TF_LITE_REMOVE_VIRTUAL_DELETE diff --git a/tensorflow/lite/micro/micro_op_resolver.h b/tensorflow/lite/micro/micro_op_resolver.h index 49022cd70fa..0f5528d7b70 100644 --- a/tensorflow/lite/micro/micro_op_resolver.h +++ b/tensorflow/lite/micro/micro_op_resolver.h @@ -50,6 +50,9 @@ class MicroOpResolver : public OpResolver { // i.e. if this function is called again for a previously added // BuiltinOperator, the MicroOpResolver will be unchanged and this function // will return kTfLiteError. + // + // TODO(b/149408647): remove this API once the templated AddBuiltin API in + // MicroMutableOpResolver is properly implemented. virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; diff --git a/tensorflow/lite/micro/simple_memory_allocator.cc b/tensorflow/lite/micro/simple_memory_allocator.cc index a637b60b983..3b416047c8f 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.cc +++ b/tensorflow/lite/micro/simple_memory_allocator.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/micro/memory_helpers.h"