Add MicroOpResolver interface class.

This will allow us to implement selective registration of the builtin parse
functions without changing the OpResolver base class in TFLite.

* MicroOpResolver is now an interface (matching the OpResolver name in TFLite).
* MicroMutableOpResolver is the implementation of the MicroOpResolver
  interface that should be used by applications that do not want to use
  AllOpsResolver.

PiperOrigin-RevId: 313691276
Change-Id: I0a9f51f6584326a3b3dd645cde083ba42116083d
This commit is contained in:
Advait Jain 2020-05-28 17:28:40 -07:00 committed by TensorFlower Gardener
parent f9fb66cdb7
commit 33689c48ad
23 changed files with 150 additions and 63 deletions

View File

@ -37,6 +37,7 @@ cc_library(
"micro_allocator.h",
"micro_interpreter.h",
"micro_mutable_op_resolver.h",
"micro_op_resolver.h",
"micro_optional_debug_tools.h",
"simple_memory_allocator.h",
"test_helpers.h",
@ -159,6 +160,7 @@ tflite_micro_cc_test(
deps = [
":micro_framework",
":micro_utils",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/micro/testing:micro_test",
],

View File

@ -84,7 +84,7 @@ class KeywordRunner {
const tflite::Model* keyword_spotting_model_;
tflite::MicroErrorReporter micro_reporter_;
tflite::ErrorReporter* reporter_;
tflite::MicroOpResolver<6> resolver_;
tflite::MicroMutableOpResolver<6> resolver_;
tflite::MicroInterpreter interpreter_;
};

View File

@ -98,7 +98,7 @@ class PersonDetectionRunner {
const tflite::Model* person_detection_model_;
tflite::MicroErrorReporter micro_reporter_;
tflite::ErrorReporter* reporter_;
tflite::MicroOpResolver<6> resolver_;
tflite::MicroMutableOpResolver<6> resolver_;
tflite::MicroInterpreter interpreter_;
};

View File

@ -42,7 +42,7 @@ TF_LITE_MICRO_TEST(TestImageRecognitionInvoke) {
model->version(), TFLITE_SCHEMA_VERSION);
}
tflite::MicroOpResolver<4> micro_op_resolver;
tflite::MicroMutableOpResolver<4> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());

View File

@ -56,7 +56,7 @@ int main(int argc, char** argv) {
return 1;
}
tflite::MicroOpResolver<4> micro_op_resolver;
tflite::MicroMutableOpResolver<4> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
tflite::ops::micro::Register_CONV_2D());

View File

@ -46,7 +46,7 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
static tflite::MicroOpResolver<5> micro_op_resolver; // NOLINT
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
micro_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());

View File

@ -65,7 +65,7 @@ void setup() {
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
static tflite::MicroOpResolver<5> micro_op_resolver; // NOLINT
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
micro_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());

View File

@ -74,7 +74,7 @@ void setup() {
//
// tflite::ops::micro::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroOpResolver<4> micro_op_resolver(error_reporter);
static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter);
if (micro_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),

View File

@ -48,7 +48,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
// needed by this graph.
//
// tflite::ops::micro::AllOpsResolver resolver;
tflite::MicroOpResolver<4> micro_op_resolver;
tflite::MicroMutableOpResolver<4> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
tflite::MicroOpResolverAnyVersion());

View File

@ -65,7 +65,7 @@ void setup() {
//
// tflite::ops::micro::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroOpResolver<3> micro_op_resolver;
static tflite::MicroMutableOpResolver<3> micro_op_resolver;
micro_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());

View File

@ -54,7 +54,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
// needed by this graph.
//
// tflite::ops::micro::AllOpsResolver resolver;
tflite::MicroOpResolver<3> micro_op_resolver;
tflite::MicroMutableOpResolver<3> micro_op_resolver;
micro_op_resolver.AddBuiltin(
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());

View File

@ -72,7 +72,7 @@ void setup() {
//
// tflite::ops::micro::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroOpResolver<12> micro_op_resolver;
static tflite::MicroMutableOpResolver<12> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
1, 3);

View File

@ -52,7 +52,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
tflite::MicroOpResolver<11> micro_op_resolver;
tflite::MicroMutableOpResolver<11> micro_op_resolver;
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
1, 3);

View File

@ -19,7 +19,12 @@ namespace tflite {
namespace ops {
namespace micro {
class AllOpsResolver : public MicroMutableOpResolver {
// The magic number in the template parameter is the maximum number of ops that
// can be added to AllOpsResolver. It can be increased if needed. And most
// applications that care about the memory footprint will want to directly use
// MicroMutableOpResolver and have an application specific template parameter.
// The examples directory has sample code for this.
class AllOpsResolver : public MicroMutableOpResolver<128> {
public:
AllOpsResolver();

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {

View File

@ -24,11 +24,14 @@ limitations under the License.
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/api/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h"
#include "tensorflow/lite/micro/memory_planner/memory_planner.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
@ -431,7 +434,7 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model,
}
TfLiteStatus MicroAllocator::InitializeFromFlatbuffer(
const OpResolver& op_resolver,
const MicroOpResolver& op_resolver,
NodeAndRegistration** node_and_registrations) {
if (!active_) {
return kTfLiteError;
@ -649,7 +652,7 @@ TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations(
}
TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
const OpResolver& op_resolver,
const MicroOpResolver& op_resolver,
NodeAndRegistration* node_and_registrations) {
TfLiteStatus status = kTfLiteOk;
auto* opcodes = model_->operator_codes();
@ -697,9 +700,12 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
custom_data_size = op->custom_options()->size();
} else {
TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
&builtin_data_allocator,
(void**)(&builtin_data)));
MicroOpResolver::BuiltinParseFunction parser =
op_resolver.GetOpDataParser(op_type);
TFLITE_DCHECK(parser != nullptr);
TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_,
&builtin_data_allocator,
(void**)(&builtin_data)));
}
// Disregard const qualifier to workaround with existing API.

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/micro/simple_memory_allocator.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -92,7 +92,7 @@ class MicroAllocator {
// needs to be called before FinishTensorAllocation method. This method also
// allocates any internal Op data that is required from the flatbuffer.
TfLiteStatus InitializeFromFlatbuffer(
const OpResolver& op_resolver,
const MicroOpResolver& op_resolver,
NodeAndRegistration** node_and_registrations);
// Runs through the model and allocates all necessary input, output and
@ -145,7 +145,7 @@ class MicroAllocator {
// instance). Persistent data (e.g. operator data) is allocated from the
// arena.
TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer(
const OpResolver& op_resolver,
const MicroOpResolver& op_resolver,
NodeAndRegistration* node_and_registrations);
private:

View File

@ -21,9 +21,10 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/api/tensor_utils.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
namespace {
@ -71,7 +72,7 @@ void ContextHelper::ReportOpError(struct TfLiteContext* context,
} // namespace internal
MicroInterpreter::MicroInterpreter(const Model* model,
const OpResolver& op_resolver,
const MicroOpResolver& op_resolver,
uint8_t* tensor_arena,
size_t tensor_arena_size,
ErrorReporter* error_reporter)

View File

@ -21,9 +21,9 @@ limitations under the License.
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/type_to_tflitetype.h"
@ -72,7 +72,7 @@ class MicroInterpreter {
// function.
// The interpreter doesn't do any deallocation of any of the pointed-to
// objects, ownership remains with the caller.
MicroInterpreter(const Model* model, const OpResolver& op_resolver,
MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
uint8_t* tensor_arena, size_t tensor_arena_size,
ErrorReporter* error_reporter);
@ -160,7 +160,7 @@ class MicroInterpreter {
NodeAndRegistration* node_and_registrations_ = nullptr;
const Model* model_;
const OpResolver& op_resolver_;
const MicroOpResolver& op_resolver_;
ErrorReporter* error_reporter_;
TfLiteContext context_ = {};
MicroAllocator allocator_;

View File

@ -17,7 +17,9 @@ limitations under the License.
#include <cstdint>
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_optional_debug_tools.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/micro/test_helpers.h"
@ -147,7 +149,7 @@ class MockCustom {
}
};
class MockOpResolver : public OpResolver {
class MockOpResolver : public MicroOpResolver {
public:
const TfLiteRegistration* FindOp(BuiltinOperator op,
int version) const override {
@ -162,6 +164,22 @@ class MockOpResolver : public OpResolver {
return nullptr;
}
}
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
tflite::BuiltinOperator) const override {
// TODO(b/149408647): Figure out an alternative so that we do not have any
// references to ParseOpData in the micro code and the signature for
// MicroOpResolver::BuiltinParseFunction can be changed to be different from
// ParseOpData.
return ParseOpData;
}
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration,
int version) override {
// This function is currently not used in the tests.
return kTfLiteError;
}
};
} // namespace

View File

@ -1,4 +1,4 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -19,14 +19,11 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
#ifndef TFLITE_REGISTRATIONS_MAX
#define TFLITE_REGISTRATIONS_MAX (128)
#endif
namespace tflite {
// Op versions discussed in this file are enumerated here:
@ -34,10 +31,10 @@ namespace tflite {
inline int MicroOpResolverAnyVersion() { return 0; }
template <unsigned int tOpCount = TFLITE_REGISTRATIONS_MAX>
class MicroOpResolver : public OpResolver {
template <unsigned int tOpCount>
class MicroMutableOpResolver : public MicroOpResolver {
public:
explicit MicroOpResolver(ErrorReporter* error_reporter = nullptr)
explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
: error_reporter_(error_reporter) {}
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
@ -68,8 +65,15 @@ class MicroOpResolver : public OpResolver {
return nullptr;
}
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
tflite::BuiltinOperator) const override {
// TODO(b/149408647): Replace with the more selective builtin parser.
return ParseOpData;
}
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration, int version = 1) {
TfLiteRegistration* registration,
int version = 1) override {
if (registrations_len_ >= tOpCount) {
if (error_reporter_) {
TF_LITE_REPORT_ERROR(error_reporter_,
@ -144,14 +148,6 @@ class MicroOpResolver : public OpResolver {
TF_LITE_REMOVE_VIRTUAL_DELETE
};
// TODO(b/147854028): Consider switching all uses of MicroMutableOpResolver to
// MicroOpResolver.
class MicroMutableOpResolver
: public MicroOpResolver<TFLITE_REGISTRATIONS_MAX> {
private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};
}; // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
@ -58,7 +60,7 @@ TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestOperations) {
using tflite::BuiltinOperator_CONV_2D;
using tflite::BuiltinOperator_RELU;
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
using tflite::OpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
@ -66,7 +68,7 @@ TF_LITE_MICRO_TEST(TestOperations) {
// We need space for 7 operators because of 2 ops, one with 3 versions, one
// with 4 versions.
MicroOpResolver<7> micro_op_resolver;
MicroMutableOpResolver<7> micro_op_resolver;
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin(
BuiltinOperator_CONV_2D, &r, 1, 3));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
@ -104,32 +106,30 @@ TF_LITE_MICRO_TEST(TestOperations) {
TF_LITE_MICRO_TEST(TestOpRegistrationOverflow) {
using tflite::BuiltinOperator_CONV_2D;
using tflite::BuiltinOperator_RELU;
using tflite::MicroOpResolver;
using tflite::OpResolver;
using tflite::MicroMutableOpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
MicroOpResolver<4> micro_op_resolver;
MicroMutableOpResolver<4> micro_op_resolver;
// Register 7 ops, but only 4 is expected because the class is created with
// that limit..
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin(
BuiltinOperator_CONV_2D, &r, 0, 2));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
micro_op_resolver.AddCustom("mock_custom", &r, 0, 3));
OpResolver* resolver = &micro_op_resolver;
TF_LITE_MICRO_EXPECT_EQ(4, micro_op_resolver.GetRegistrationLength());
}
TF_LITE_MICRO_TEST(TestZeroVersionRegistration) {
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
using tflite::OpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
MicroOpResolver<1> micro_op_resolver;
MicroMutableOpResolver<1> micro_op_resolver;
micro_op_resolver.AddCustom("mock_custom", &r,
tflite::MicroOpResolverAnyVersion());
@ -157,13 +157,13 @@ TF_LITE_MICRO_TEST(TestZeroVersionRegistration) {
}
TF_LITE_MICRO_TEST(TestZeroModelVersion) {
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
using tflite::OpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
MicroOpResolver<2> micro_op_resolver;
MicroMutableOpResolver<2> micro_op_resolver;
micro_op_resolver.AddCustom("mock_custom", &r, 1, 2);
TF_LITE_MICRO_EXPECT_EQ(2, micro_op_resolver.GetRegistrationLength());
OpResolver* resolver = &micro_op_resolver;
@ -196,13 +196,13 @@ TF_LITE_MICRO_TEST(TestZeroModelVersion) {
TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) {
using tflite::BuiltinOperator_CONV_2D;
using tflite::BuiltinOperator_RELU;
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
tflite::MockErrorReporter mock_reporter;
MicroOpResolver<1> micro_op_resolver(&mock_reporter);
MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r));
@ -215,13 +215,13 @@ TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) {
TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) {
using tflite::BuiltinOperator_CONV_2D;
using tflite::BuiltinOperator_RELU;
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
tflite::MockErrorReporter mock_reporter;
MicroOpResolver<1> micro_op_resolver(&mock_reporter);
MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
micro_op_resolver.AddCustom("mock_custom_0", &r));
@ -234,13 +234,13 @@ TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) {
TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) {
using tflite::BuiltinOperator_CONV_2D;
using tflite::BuiltinOperator_RELU;
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
tflite::MockErrorReporter mock_reporter;
MicroOpResolver<2> micro_op_resolver(&mock_reporter);
MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter);
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin(
BuiltinOperator_CONV_2D, &r, 1, 2));
@ -253,13 +253,13 @@ TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) {
TF_LITE_MICRO_TEST(TestCustomVersionRegistrationErrorReporting) {
using tflite::BuiltinOperator_CONV_2D;
using tflite::BuiltinOperator_RELU;
using tflite::MicroOpResolver;
using tflite::MicroMutableOpResolver;
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
tflite::MockPrepare, tflite::MockInvoke};
tflite::MockErrorReporter mock_reporter;
MicroOpResolver<2> micro_op_resolver(&mock_reporter);
MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter);
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, micro_op_resolver.AddCustom("mock_custom_0", &r, 1, 2));

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_
#define TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace tflite {
// This is an interface for the OpResolver for TFLiteMicro. The differences from
// the TFLite OpResolver base class are to allow for finer grained registration
// of the Builtin Ops to reduce code size for TFLiteMicro. We need an interface
// class instead of directly using MicroMutableOpResolver because
// MicroMutableOpResolver is a class template with the number of registered Ops
// as the template parameter.
class MicroOpResolver : public OpResolver {
public:
// TODO(b/149408647): The op_type parameter enables a gradual transfer to
// selective registration of the parse function. It should be removed once we
// no longer need to use ParseOpData (from flatbuffer_conversions.h) as part
// of the MicroMutableOpResolver.
typedef TfLiteStatus (*BuiltinParseFunction)(const Operator* op,
BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
// Returns the operator specific parsing function for the OpData for a
// BuiltinOperator (if registered), else nullptr.
virtual BuiltinParseFunction GetOpDataParser(
tflite::BuiltinOperator op) const = 0;
virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
TfLiteRegistration* registration,
int version) = 0;
~MicroOpResolver() override {}
};
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_