Add MicroOpResolver interface class.
This will allow us to implement selective registration of the builtin parse functions without changing the OpResolver base class in TFLite. * MicroOpResolver is now an interface (matching the OpResolver name in TFLite). * MicroMutableOpResolver is the implementation of the MicroOpResolver interface that should be used by applications that do not want to use AllOpsResolver. PiperOrigin-RevId: 313691276 Change-Id: I0a9f51f6584326a3b3dd645cde083ba42116083d
This commit is contained in:
parent
f9fb66cdb7
commit
33689c48ad
tensorflow/lite/micro
BUILD
benchmarks
examples
image_recognition_experimental
magic_wand
micro_speech
person_detection
person_detection_experimental
kernels
memory_helpers.ccmicro_allocator.ccmicro_allocator.hmicro_interpreter.ccmicro_interpreter.hmicro_interpreter_test.ccmicro_mutable_op_resolver.hmicro_mutable_op_resolver_test.ccmicro_op_resolver.h@ -37,6 +37,7 @@ cc_library(
|
||||
"micro_allocator.h",
|
||||
"micro_interpreter.h",
|
||||
"micro_mutable_op_resolver.h",
|
||||
"micro_op_resolver.h",
|
||||
"micro_optional_debug_tools.h",
|
||||
"simple_memory_allocator.h",
|
||||
"test_helpers.h",
|
||||
@ -159,6 +160,7 @@ tflite_micro_cc_test(
|
||||
deps = [
|
||||
":micro_framework",
|
||||
":micro_utils",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
],
|
||||
|
@ -84,7 +84,7 @@ class KeywordRunner {
|
||||
const tflite::Model* keyword_spotting_model_;
|
||||
tflite::MicroErrorReporter micro_reporter_;
|
||||
tflite::ErrorReporter* reporter_;
|
||||
tflite::MicroOpResolver<6> resolver_;
|
||||
tflite::MicroMutableOpResolver<6> resolver_;
|
||||
tflite::MicroInterpreter interpreter_;
|
||||
};
|
||||
|
||||
|
@ -98,7 +98,7 @@ class PersonDetectionRunner {
|
||||
const tflite::Model* person_detection_model_;
|
||||
tflite::MicroErrorReporter micro_reporter_;
|
||||
tflite::ErrorReporter* reporter_;
|
||||
tflite::MicroOpResolver<6> resolver_;
|
||||
tflite::MicroMutableOpResolver<6> resolver_;
|
||||
tflite::MicroInterpreter interpreter_;
|
||||
};
|
||||
|
||||
|
@ -42,7 +42,7 @@ TF_LITE_MICRO_TEST(TestImageRecognitionInvoke) {
|
||||
model->version(), TFLITE_SCHEMA_VERSION);
|
||||
}
|
||||
|
||||
tflite::MicroOpResolver<4> micro_op_resolver;
|
||||
tflite::MicroMutableOpResolver<4> micro_op_resolver;
|
||||
|
||||
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
|
||||
tflite::ops::micro::Register_CONV_2D());
|
||||
|
@ -56,7 +56,7 @@ int main(int argc, char** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
tflite::MicroOpResolver<4> micro_op_resolver;
|
||||
tflite::MicroMutableOpResolver<4> micro_op_resolver;
|
||||
|
||||
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
|
||||
tflite::ops::micro::Register_CONV_2D());
|
||||
|
@ -46,7 +46,7 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
|
||||
// An easier approach is to just use the AllOpsResolver, but this will
|
||||
// incur some penalty in code space for op implementations that are not
|
||||
// needed by this graph.
|
||||
static tflite::MicroOpResolver<5> micro_op_resolver; // NOLINT
|
||||
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
|
||||
micro_op_resolver.AddBuiltin(
|
||||
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
|
||||
|
@ -65,7 +65,7 @@ void setup() {
|
||||
// An easier approach is to just use the AllOpsResolver, but this will
|
||||
// incur some penalty in code space for op implementations that are not
|
||||
// needed by this graph.
|
||||
static tflite::MicroOpResolver<5> micro_op_resolver; // NOLINT
|
||||
static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT
|
||||
micro_op_resolver.AddBuiltin(
|
||||
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
|
||||
|
@ -74,7 +74,7 @@ void setup() {
|
||||
//
|
||||
// tflite::ops::micro::AllOpsResolver resolver;
|
||||
// NOLINTNEXTLINE(runtime-global-variables)
|
||||
static tflite::MicroOpResolver<4> micro_op_resolver(error_reporter);
|
||||
static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter);
|
||||
if (micro_op_resolver.AddBuiltin(
|
||||
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
|
||||
|
@ -48,7 +48,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
|
||||
// needed by this graph.
|
||||
//
|
||||
// tflite::ops::micro::AllOpsResolver resolver;
|
||||
tflite::MicroOpResolver<4> micro_op_resolver;
|
||||
tflite::MicroMutableOpResolver<4> micro_op_resolver;
|
||||
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
|
||||
tflite::MicroOpResolverAnyVersion());
|
||||
|
@ -65,7 +65,7 @@ void setup() {
|
||||
//
|
||||
// tflite::ops::micro::AllOpsResolver resolver;
|
||||
// NOLINTNEXTLINE(runtime-global-variables)
|
||||
static tflite::MicroOpResolver<3> micro_op_resolver;
|
||||
static tflite::MicroMutableOpResolver<3> micro_op_resolver;
|
||||
micro_op_resolver.AddBuiltin(
|
||||
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
|
||||
|
@ -54,7 +54,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
|
||||
// needed by this graph.
|
||||
//
|
||||
// tflite::ops::micro::AllOpsResolver resolver;
|
||||
tflite::MicroOpResolver<3> micro_op_resolver;
|
||||
tflite::MicroMutableOpResolver<3> micro_op_resolver;
|
||||
micro_op_resolver.AddBuiltin(
|
||||
tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
|
||||
|
@ -72,7 +72,7 @@ void setup() {
|
||||
//
|
||||
// tflite::ops::micro::AllOpsResolver resolver;
|
||||
// NOLINTNEXTLINE(runtime-global-variables)
|
||||
static tflite::MicroOpResolver<12> micro_op_resolver;
|
||||
static tflite::MicroMutableOpResolver<12> micro_op_resolver;
|
||||
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
|
||||
1, 3);
|
||||
|
@ -52,7 +52,7 @@ TF_LITE_MICRO_TEST(TestInvoke) {
|
||||
// An easier approach is to just use the AllOpsResolver, but this will
|
||||
// incur some penalty in code space for op implementations that are not
|
||||
// needed by this graph.
|
||||
tflite::MicroOpResolver<11> micro_op_resolver;
|
||||
tflite::MicroMutableOpResolver<11> micro_op_resolver;
|
||||
micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
tflite::ops::micro::Register_DEPTHWISE_CONV_2D(),
|
||||
1, 3);
|
||||
|
@ -19,7 +19,12 @@ namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
|
||||
class AllOpsResolver : public MicroMutableOpResolver {
|
||||
// The magic number in the template parameter is the maximum number of ops that
|
||||
// can be added to AllOpsResolver. It can be increased if needed. And most
|
||||
// applications that care about the memory footprint will want to directly use
|
||||
// MicroMutableOpResolver and have an application specific template parameter.
|
||||
// The examples directory has sample code for this.
|
||||
class AllOpsResolver : public MicroMutableOpResolver<128> {
|
||||
public:
|
||||
AllOpsResolver();
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
|
@ -24,11 +24,14 @@ limitations under the License.
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/api/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/micro/compatibility.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h"
|
||||
#include "tensorflow/lite/micro/memory_planner/memory_planner.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@ -431,7 +434,7 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model,
|
||||
}
|
||||
|
||||
TfLiteStatus MicroAllocator::InitializeFromFlatbuffer(
|
||||
const OpResolver& op_resolver,
|
||||
const MicroOpResolver& op_resolver,
|
||||
NodeAndRegistration** node_and_registrations) {
|
||||
if (!active_) {
|
||||
return kTfLiteError;
|
||||
@ -649,7 +652,7 @@ TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations(
|
||||
}
|
||||
|
||||
TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
const OpResolver& op_resolver,
|
||||
const MicroOpResolver& op_resolver,
|
||||
NodeAndRegistration* node_and_registrations) {
|
||||
TfLiteStatus status = kTfLiteOk;
|
||||
auto* opcodes = model_->operator_codes();
|
||||
@ -697,9 +700,12 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
|
||||
custom_data_size = op->custom_options()->size();
|
||||
} else {
|
||||
TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
|
||||
&builtin_data_allocator,
|
||||
(void**)(&builtin_data)));
|
||||
MicroOpResolver::BuiltinParseFunction parser =
|
||||
op_resolver.GetOpDataParser(op_type);
|
||||
TFLITE_DCHECK(parser != nullptr);
|
||||
TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_,
|
||||
&builtin_data_allocator,
|
||||
(void**)(&builtin_data)));
|
||||
}
|
||||
|
||||
// Disregard const qualifier to workaround with existing API.
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/micro/simple_memory_allocator.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -92,7 +92,7 @@ class MicroAllocator {
|
||||
// needs to be called before FinishTensorAllocation method. This method also
|
||||
// allocates any internal Op data that is required from the flatbuffer.
|
||||
TfLiteStatus InitializeFromFlatbuffer(
|
||||
const OpResolver& op_resolver,
|
||||
const MicroOpResolver& op_resolver,
|
||||
NodeAndRegistration** node_and_registrations);
|
||||
|
||||
// Runs through the model and allocates all necessary input, output and
|
||||
@ -145,7 +145,7 @@ class MicroAllocator {
|
||||
// instance). Persistent data (e.g. operator data) is allocated from the
|
||||
// arena.
|
||||
TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer(
|
||||
const OpResolver& op_resolver,
|
||||
const MicroOpResolver& op_resolver,
|
||||
NodeAndRegistration* node_and_registrations);
|
||||
|
||||
private:
|
||||
|
@ -21,9 +21,10 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/api/tensor_utils.h"
|
||||
#include "tensorflow/lite/micro/micro_allocator.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@ -71,7 +72,7 @@ void ContextHelper::ReportOpError(struct TfLiteContext* context,
|
||||
} // namespace internal
|
||||
|
||||
MicroInterpreter::MicroInterpreter(const Model* model,
|
||||
const OpResolver& op_resolver,
|
||||
const MicroOpResolver& op_resolver,
|
||||
uint8_t* tensor_arena,
|
||||
size_t tensor_arena_size,
|
||||
ErrorReporter* error_reporter)
|
||||
|
@ -21,9 +21,9 @@ limitations under the License.
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/micro/micro_allocator.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/type_to_tflitetype.h"
|
||||
|
||||
@ -72,7 +72,7 @@ class MicroInterpreter {
|
||||
// function.
|
||||
// The interpreter doesn't do any deallocation of any of the pointed-to
|
||||
// objects, ownership remains with the caller.
|
||||
MicroInterpreter(const Model* model, const OpResolver& op_resolver,
|
||||
MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
|
||||
uint8_t* tensor_arena, size_t tensor_arena_size,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
@ -160,7 +160,7 @@ class MicroInterpreter {
|
||||
NodeAndRegistration* node_and_registrations_ = nullptr;
|
||||
|
||||
const Model* model_;
|
||||
const OpResolver& op_resolver_;
|
||||
const MicroOpResolver& op_resolver_;
|
||||
ErrorReporter* error_reporter_;
|
||||
TfLiteContext context_ = {};
|
||||
MicroAllocator allocator_;
|
||||
|
@ -17,7 +17,9 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/micro/micro_optional_debug_tools.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
@ -147,7 +149,7 @@ class MockCustom {
|
||||
}
|
||||
};
|
||||
|
||||
class MockOpResolver : public OpResolver {
|
||||
class MockOpResolver : public MicroOpResolver {
|
||||
public:
|
||||
const TfLiteRegistration* FindOp(BuiltinOperator op,
|
||||
int version) const override {
|
||||
@ -162,6 +164,22 @@ class MockOpResolver : public OpResolver {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
|
||||
tflite::BuiltinOperator) const override {
|
||||
// TODO(b/149408647): Figure out an alternative so that we do not have any
|
||||
// references to ParseOpData in the micro code and the signature for
|
||||
// MicroOpResolver::BuiltinParseFunction can be changed to be different from
|
||||
// ParseOpData.
|
||||
return ParseOpData;
|
||||
}
|
||||
|
||||
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
|
||||
TfLiteRegistration* registration,
|
||||
int version) override {
|
||||
// This function is currently not used in the tests.
|
||||
return kTfLiteError;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -19,14 +19,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/micro/compatibility.h"
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
#ifndef TFLITE_REGISTRATIONS_MAX
|
||||
#define TFLITE_REGISTRATIONS_MAX (128)
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Op versions discussed in this file are enumerated here:
|
||||
@ -34,10 +31,10 @@ namespace tflite {
|
||||
|
||||
inline int MicroOpResolverAnyVersion() { return 0; }
|
||||
|
||||
template <unsigned int tOpCount = TFLITE_REGISTRATIONS_MAX>
|
||||
class MicroOpResolver : public OpResolver {
|
||||
template <unsigned int tOpCount>
|
||||
class MicroMutableOpResolver : public MicroOpResolver {
|
||||
public:
|
||||
explicit MicroOpResolver(ErrorReporter* error_reporter = nullptr)
|
||||
explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr)
|
||||
: error_reporter_(error_reporter) {}
|
||||
|
||||
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
|
||||
@ -68,8 +65,15 @@ class MicroOpResolver : public OpResolver {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MicroOpResolver::BuiltinParseFunction GetOpDataParser(
|
||||
tflite::BuiltinOperator) const override {
|
||||
// TODO(b/149408647): Replace with the more selective builtin parser.
|
||||
return ParseOpData;
|
||||
}
|
||||
|
||||
TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
|
||||
TfLiteRegistration* registration, int version = 1) {
|
||||
TfLiteRegistration* registration,
|
||||
int version = 1) override {
|
||||
if (registrations_len_ >= tOpCount) {
|
||||
if (error_reporter_) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter_,
|
||||
@ -144,14 +148,6 @@ class MicroOpResolver : public OpResolver {
|
||||
TF_LITE_REMOVE_VIRTUAL_DELETE
|
||||
};
|
||||
|
||||
// TODO(b/147854028): Consider switching all uses of MicroMutableOpResolver to
|
||||
// MicroOpResolver.
|
||||
class MicroMutableOpResolver
|
||||
: public MicroOpResolver<TFLITE_REGISTRATIONS_MAX> {
|
||||
private:
|
||||
TF_LITE_REMOVE_VIRTUAL_DELETE
|
||||
};
|
||||
|
||||
}; // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
|
||||
|
||||
#include "tensorflow/lite/micro/micro_op_resolver.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -58,7 +60,7 @@ TF_LITE_MICRO_TESTS_BEGIN
|
||||
TF_LITE_MICRO_TEST(TestOperations) {
|
||||
using tflite::BuiltinOperator_CONV_2D;
|
||||
using tflite::BuiltinOperator_RELU;
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
using tflite::OpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
@ -66,7 +68,7 @@ TF_LITE_MICRO_TEST(TestOperations) {
|
||||
|
||||
// We need space for 7 operators because of 2 ops, one with 3 versions, one
|
||||
// with 4 versions.
|
||||
MicroOpResolver<7> micro_op_resolver;
|
||||
MicroMutableOpResolver<7> micro_op_resolver;
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin(
|
||||
BuiltinOperator_CONV_2D, &r, 1, 3));
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
@ -104,32 +106,30 @@ TF_LITE_MICRO_TEST(TestOperations) {
|
||||
TF_LITE_MICRO_TEST(TestOpRegistrationOverflow) {
|
||||
using tflite::BuiltinOperator_CONV_2D;
|
||||
using tflite::BuiltinOperator_RELU;
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::OpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
MicroOpResolver<4> micro_op_resolver;
|
||||
MicroMutableOpResolver<4> micro_op_resolver;
|
||||
// Register 7 ops, but only 4 is expected because the class is created with
|
||||
// that limit..
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin(
|
||||
BuiltinOperator_CONV_2D, &r, 0, 2));
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
|
||||
micro_op_resolver.AddCustom("mock_custom", &r, 0, 3));
|
||||
OpResolver* resolver = µ_op_resolver;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(4, micro_op_resolver.GetRegistrationLength());
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestZeroVersionRegistration) {
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
using tflite::OpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
MicroOpResolver<1> micro_op_resolver;
|
||||
MicroMutableOpResolver<1> micro_op_resolver;
|
||||
micro_op_resolver.AddCustom("mock_custom", &r,
|
||||
tflite::MicroOpResolverAnyVersion());
|
||||
|
||||
@ -157,13 +157,13 @@ TF_LITE_MICRO_TEST(TestZeroVersionRegistration) {
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestZeroModelVersion) {
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
using tflite::OpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
MicroOpResolver<2> micro_op_resolver;
|
||||
MicroMutableOpResolver<2> micro_op_resolver;
|
||||
micro_op_resolver.AddCustom("mock_custom", &r, 1, 2);
|
||||
TF_LITE_MICRO_EXPECT_EQ(2, micro_op_resolver.GetRegistrationLength());
|
||||
OpResolver* resolver = µ_op_resolver;
|
||||
@ -196,13 +196,13 @@ TF_LITE_MICRO_TEST(TestZeroModelVersion) {
|
||||
TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) {
|
||||
using tflite::BuiltinOperator_CONV_2D;
|
||||
using tflite::BuiltinOperator_RELU;
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
tflite::MockErrorReporter mock_reporter;
|
||||
MicroOpResolver<1> micro_op_resolver(&mock_reporter);
|
||||
MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r));
|
||||
@ -215,13 +215,13 @@ TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) {
|
||||
TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) {
|
||||
using tflite::BuiltinOperator_CONV_2D;
|
||||
using tflite::BuiltinOperator_RELU;
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
tflite::MockErrorReporter mock_reporter;
|
||||
MicroOpResolver<1> micro_op_resolver(&mock_reporter);
|
||||
MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
|
||||
micro_op_resolver.AddCustom("mock_custom_0", &r));
|
||||
@ -234,13 +234,13 @@ TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) {
|
||||
TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) {
|
||||
using tflite::BuiltinOperator_CONV_2D;
|
||||
using tflite::BuiltinOperator_RELU;
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
tflite::MockErrorReporter mock_reporter;
|
||||
MicroOpResolver<2> micro_op_resolver(&mock_reporter);
|
||||
MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin(
|
||||
BuiltinOperator_CONV_2D, &r, 1, 2));
|
||||
@ -253,13 +253,13 @@ TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) {
|
||||
TF_LITE_MICRO_TEST(TestCustomVersionRegistrationErrorReporting) {
|
||||
using tflite::BuiltinOperator_CONV_2D;
|
||||
using tflite::BuiltinOperator_RELU;
|
||||
using tflite::MicroOpResolver;
|
||||
using tflite::MicroMutableOpResolver;
|
||||
|
||||
static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
|
||||
tflite::MockPrepare, tflite::MockInvoke};
|
||||
|
||||
tflite::MockErrorReporter mock_reporter;
|
||||
MicroOpResolver<2> micro_op_resolver(&mock_reporter);
|
||||
MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter);
|
||||
TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled());
|
||||
TF_LITE_MICRO_EXPECT_EQ(
|
||||
kTfLiteOk, micro_op_resolver.AddCustom("mock_custom_0", &r, 1, 2));
|
||||
|
58
tensorflow/lite/micro/micro_op_resolver.h
Normal file
58
tensorflow/lite/micro/micro_op_resolver.h
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_
|
||||
#define TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/flatbuffer_conversions.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// This is an interface for the OpResolver for TFLiteMicro. The differences from
|
||||
// the TFLite OpResolver base class are to allow for finer grained registration
|
||||
// of the Builtin Ops to reduce code size for TFLiteMicro. We need an interface
|
||||
// class instead of directly using MicroMutableOpResolver because
|
||||
// MicroMutableOpResolver is a class template with the number of registered Ops
|
||||
// as the template parameter.
|
||||
class MicroOpResolver : public OpResolver {
|
||||
public:
|
||||
// TODO(b/149408647): The op_type parameter enables a gradual transfer to
|
||||
// selective registration of the parse function. It should be removed once we
|
||||
// no longer need to use ParseOpData (from flatbuffer_conversions.h) as part
|
||||
// of the MicroMutableOpResolver.
|
||||
typedef TfLiteStatus (*BuiltinParseFunction)(const Operator* op,
|
||||
BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
// Returns the operator specific parsing function for the OpData for a
|
||||
// BuiltinOperator (if registered), else nullptr.
|
||||
virtual BuiltinParseFunction GetOpDataParser(
|
||||
tflite::BuiltinOperator op) const = 0;
|
||||
|
||||
virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op,
|
||||
TfLiteRegistration* registration,
|
||||
int version) = 0;
|
||||
|
||||
~MicroOpResolver() override {}
|
||||
};
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_
|
Loading…
Reference in New Issue
Block a user