Add APIs to enable selective registration of the builtin parse function.

With this CL:
 * We have the hooks needed to register an operator specific parse function with
   MicroMutableOpResolver and the retrieve it without ParseOpData being used.

 * This CL is still passing in ParseOpData as the operator specific parse
   function and that will be changed in a follow-on CL.

PiperOrigin-RevId: 314982707
Change-Id: I174259aabd66e97184a8a282832f6c71580366c9
This commit is contained in:
Advait Jain 2020-06-05 13:09:00 -07:00 committed by TensorFlower Gardener
parent 21715bfb30
commit d54578dba2
9 changed files with 157 additions and 43 deletions

View File

@ -74,6 +74,8 @@ cc_library(
":micro_compatibility", ":micro_compatibility",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/micro/kernels:micro_ops",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
], ],
@ -96,6 +98,8 @@ cc_library(
":micro_compatibility", ":micro_compatibility",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro/kernels:portable_optimized_micro_ops", "//tensorflow/lite/micro/kernels:portable_optimized_micro_ops",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
], ],

View File

@ -142,6 +142,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace dequantize } // namespace dequantize
TfLiteRegistration* Register_DEQUANTIZE() { TfLiteRegistration* Register_DEQUANTIZE() {
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
// this struct no longer needs to be static and can be returned by value.
static TfLiteRegistration r = {/*init=*/dequantize::Init, static TfLiteRegistration r = {/*init=*/dequantize::Init,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/dequantize::Prepare, /*prepare=*/dequantize::Prepare,

View File

@ -217,6 +217,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace fully_connected } // namespace fully_connected
TfLiteRegistration* Register_FULLY_CONNECTED() { TfLiteRegistration* Register_FULLY_CONNECTED() {
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
// this struct no longer needs to be static and can be returned by value.
static TfLiteRegistration r = {/*init=*/fully_connected::Init, static TfLiteRegistration r = {/*init=*/fully_connected::Init,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/fully_connected::Prepare, /*prepare=*/fully_connected::Prepare,

View File

@ -137,6 +137,9 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
} // namespace activations } // namespace activations
TfLiteRegistration* Register_SOFTMAX() { TfLiteRegistration* Register_SOFTMAX() {
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
// this struct no longer needs to be static and can be returned by value.
static TfLiteRegistration r = {/*init=*/nullptr, static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/activations::SoftmaxPrepare, /*prepare=*/activations::SoftmaxPrepare,

View File

@ -528,6 +528,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace svdf } // namespace svdf
TfLiteRegistration* Register_SVDF() { TfLiteRegistration* Register_SVDF() {
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
// this struct no longer needs to be static and can be returned by value.
static TfLiteRegistration r = {/*init=*/svdf::Init, static TfLiteRegistration r = {/*init=*/svdf::Init,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/svdf::Prepare, /*prepare=*/svdf::Prepare,

View File

@ -684,25 +684,35 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
BuiltinOperator op_type = BuiltinOperator op_type =
static_cast<BuiltinOperator>(registration->builtin_code); static_cast<BuiltinOperator>(registration->builtin_code);
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Unsupported behavior: found builtin operator %s with custom "
"options.\n",
EnumNameBuiltinOperator(op_type));
return kTfLiteError;
}
const char* custom_data = nullptr; const char* custom_data = nullptr;
size_t custom_data_size = 0; size_t custom_data_size = 0;
unsigned char* builtin_data = nullptr; unsigned char* builtin_data = nullptr;
if (op->custom_options()) {
custom_data = reinterpret_cast<const char*>(op->custom_options()->data()); if (op_type == BuiltinOperator_CUSTOM) {
custom_data_size = op->custom_options()->size(); // Custom Ops may or may not have a non-null custom_options field.
if (op->custom_options() != nullptr) {
custom_data =
reinterpret_cast<const char*>(op->custom_options()->data());
custom_data_size = op->custom_options()->size();
}
} else { } else {
if (op->custom_options() != nullptr) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Unsupported behavior: found builtin operator %s with custom "
"options.\n",
EnumNameBuiltinOperator(op_type));
return kTfLiteError;
}
MicroOpResolver::BuiltinParseFunction parser = MicroOpResolver::BuiltinParseFunction parser =
op_resolver.GetOpDataParser(op_type); op_resolver.GetOpDataParser(op_type);
TFLITE_DCHECK(parser != nullptr); if (parser == nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_, "Did not find a parser for %s",
EnumNameBuiltinOperator(op_type));
return kTfLiteError;
}
TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_, TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_,
&builtin_data_allocator, &builtin_data_allocator,
(void**)(&builtin_data))); (void**)(&builtin_data)));
@ -724,6 +734,6 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
} }
return kTfLiteOk; return kTfLiteOk;
} } // namespace tflite
} // namespace tflite } // namespace tflite

View File

@ -15,12 +15,16 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
#define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
#include <cstdio>
#include <cstring> #include <cstring>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/kernels/micro_ops.h"
#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -56,39 +60,67 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
MicroOpResolver::BuiltinParseFunction GetOpDataParser( MicroOpResolver::BuiltinParseFunction GetOpDataParser(
tflite::BuiltinOperator) const override { BuiltinOperator op) const override {
// TODO(b/149408647): Replace with the more selective builtin parser. TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
return ParseOpData; for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
if (builtin_codes_[i] == op) return builtin_parsers_[i];
}
return nullptr;
}
// The Add* functions below add the various Builtin operators to the
// MicroMutableOpResolver object.
//
// This API is currently experimental (and only supported for a small subset
// of operators). It will soon be preferred over the AddBuiltin override of
// the MicroOpResolver interface for the following reason:
// * If all calls to AddBuiltin for an application use this API, the code
// size will be smaller by 5-8K (compared to the using the AddBuiltin
// override).
TfLiteStatus AddDequantize() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function once cl/313453102 lands.
return AddBuiltin(BuiltinOperator_DEQUANTIZE,
*tflite::ops::micro::Register_DEQUANTIZE(), ParseOpData);
}
TfLiteStatus AddFullyConnected() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function once cl/313453102 lands.
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
*tflite::ops::micro::Register_FULLY_CONNECTED(),
ParseOpData);
}
TfLiteStatus AddQuantize() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function once cl/313453102 lands.
return AddBuiltin(BuiltinOperator_QUANTIZE,
*tflite::ops::micro::Register_QUANTIZE(), ParseOpData);
}
TfLiteStatus AddSoftmax() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function once cl/313453102 lands.
return AddBuiltin(BuiltinOperator_SOFTMAX,
*tflite::ops::micro::Register_SOFTMAX(), ParseOpData);
}
TfLiteStatus AddSvdf() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function once cl/313453102 lands.
return AddBuiltin(BuiltinOperator_SVDF,
*tflite::ops::micro::Register_SVDF(), ParseOpData);
} }
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration) override { TfLiteRegistration* registration) override {
if (registrations_len_ >= tOpCount) { TFLITE_DCHECK(registration != nullptr);
if (error_reporter_) { // For code that is not switched over to the new selective registration of
TF_LITE_REPORT_ERROR(error_reporter_, // the parse function, we pass in ParseOpData. This allows for backwards
"Couldn't register builtin op #%d, resolver size " // compatibility.
"is too small (%d)", return AddBuiltin(op, *registration, ParseOpData);
op, tOpCount);
}
return kTfLiteError;
}
if (FindOp(op) != nullptr) {
if (error_reporter_ != nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Calling AddBuiltin with the same op more than "
"once is not supported (Op: #%d).",
op);
}
return kTfLiteError;
}
TfLiteRegistration* new_registration = &registrations_[registrations_len_];
registrations_len_ += 1;
*new_registration = *registration;
new_registration->builtin_code = op;
return kTfLiteOk;
} }
TfLiteStatus AddCustom(const char* name, TfLiteStatus AddCustom(const char* name,
@ -125,8 +157,60 @@ class MicroMutableOpResolver : public MicroOpResolver {
unsigned int GetRegistrationLength() { return registrations_len_; } unsigned int GetRegistrationLength() { return registrations_len_; }
private: private:
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
const TfLiteRegistration& registration,
MicroOpResolver::BuiltinParseFunction parser) {
if (op == BuiltinOperator_CUSTOM) {
if (error_reporter_ != nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Invalid parameter BuiltinOperator_CUSTOM to the "
"AddBuiltin function.");
}
return kTfLiteError;
}
if (FindOp(op) != nullptr) {
if (error_reporter_ != nullptr) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Calling AddBuiltin with the same op more than "
"once is not supported (Op: #%d).",
op);
}
return kTfLiteError;
}
if (registrations_len_ >= tOpCount) {
if (error_reporter_) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Couldn't register builtin op #%d, resolver size "
"is too small (%d).",
op, tOpCount);
}
return kTfLiteError;
}
registrations_[registrations_len_] = registration;
// Strictly speaking, the builtin_code is not necessary for TFLM but filling
// it in regardless.
registrations_[registrations_len_].builtin_code = op;
registrations_len_++;
builtin_codes_[num_buitin_ops_] = op;
builtin_parsers_[num_buitin_ops_] = parser;
num_buitin_ops_++;
return kTfLiteOk;
}
TfLiteRegistration registrations_[tOpCount]; TfLiteRegistration registrations_[tOpCount];
unsigned int registrations_len_ = 0; unsigned int registrations_len_ = 0;
// Arrays (and counter) to store the builtin codes and their corresponding
// parse functions as these are registered with the Op Resolver.
BuiltinOperator builtin_codes_[tOpCount];
MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
unsigned int num_buitin_ops_ = 0;
ErrorReporter* error_reporter_; ErrorReporter* error_reporter_;
TF_LITE_REMOVE_VIRTUAL_DELETE TF_LITE_REMOVE_VIRTUAL_DELETE

View File

@ -50,6 +50,9 @@ class MicroOpResolver : public OpResolver {
// i.e. if this function is called again for a previously added // i.e. if this function is called again for a previously added
// BuiltinOperator, the MicroOpResolver will be unchanged and this function // BuiltinOperator, the MicroOpResolver will be unchanged and this function
// will return kTfLiteError. // will return kTfLiteError.
//
// TODO(b/149408647): remove this API once the templated AddBuiltin API in
// MicroMutableOpResolver is properly implemented.
virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration) = 0; TfLiteRegistration* registration) = 0;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <new>
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/memory_helpers.h"