diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 3b05aee30f4..36cc14512b6 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -37,6 +37,7 @@ cc_library( "micro_allocator.h", "micro_interpreter.h", "micro_mutable_op_resolver.h", + "micro_op_resolver.h", "micro_optional_debug_tools.h", "simple_memory_allocator.h", "test_helpers.h", @@ -159,6 +160,7 @@ tflite_micro_cc_test( deps = [ ":micro_framework", ":micro_utils", + "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/micro/testing:micro_test", ], diff --git a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc index fd8556f752e..50401039265 100644 --- a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc @@ -84,7 +84,7 @@ class KeywordRunner { const tflite::Model* keyword_spotting_model_; tflite::MicroErrorReporter micro_reporter_; tflite::ErrorReporter* reporter_; - tflite::MicroOpResolver<6> resolver_; + tflite::MicroMutableOpResolver<6> resolver_; tflite::MicroInterpreter interpreter_; }; diff --git a/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc b/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc index 5287a9c1e23..bec12ad8642 100644 --- a/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc @@ -98,7 +98,7 @@ class PersonDetectionRunner { const tflite::Model* person_detection_model_; tflite::MicroErrorReporter micro_reporter_; tflite::ErrorReporter* reporter_; - tflite::MicroOpResolver<6> resolver_; + tflite::MicroMutableOpResolver<6> resolver_; tflite::MicroInterpreter interpreter_; }; diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc b/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc index fd547b433ef..ac4de118834 100644 --- a/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc +++ b/tensorflow/lite/micro/examples/image_recognition_experimental/image_recognition_test.cc @@ -42,7 +42,7 @@ TF_LITE_MICRO_TEST(TestImageRecognitionInvoke) { model->version(), TFLITE_SCHEMA_VERSION); } - tflite::MicroOpResolver<4> micro_op_resolver; + tflite::MicroMutableOpResolver<4> micro_op_resolver; micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, tflite::ops::micro::Register_CONV_2D()); diff --git a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc index 09c76df0379..becdbdf1bd7 100644 --- a/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc +++ b/tensorflow/lite/micro/examples/image_recognition_experimental/main.cc @@ -56,7 +56,7 @@ int main(int argc, char** argv) { return 1; } - tflite::MicroOpResolver<4> micro_op_resolver; + tflite::MicroMutableOpResolver<4> micro_op_resolver; micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, tflite::ops::micro::Register_CONV_2D()); diff --git a/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc b/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc index ab4f41680fd..88bfad860e2 100644 --- a/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc +++ b/tensorflow/lite/micro/examples/magic_wand/magic_wand_test.cc @@ -46,7 +46,7 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) { // An easier approach is to just use the AllOpsResolver, but this will // incur some penalty in code space for op implementations that are not // needed by this graph. - static tflite::MicroOpResolver<5> micro_op_resolver; // NOLINT + static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); diff --git a/tensorflow/lite/micro/examples/magic_wand/main_functions.cc b/tensorflow/lite/micro/examples/magic_wand/main_functions.cc index 51e6e593cd1..26c2eb44747 100644 --- a/tensorflow/lite/micro/examples/magic_wand/main_functions.cc +++ b/tensorflow/lite/micro/examples/magic_wand/main_functions.cc @@ -65,7 +65,7 @@ void setup() { // An easier approach is to just use the AllOpsResolver, but this will // incur some penalty in code space for op implementations that are not // needed by this graph. - static tflite::MicroOpResolver<5> micro_op_resolver; // NOLINT + static tflite::MicroMutableOpResolver<5> micro_op_resolver; // NOLINT micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); diff --git a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc index e5e6aa7c1f7..5369008182b 100644 --- a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc +++ b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc @@ -74,7 +74,7 @@ void setup() { // // tflite::ops::micro::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) - static tflite::MicroOpResolver<4> micro_op_resolver(error_reporter); + static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter); if (micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc index a6e011b1224..b1b224c9391 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc @@ -48,7 +48,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { // needed by this graph. // // tflite::ops::micro::AllOpsResolver resolver; - tflite::MicroOpResolver<4> micro_op_resolver; + tflite::MicroMutableOpResolver<4> micro_op_resolver; micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), tflite::MicroOpResolverAnyVersion()); diff --git a/tensorflow/lite/micro/examples/person_detection/main_functions.cc b/tensorflow/lite/micro/examples/person_detection/main_functions.cc index 0e5c6394d56..6b07f6514d5 100644 --- a/tensorflow/lite/micro/examples/person_detection/main_functions.cc +++ b/tensorflow/lite/micro/examples/person_detection/main_functions.cc @@ -65,7 +65,7 @@ void setup() { // // tflite::ops::micro::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) - static tflite::MicroOpResolver<3> micro_op_resolver; + static tflite::MicroMutableOpResolver<3> micro_op_resolver; micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); diff --git a/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc b/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc index 8acb93ced17..dafed8089e3 100644 --- a/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc +++ b/tensorflow/lite/micro/examples/person_detection/person_detection_test.cc @@ -54,7 +54,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { // needed by this graph. // // tflite::ops::micro::AllOpsResolver resolver; - tflite::MicroOpResolver<3> micro_op_resolver; + tflite::MicroMutableOpResolver<3> micro_op_resolver; micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc b/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc index 92d2c091f55..6f10d5c3f27 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc +++ b/tensorflow/lite/micro/examples/person_detection_experimental/main_functions.cc @@ -72,7 +72,7 @@ void setup() { // // tflite::ops::micro::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) - static tflite::MicroOpResolver<12> micro_op_resolver; + static tflite::MicroMutableOpResolver<12> micro_op_resolver; micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), 1, 3); diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc b/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc index c3719e559ca..ea37faa15f2 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc +++ b/tensorflow/lite/micro/examples/person_detection_experimental/person_detection_test.cc @@ -52,7 +52,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { // An easier approach is to just use the AllOpsResolver, but this will // incur some penalty in code space for op implementations that are not // needed by this graph. - tflite::MicroOpResolver<11> micro_op_resolver; + tflite::MicroMutableOpResolver<11> micro_op_resolver; micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), 1, 3); diff --git a/tensorflow/lite/micro/kernels/all_ops_resolver.h b/tensorflow/lite/micro/kernels/all_ops_resolver.h index 26bb03230ed..5637316d0da 100644 --- a/tensorflow/lite/micro/kernels/all_ops_resolver.h +++ b/tensorflow/lite/micro/kernels/all_ops_resolver.h @@ -19,7 +19,12 @@ namespace tflite { namespace ops { namespace micro { -class AllOpsResolver : public MicroMutableOpResolver { +// The magic number in the template parameter is the maximum number of ops that +// can be added to AllOpsResolver. It can be increased if needed. And most +// applications that care about the memory footprint will want to directly use +// MicroMutableOpResolver and have an application specific template parameter. +// The examples directory has sample code for this. +class AllOpsResolver : public MicroMutableOpResolver<128> { public: AllOpsResolver(); diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc index 05105f83ff3..37c78162b62 100644 --- a/tensorflow/lite/micro/memory_helpers.cc +++ b/tensorflow/lite/micro/memory_helpers.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index d43f0ec076f..35f4bdabd20 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -24,11 +24,14 @@ limitations under the License. #include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/tensor_utils.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h" #include "tensorflow/lite/micro/memory_planner/memory_planner.h" +#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/micro/simple_memory_allocator.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -431,7 +434,7 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model, } TfLiteStatus MicroAllocator::InitializeFromFlatbuffer( - const OpResolver& op_resolver, + const MicroOpResolver& op_resolver, NodeAndRegistration** node_and_registrations) { if (!active_) { return kTfLiteError; @@ -649,7 +652,7 @@ TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations( } TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( - const OpResolver& op_resolver, + const MicroOpResolver& op_resolver, NodeAndRegistration* node_and_registrations) { TfLiteStatus status = kTfLiteOk; auto* opcodes = model_->operator_codes(); @@ -697,9 +700,12 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( custom_data = reinterpret_cast<const char*>(op->custom_options()->data()); custom_data_size = op->custom_options()->size(); } else { - TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, - &builtin_data_allocator, - (void**)(&builtin_data))); + MicroOpResolver::BuiltinParseFunction parser = + op_resolver.GetOpDataParser(op_type); + TFLITE_DCHECK(parser != nullptr); + TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_, + &builtin_data_allocator, + (void**)(&builtin_data))); } // Disregard const qualifier to workaround with existing API. diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h index 1dd90c36a4d..d7f7e4c9d6c 100644 --- a/tensorflow/lite/micro/micro_allocator.h +++ b/tensorflow/lite/micro/micro_allocator.h @@ -21,7 +21,7 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -92,7 +92,7 @@ class MicroAllocator { // needs to be called before FinishTensorAllocation method. This method also // allocates any internal Op data that is required from the flatbuffer. TfLiteStatus InitializeFromFlatbuffer( - const OpResolver& op_resolver, + const MicroOpResolver& op_resolver, NodeAndRegistration** node_and_registrations); // Runs through the model and allocates all necessary input, output and @@ -145,7 +145,7 @@ class MicroAllocator { // instance). Persistent data (e.g. operator data) is allocated from the // arena. TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer( - const OpResolver& op_resolver, + const MicroOpResolver& op_resolver, NodeAndRegistration* node_and_registrations); private: diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 6b78966020e..7e2e56e417d 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -21,9 +21,10 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { @@ -71,7 +72,7 @@ void ContextHelper::ReportOpError(struct TfLiteContext* context, } // namespace internal MicroInterpreter::MicroInterpreter(const Model* model, - const OpResolver& op_resolver, + const MicroOpResolver& op_resolver, uint8_t* tensor_arena, size_t tensor_arena_size, ErrorReporter* error_reporter) diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 180a557668e..a0b70527905 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -21,9 +21,9 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/type_to_tflitetype.h" @@ -72,7 +72,7 @@ class MicroInterpreter { // function. // The interpreter doesn't do any deallocation of any of the pointed-to // objects, ownership remains with the caller. - MicroInterpreter(const Model* model, const OpResolver& op_resolver, + MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver, uint8_t* tensor_arena, size_t tensor_arena_size, ErrorReporter* error_reporter); @@ -160,7 +160,7 @@ class MicroInterpreter { NodeAndRegistration* node_and_registrations_ = nullptr; const Model* model_; - const OpResolver& op_resolver_; + const MicroOpResolver& op_resolver_; ErrorReporter* error_reporter_; TfLiteContext context_ = {}; MicroAllocator allocator_; diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc index 36e8c009b96..2358f763bc0 100644 --- a/tensorflow/lite/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_test.cc @@ -17,7 +17,9 @@ limitations under the License. #include <cstdint> +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" #include "tensorflow/lite/micro/micro_optional_debug_tools.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" @@ -147,7 +149,7 @@ class MockCustom { } }; -class MockOpResolver : public OpResolver { +class MockOpResolver : public MicroOpResolver { public: const TfLiteRegistration* FindOp(BuiltinOperator op, int version) const override { @@ -162,6 +164,22 @@ class MockOpResolver : public OpResolver { return nullptr; } } + + MicroOpResolver::BuiltinParseFunction GetOpDataParser( + tflite::BuiltinOperator) const override { + // TODO(b/149408647): Figure out an alternative so that we do not have any + // references to ParseOpData in the micro code and the signature for + // MicroOpResolver::BuiltinParseFunction can be changed to be different from + // ParseOpData. + return ParseOpData; + } + + TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration, + int version) override { + // This function is currently not used in the tests. + return kTfLiteError; + } }; } // namespace diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 6c3e9a3331e..88ec1133c9f 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,14 +19,11 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" -#ifndef TFLITE_REGISTRATIONS_MAX -#define TFLITE_REGISTRATIONS_MAX (128) -#endif - namespace tflite { // Op versions discussed in this file are enumerated here: @@ -34,10 +31,10 @@ namespace tflite { inline int MicroOpResolverAnyVersion() { return 0; } -template <unsigned int tOpCount = TFLITE_REGISTRATIONS_MAX> -class MicroOpResolver : public OpResolver { +template <unsigned int tOpCount> +class MicroMutableOpResolver : public MicroOpResolver { public: - explicit MicroOpResolver(ErrorReporter* error_reporter = nullptr) + explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr) : error_reporter_(error_reporter) {} const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, @@ -68,8 +65,15 @@ class MicroOpResolver : public OpResolver { return nullptr; } + MicroOpResolver::BuiltinParseFunction GetOpDataParser( + tflite::BuiltinOperator) const override { + // TODO(b/149408647): Replace with the more selective builtin parser. + return ParseOpData; + } + TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, - TfLiteRegistration* registration, int version = 1) { + TfLiteRegistration* registration, + int version = 1) override { if (registrations_len_ >= tOpCount) { if (error_reporter_) { TF_LITE_REPORT_ERROR(error_reporter_, @@ -144,14 +148,6 @@ class MicroOpResolver : public OpResolver { TF_LITE_REMOVE_VIRTUAL_DELETE }; -// TODO(b/147854028): Consider switching all uses of MicroMutableOpResolver to -// MicroOpResolver. -class MicroMutableOpResolver - : public MicroOpResolver<TFLITE_REGISTRATIONS_MAX> { - private: - TF_LITE_REMOVE_VIRTUAL_DELETE -}; - }; // namespace tflite #endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc b/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc index 61ab0e3bec9..6b0c9974874 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc +++ b/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" + +#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/micro/testing/micro_test.h" namespace tflite { @@ -58,7 +60,7 @@ TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(TestOperations) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; using tflite::OpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, @@ -66,7 +68,7 @@ TF_LITE_MICRO_TEST(TestOperations) { // We need space for 7 operators because of 2 ops, one with 3 versions, one // with 4 versions. - MicroOpResolver<7> micro_op_resolver; + MicroMutableOpResolver<7> micro_op_resolver; TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin( BuiltinOperator_CONV_2D, &r, 1, 3)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, @@ -104,32 +106,30 @@ TF_LITE_MICRO_TEST(TestOperations) { TF_LITE_MICRO_TEST(TestOpRegistrationOverflow) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; - using tflite::MicroOpResolver; - using tflite::OpResolver; + using tflite::MicroMutableOpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; - MicroOpResolver<4> micro_op_resolver; + MicroMutableOpResolver<4> micro_op_resolver; // Register 7 ops, but only 4 is expected because the class is created with // that limit.. TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin( BuiltinOperator_CONV_2D, &r, 0, 2)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_op_resolver.AddCustom("mock_custom", &r, 0, 3)); - OpResolver* resolver = µ_op_resolver; TF_LITE_MICRO_EXPECT_EQ(4, micro_op_resolver.GetRegistrationLength()); } TF_LITE_MICRO_TEST(TestZeroVersionRegistration) { - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; using tflite::OpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; - MicroOpResolver<1> micro_op_resolver; + MicroMutableOpResolver<1> micro_op_resolver; micro_op_resolver.AddCustom("mock_custom", &r, tflite::MicroOpResolverAnyVersion()); @@ -157,13 +157,13 @@ TF_LITE_MICRO_TEST(TestZeroVersionRegistration) { } TF_LITE_MICRO_TEST(TestZeroModelVersion) { - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; using tflite::OpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; - MicroOpResolver<2> micro_op_resolver; + MicroMutableOpResolver<2> micro_op_resolver; micro_op_resolver.AddCustom("mock_custom", &r, 1, 2); TF_LITE_MICRO_EXPECT_EQ(2, micro_op_resolver.GetRegistrationLength()); OpResolver* resolver = µ_op_resolver; @@ -196,13 +196,13 @@ TF_LITE_MICRO_TEST(TestZeroModelVersion) { TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; tflite::MockErrorReporter mock_reporter; - MicroOpResolver<1> micro_op_resolver(&mock_reporter); + MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter); TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r)); @@ -215,13 +215,13 @@ TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) { TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; tflite::MockErrorReporter mock_reporter; - MicroOpResolver<1> micro_op_resolver(&mock_reporter); + MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter); TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddCustom("mock_custom_0", &r)); @@ -234,13 +234,13 @@ TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) { TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; tflite::MockErrorReporter mock_reporter; - MicroOpResolver<2> micro_op_resolver(&mock_reporter); + MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter); TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin( BuiltinOperator_CONV_2D, &r, 1, 2)); @@ -253,13 +253,13 @@ TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) { TF_LITE_MICRO_TEST(TestCustomVersionRegistrationErrorReporting) { using tflite::BuiltinOperator_CONV_2D; using tflite::BuiltinOperator_RELU; - using tflite::MicroOpResolver; + using tflite::MicroMutableOpResolver; static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, tflite::MockPrepare, tflite::MockInvoke}; tflite::MockErrorReporter mock_reporter; - MicroOpResolver<2> micro_op_resolver(&mock_reporter); + MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter); TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, micro_op_resolver.AddCustom("mock_custom_0", &r, 1, 2)); diff --git a/tensorflow/lite/micro/micro_op_resolver.h b/tensorflow/lite/micro/micro_op_resolver.h new file mode 100644 index 00000000000..64a3c85cc78 --- /dev/null +++ b/tensorflow/lite/micro/micro_op_resolver.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_ + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { + +// This is an interface for the OpResolver for TFLiteMicro. The differences from +// the TFLite OpResolver base class are to allow for finer grained registration +// of the Builtin Ops to reduce code size for TFLiteMicro. We need an interface +// class instead of directly using MicroMutableOpResolver because +// MicroMutableOpResolver is a class template with the number of registered Ops +// as the template parameter. +class MicroOpResolver : public OpResolver { + public: + // TODO(b/149408647): The op_type parameter enables a gradual transfer to + // selective registration of the parse function. It should be removed once we + // no longer need to use ParseOpData (from flatbuffer_conversions.h) as part + // of the MicroMutableOpResolver. + typedef TfLiteStatus (*BuiltinParseFunction)(const Operator* op, + BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data); + + // Returns the operator specific parsing function for the OpData for a + // BuiltinOperator (if registered), else nullptr. + virtual BuiltinParseFunction GetOpDataParser( + tflite::BuiltinOperator op) const = 0; + + virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration, + int version) = 0; + + ~MicroOpResolver() override {} +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_