STT-tensorflow/tensorflow/lite/micro/micro_mutable_op_resolver.h
Advait Jain 33689c48ad Add MicroOpResolver interface class.
This will allow us to implement selective registration of the builtin parse
functions without changing the OpResolver base class in TFLite.

* MicroOpResolver is now an interface (matching the OpResolver name in TFLite).
* MicroMutableOpResolver is the implementation of the MicroOpResolver
  interface that should be used by applications that do not want to use
  AllOpsResolver.

PiperOrigin-RevId: 313691276
Change-Id: I0a9f51f6584326a3b3dd645cde083ba42116083d
2020-05-28 17:38:37 -07:00

154 lines
5.3 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
#define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
#include <cstring>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
// Op versions discussed in this file are enumerated here:
// tensorflow/lite/tools/versioning/op_version.cc
inline int MicroOpResolverAnyVersion() { return 0; }
template <unsigned int tOpCount>
class MicroMutableOpResolver : public MicroOpResolver {
public:
explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
: error_reporter_(error_reporter) {}
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override {
for (unsigned int i = 0; i < registrations_len_; ++i) {
const TfLiteRegistration& registration = registrations_[i];
if ((registration.builtin_code == op) &&
(registration.version == MicroOpResolverAnyVersion() ||
version == MicroOpResolverAnyVersion() ||
registration.version == version)) {
return &registration;
}
}
return nullptr;
}
const TfLiteRegistration* FindOp(const char* op, int version) const override {
for (unsigned int i = 0; i < registrations_len_; ++i) {
const TfLiteRegistration& registration = registrations_[i];
if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
(strcmp(registration.custom_name, op) == 0) &&
(registration.version == MicroOpResolverAnyVersion() ||
version == MicroOpResolverAnyVersion() ||
registration.version == version)) {
return &registration;
}
}
return nullptr;
}
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
tflite::BuiltinOperator) const override {
// TODO(b/149408647): Replace with the more selective builtin parser.
return ParseOpData;
}
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration,
int version = 1) override {
if (registrations_len_ >= tOpCount) {
if (error_reporter_) {
TF_LITE_REPORT_ERROR(error_reporter_,
"Couldn't register builtin op #%d, resolver size "
"is too small (%d)",
op, tOpCount);
}
return kTfLiteError;
}
TfLiteRegistration* new_registration = &registrations_[registrations_len_];
registrations_len_ += 1;
*new_registration = *registration;
new_registration->builtin_code = op;
new_registration->version = version;
return kTfLiteOk;
}
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration, int min_version,
int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteStatus add_status = AddBuiltin(op, registration, version);
if (add_status != kTfLiteOk) {
return add_status;
}
}
return kTfLiteOk;
}
TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration,
int version = 1) {
if (registrations_len_ >= tOpCount) {
if (error_reporter_) {
TF_LITE_REPORT_ERROR(
error_reporter_,
"Couldn't register custom op '%s', resolver size is too small (%d)",
name, tOpCount);
}
return kTfLiteError;
}
TfLiteRegistration* new_registration = &registrations_[registrations_len_];
registrations_len_ += 1;
*new_registration = *registration;
new_registration->builtin_code = BuiltinOperator_CUSTOM;
new_registration->custom_name = name;
new_registration->version = version;
return kTfLiteOk;
}
TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteStatus add_status = AddCustom(name, registration, version);
if (add_status != kTfLiteOk) {
return add_status;
}
}
return kTfLiteOk;
}
unsigned int GetRegistrationLength() { return registrations_len_; }
private:
TfLiteRegistration registrations_[tOpCount];
unsigned int registrations_len_ = 0;
ErrorReporter* error_reporter_;
TF_LITE_REMOVE_VIRTUAL_DELETE
};
}; // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_