Add APIs to enable selective registration of the builtin parse function.
With this CL: * We have the hooks needed to register an operator specific parse function with MicroMutableOpResolver and the retrieve it without ParseOpData being used. * This CL is still passing in ParseOpData as the operator specific parse function and that will be changed in a follow-on CL. PiperOrigin-RevId: 314982707 Change-Id: I174259aabd66e97184a8a282832f6c71580366c9
This commit is contained in:
parent
21715bfb30
commit
d54578dba2
@ -74,6 +74,8 @@ cc_library(
|
||||
":micro_compatibility",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/micro/kernels:micro_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
@ -96,6 +98,8 @@ cc_library(
|
||||
":micro_compatibility",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels/internal:compatibility",
|
||||
"//tensorflow/lite/micro/kernels:portable_optimized_micro_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
|
@ -142,6 +142,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace dequantize
|
||||
|
||||
TfLiteRegistration* Register_DEQUANTIZE() {
|
||||
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
|
||||
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
|
||||
// this struct no longer needs to be static and can be returned by value.
|
||||
static TfLiteRegistration r = {/*init=*/dequantize::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/dequantize::Prepare,
|
||||
|
@ -217,6 +217,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace fully_connected
|
||||
|
||||
TfLiteRegistration* Register_FULLY_CONNECTED() {
|
||||
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
|
||||
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
|
||||
// this struct no longer needs to be static and can be returned by value.
|
||||
static TfLiteRegistration r = {/*init=*/fully_connected::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/fully_connected::Prepare,
|
||||
|
@ -137,6 +137,9 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace activations
|
||||
|
||||
TfLiteRegistration* Register_SOFTMAX() {
|
||||
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
|
||||
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
|
||||
// this struct no longer needs to be static and can be returned by value.
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/activations::SoftmaxPrepare,
|
||||
|
@ -528,6 +528,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace svdf
|
||||
|
||||
TfLiteRegistration* Register_SVDF() {
|
||||
// TODO(b/149408647): Once we remove AddBuiltin from MicroOpResolver and
|
||||
// completely switch to the templated AddBuiltin from MicroMutableOpResolver,
|
||||
// this struct no longer needs to be static and can be returned by value.
|
||||
static TfLiteRegistration r = {/*init=*/svdf::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/svdf::Prepare,
|
||||
|
@ -684,7 +684,19 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
BuiltinOperator op_type =
|
||||
static_cast<BuiltinOperator>(registration->builtin_code);
|
||||
|
||||
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
|
||||
const char* custom_data = nullptr;
|
||||
size_t custom_data_size = 0;
|
||||
unsigned char* builtin_data = nullptr;
|
||||
|
||||
if (op_type == BuiltinOperator_CUSTOM) {
|
||||
// Custom Ops may or may not have a non-null custom_options field.
|
||||
if (op->custom_options() != nullptr) {
|
||||
custom_data =
|
||||
reinterpret_cast<const char*>(op->custom_options()->data());
|
||||
custom_data_size = op->custom_options()->size();
|
||||
}
|
||||
} else {
|
||||
if (op->custom_options() != nullptr) {
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter_,
|
||||
"Unsupported behavior: found builtin operator %s with custom "
|
||||
@ -693,16 +705,14 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
const char* custom_data = nullptr;
|
||||
size_t custom_data_size = 0;
|
||||
unsigned char* builtin_data = nullptr;
|
||||
if (op->custom_options()) {
|
||||
custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
|
||||
custom_data_size = op->custom_options()->size();
|
||||
} else {
|
||||
MicroOpResolver::BuiltinParseFunction parser =
|
||||
op_resolver.GetOpDataParser(op_type);
|
||||
TFLITE_DCHECK(parser != nullptr);
|
||||
if (parser == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_, "Did not find a parser for %s",
|
||||
EnumNameBuiltinOperator(op_type));
|
||||
|
||||
return kTfLiteError;
|
||||
}
|
||||
TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_,
|
||||
&builtin_data_allocator,
|
||||
(void**)(&builtin_data)));
|
||||
@ -724,6 +734,6 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -15,12 +15,16 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
|
||||
#define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/compatibility.h"
|
||||
#include "tensorflow/lite/micro/kernels/micro_ops.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -56,39 +60,67 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
}
|
||||
|
||||
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
|
||||
tflite::BuiltinOperator) const override {
|
||||
// TODO(b/149408647): Replace with the more selective builtin parser.
|
||||
return ParseOpData;
|
||||
BuiltinOperator op) const override {
|
||||
TFLITE_DCHECK(num_buitin_ops_ <= tOpCount);
|
||||
for (unsigned int i = 0; i < num_buitin_ops_; ++i) {
|
||||
if (builtin_codes_[i] == op) return builtin_parsers_[i];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The Add* functions below add the various Builtin operators to the
|
||||
// MicroMutableOpResolver object.
|
||||
//
|
||||
// This API is currently experimental (and only supported for a small subset
|
||||
// of operators). It will soon be preferred over the AddBuiltin override of
|
||||
// the MicroOpResolver interface for the following reason:
|
||||
// * If all calls to AddBuiltin for an application use this API, the code
|
||||
// size will be smaller by 5-8K (compared to the using the AddBuiltin
|
||||
// override).
|
||||
|
||||
TfLiteStatus AddDequantize() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function once cl/313453102 lands.
|
||||
return AddBuiltin(BuiltinOperator_DEQUANTIZE,
|
||||
*tflite::ops::micro::Register_DEQUANTIZE(), ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddFullyConnected() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function once cl/313453102 lands.
|
||||
return AddBuiltin(BuiltinOperator_FULLY_CONNECTED,
|
||||
*tflite::ops::micro::Register_FULLY_CONNECTED(),
|
||||
ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddQuantize() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function once cl/313453102 lands.
|
||||
return AddBuiltin(BuiltinOperator_QUANTIZE,
|
||||
*tflite::ops::micro::Register_QUANTIZE(), ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSoftmax() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function once cl/313453102 lands.
|
||||
return AddBuiltin(BuiltinOperator_SOFTMAX,
|
||||
*tflite::ops::micro::Register_SOFTMAX(), ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSvdf() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function once cl/313453102 lands.
|
||||
return AddBuiltin(BuiltinOperator_SVDF,
|
||||
*tflite::ops::micro::Register_SVDF(), ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
|
||||
TfLiteRegistration* registration) override {
|
||||
if (registrations_len_ >= tOpCount) {
|
||||
if (error_reporter_) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Couldn't register builtin op #%d, resolver size "
|
||||
"is too small (%d)",
|
||||
op, tOpCount);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (FindOp(op) != nullptr) {
|
||||
if (error_reporter_ != nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Calling AddBuiltin with the same op more than "
|
||||
"once is not supported (Op: #%d).",
|
||||
op);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TfLiteRegistration* new_registration = ®istrations_[registrations_len_];
|
||||
registrations_len_ += 1;
|
||||
|
||||
*new_registration = *registration;
|
||||
new_registration->builtin_code = op;
|
||||
return kTfLiteOk;
|
||||
TFLITE_DCHECK(registration != nullptr);
|
||||
// For code that is not switched over to the new selective registration of
|
||||
// the parse function, we pass in ParseOpData. This allows for backwards
|
||||
// compatibility.
|
||||
return AddBuiltin(op, *registration, ParseOpData);
|
||||
}
|
||||
|
||||
TfLiteStatus AddCustom(const char* name,
|
||||
@ -125,8 +157,60 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
unsigned int GetRegistrationLength() { return registrations_len_; }
|
||||
|
||||
private:
|
||||
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
|
||||
const TfLiteRegistration& registration,
|
||||
MicroOpResolver::BuiltinParseFunction parser) {
|
||||
if (op == BuiltinOperator_CUSTOM) {
|
||||
if (error_reporter_ != nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Invalid parameter BuiltinOperator_CUSTOM to the "
|
||||
"AddBuiltin function.");
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (FindOp(op) != nullptr) {
|
||||
if (error_reporter_ != nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Calling AddBuiltin with the same op more than "
|
||||
"once is not supported (Op: #%d).",
|
||||
op);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (registrations_len_ >= tOpCount) {
|
||||
if (error_reporter_) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
"Couldn't register builtin op #%d, resolver size "
|
||||
"is too small (%d).",
|
||||
op, tOpCount);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
registrations_[registrations_len_] = registration;
|
||||
// Strictly speaking, the builtin_code is not necessary for TFLM but filling
|
||||
// it in regardless.
|
||||
registrations_[registrations_len_].builtin_code = op;
|
||||
registrations_len_++;
|
||||
|
||||
builtin_codes_[num_buitin_ops_] = op;
|
||||
builtin_parsers_[num_buitin_ops_] = parser;
|
||||
num_buitin_ops_++;
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteRegistration registrations_[tOpCount];
|
||||
unsigned int registrations_len_ = 0;
|
||||
|
||||
// Arrays (and counter) to store the builtin codes and their corresponding
|
||||
// parse functions as these are registered with the Op Resolver.
|
||||
BuiltinOperator builtin_codes_[tOpCount];
|
||||
MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount];
|
||||
unsigned int num_buitin_ops_ = 0;
|
||||
|
||||
ErrorReporter* error_reporter_;
|
||||
|
||||
TF_LITE_REMOVE_VIRTUAL_DELETE
|
||||
|
@ -50,6 +50,9 @@ class MicroOpResolver : public OpResolver {
|
||||
// i.e. if this function is called again for a previously added
|
||||
// BuiltinOperator, the MicroOpResolver will be unchanged and this function
|
||||
// will return kTfLiteError.
|
||||
//
|
||||
// TODO(b/149408647): remove this API once the templated AddBuiltin API in
|
||||
// MicroMutableOpResolver is properly implemented.
|
||||
virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
|
||||
TfLiteRegistration* registration) = 0;
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <new>
|
||||
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
|
Loading…
Reference in New Issue
Block a user