Add MicroOpResolver interface class.
This will allow us to implement selective registration of the builtin parse functions without changing the OpResolver base class in TFLite. * MicroOpResolver is now an interface (matching the OpResolver name in TFLite). * MicroMutableOpResolver is the implementation of the MicroOpResolver interface that should be used by applications that do not want to use AllOpsResolver. PiperOrigin-RevId: 313691276 Change-Id: I0a9f51f6584326a3b3dd645cde083ba42116083d
This commit is contained in:
		
							parent
							
								
									f9fb66cdb7
								
							
						
					
					
						commit
						33689c48ad
					
				| @ -37,6 +37,7 @@ cc_library( | ||||
|         "micro_allocator.h", | ||||
|         "micro_interpreter.h", | ||||
|         "micro_mutable_op_resolver.h", | ||||
|         "micro_op_resolver.h", | ||||
|         "micro_optional_debug_tools.h", | ||||
|         "simple_memory_allocator.h", | ||||
|         "test_helpers.h", | ||||
| @ -159,6 +160,7 @@ tflite_micro_cc_test( | ||||
|     deps = [ | ||||
|         ":micro_framework", | ||||
|         ":micro_utils", | ||||
|         "//tensorflow/lite/core/api", | ||||
|         "//tensorflow/lite/kernels:kernel_util", | ||||
|         "//tensorflow/lite/micro/testing:micro_test", | ||||
|     ], | ||||
|  | ||||
| @ -84,7 +84,7 @@ class KeywordRunner { | ||||
|   const tflite::Model* keyword_spotting_model_; | ||||
|   tflite::MicroErrorReporter micro_reporter_; | ||||
|   tflite::ErrorReporter* reporter_; | ||||
|   tflite::MicroOpResolver<6> resolver_; | ||||
|   tflite::MicroMutableOpResolver<6> resolver_; | ||||
|   tflite::MicroInterpreter interpreter_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -98,7 +98,7 @@ class PersonDetectionRunner { | ||||
|   const tflite::Model* person_detection_model_; | ||||
|   tflite::MicroErrorReporter micro_reporter_; | ||||
|   tflite::ErrorReporter* reporter_; | ||||
|   tflite::MicroOpResolver<6> resolver_; | ||||
|   tflite::MicroMutableOpResolver<6> resolver_; | ||||
|   tflite::MicroInterpreter interpreter_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -42,7 +42,7 @@ TF_LITE_MICRO_TEST(TestImageRecognitionInvoke) { | ||||
|                          model->version(), TFLITE_SCHEMA_VERSION); | ||||
|   } | ||||
| 
 | ||||
|   tflite::MicroOpResolver<4> micro_op_resolver; | ||||
|   tflite::MicroMutableOpResolver<4> micro_op_resolver; | ||||
| 
 | ||||
|   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, | ||||
|                                tflite::ops::micro::Register_CONV_2D()); | ||||
|  | ||||
| @ -56,7 +56,7 @@ int main(int argc, char** argv) { | ||||
|     return 1; | ||||
|   } | ||||
| 
 | ||||
|   tflite::MicroOpResolver<4> micro_op_resolver; | ||||
|   tflite::MicroMutableOpResolver<4> micro_op_resolver; | ||||
| 
 | ||||
|   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D, | ||||
|                                tflite::ops::micro::Register_CONV_2D()); | ||||
|  | ||||
| @ -46,7 +46,7 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) { | ||||
|   // An easier approach is to just use the AllOpsResolver, but this will
 | ||||
|   // incur some penalty in code space for op implementations that are not
 | ||||
|   // needed by this graph.
 | ||||
|   static tflite::MicroOpResolver<5> micro_op_resolver;  // NOLINT
 | ||||
|   static tflite::MicroMutableOpResolver<5> micro_op_resolver;  // NOLINT
 | ||||
|   micro_op_resolver.AddBuiltin( | ||||
|       tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|       tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); | ||||
|  | ||||
| @ -65,7 +65,7 @@ void setup() { | ||||
|   // An easier approach is to just use the AllOpsResolver, but this will
 | ||||
|   // incur some penalty in code space for op implementations that are not
 | ||||
|   // needed by this graph.
 | ||||
|   static tflite::MicroOpResolver<5> micro_op_resolver;  // NOLINT
 | ||||
|   static tflite::MicroMutableOpResolver<5> micro_op_resolver;  // NOLINT
 | ||||
|   micro_op_resolver.AddBuiltin( | ||||
|       tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|       tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); | ||||
|  | ||||
| @ -74,7 +74,7 @@ void setup() { | ||||
|   //
 | ||||
|   // tflite::ops::micro::AllOpsResolver resolver;
 | ||||
|   // NOLINTNEXTLINE(runtime-global-variables)
 | ||||
|   static tflite::MicroOpResolver<4> micro_op_resolver(error_reporter); | ||||
|   static tflite::MicroMutableOpResolver<4> micro_op_resolver(error_reporter); | ||||
|   if (micro_op_resolver.AddBuiltin( | ||||
|           tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|           tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), | ||||
|  | ||||
| @ -48,7 +48,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { | ||||
|   // needed by this graph.
 | ||||
|   //
 | ||||
|   // tflite::ops::micro::AllOpsResolver resolver;
 | ||||
|   tflite::MicroOpResolver<4> micro_op_resolver; | ||||
|   tflite::MicroMutableOpResolver<4> micro_op_resolver; | ||||
|   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|                                tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), | ||||
|                                tflite::MicroOpResolverAnyVersion()); | ||||
|  | ||||
| @ -65,7 +65,7 @@ void setup() { | ||||
|   //
 | ||||
|   // tflite::ops::micro::AllOpsResolver resolver;
 | ||||
|   // NOLINTNEXTLINE(runtime-global-variables)
 | ||||
|   static tflite::MicroOpResolver<3> micro_op_resolver; | ||||
|   static tflite::MicroMutableOpResolver<3> micro_op_resolver; | ||||
|   micro_op_resolver.AddBuiltin( | ||||
|       tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|       tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); | ||||
|  | ||||
| @ -54,7 +54,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { | ||||
|   // needed by this graph.
 | ||||
|   //
 | ||||
|   // tflite::ops::micro::AllOpsResolver resolver;
 | ||||
|   tflite::MicroOpResolver<3> micro_op_resolver; | ||||
|   tflite::MicroMutableOpResolver<3> micro_op_resolver; | ||||
|   micro_op_resolver.AddBuiltin( | ||||
|       tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|       tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); | ||||
|  | ||||
| @ -72,7 +72,7 @@ void setup() { | ||||
|   //
 | ||||
|   // tflite::ops::micro::AllOpsResolver resolver;
 | ||||
|   // NOLINTNEXTLINE(runtime-global-variables)
 | ||||
|   static tflite::MicroOpResolver<12> micro_op_resolver; | ||||
|   static tflite::MicroMutableOpResolver<12> micro_op_resolver; | ||||
|   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|                                tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), | ||||
|                                1, 3); | ||||
|  | ||||
| @ -52,7 +52,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { | ||||
|   // An easier approach is to just use the AllOpsResolver, but this will
 | ||||
|   // incur some penalty in code space for op implementations that are not
 | ||||
|   // needed by this graph.
 | ||||
|   tflite::MicroOpResolver<11> micro_op_resolver; | ||||
|   tflite::MicroMutableOpResolver<11> micro_op_resolver; | ||||
|   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, | ||||
|                                tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), | ||||
|                                1, 3); | ||||
|  | ||||
| @ -19,7 +19,12 @@ namespace tflite { | ||||
| namespace ops { | ||||
| namespace micro { | ||||
| 
 | ||||
| class AllOpsResolver : public MicroMutableOpResolver { | ||||
| // The magic number in the template parameter is the maximum number of ops that
 | ||||
| // can be added to AllOpsResolver. It can be increased if needed. And most
 | ||||
| // applications that care about the memory footprint will want to directly use
 | ||||
| // MicroMutableOpResolver and have an application specific template parameter.
 | ||||
| // The examples directory has sample code for this.
 | ||||
| class AllOpsResolver : public MicroMutableOpResolver<128> { | ||||
|  public: | ||||
|   AllOpsResolver(); | ||||
| 
 | ||||
|  | ||||
| @ -22,6 +22,7 @@ limitations under the License. | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/flatbuffer_conversions.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| 
 | ||||
|  | ||||
| @ -24,11 +24,14 @@ limitations under the License. | ||||
| #include "tensorflow/lite/core/api/flatbuffer_conversions.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/api/tensor_utils.h" | ||||
| #include "tensorflow/lite/kernels/internal/compatibility.h" | ||||
| #include "tensorflow/lite/micro/compatibility.h" | ||||
| #include "tensorflow/lite/micro/memory_helpers.h" | ||||
| #include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h" | ||||
| #include "tensorflow/lite/micro/memory_planner/memory_planner.h" | ||||
| #include "tensorflow/lite/micro/micro_op_resolver.h" | ||||
| #include "tensorflow/lite/micro/simple_memory_allocator.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| 
 | ||||
| @ -431,7 +434,7 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model, | ||||
| } | ||||
| 
 | ||||
| TfLiteStatus MicroAllocator::InitializeFromFlatbuffer( | ||||
|     const OpResolver& op_resolver, | ||||
|     const MicroOpResolver& op_resolver, | ||||
|     NodeAndRegistration** node_and_registrations) { | ||||
|   if (!active_) { | ||||
|     return kTfLiteError; | ||||
| @ -649,7 +652,7 @@ TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations( | ||||
| } | ||||
| 
 | ||||
| TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( | ||||
|     const OpResolver& op_resolver, | ||||
|     const MicroOpResolver& op_resolver, | ||||
|     NodeAndRegistration* node_and_registrations) { | ||||
|   TfLiteStatus status = kTfLiteOk; | ||||
|   auto* opcodes = model_->operator_codes(); | ||||
| @ -697,9 +700,12 @@ TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( | ||||
|       custom_data = reinterpret_cast<const char*>(op->custom_options()->data()); | ||||
|       custom_data_size = op->custom_options()->size(); | ||||
|     } else { | ||||
|       TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, | ||||
|                                         &builtin_data_allocator, | ||||
|                                         (void**)(&builtin_data))); | ||||
|       MicroOpResolver::BuiltinParseFunction parser = | ||||
|           op_resolver.GetOpDataParser(op_type); | ||||
|       TFLITE_DCHECK(parser != nullptr); | ||||
|       TF_LITE_ENSURE_STATUS(parser(op, op_type, error_reporter_, | ||||
|                                    &builtin_data_allocator, | ||||
|                                    (void**)(&builtin_data))); | ||||
|     } | ||||
| 
 | ||||
|     // Disregard const qualifier to workaround with existing API.
 | ||||
|  | ||||
| @ -21,7 +21,7 @@ limitations under the License. | ||||
| #include "flatbuffers/flatbuffers.h"  // from @flatbuffers | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/micro/micro_op_resolver.h" | ||||
| #include "tensorflow/lite/micro/simple_memory_allocator.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| @ -92,7 +92,7 @@ class MicroAllocator { | ||||
|   // needs to be called before FinishTensorAllocation method. This method also
 | ||||
|   // allocates any internal Op data that is required from the flatbuffer.
 | ||||
|   TfLiteStatus InitializeFromFlatbuffer( | ||||
|       const OpResolver& op_resolver, | ||||
|       const MicroOpResolver& op_resolver, | ||||
|       NodeAndRegistration** node_and_registrations); | ||||
| 
 | ||||
|   // Runs through the model and allocates all necessary input, output and
 | ||||
| @ -145,7 +145,7 @@ class MicroAllocator { | ||||
|   // instance). Persistent data (e.g. operator data) is allocated from the
 | ||||
|   // arena.
 | ||||
|   TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer( | ||||
|       const OpResolver& op_resolver, | ||||
|       const MicroOpResolver& op_resolver, | ||||
|       NodeAndRegistration* node_and_registrations); | ||||
| 
 | ||||
|  private: | ||||
|  | ||||
| @ -21,9 +21,10 @@ limitations under the License. | ||||
| #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
 | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/api/tensor_utils.h" | ||||
| #include "tensorflow/lite/micro/micro_allocator.h" | ||||
| #include "tensorflow/lite/micro/micro_op_resolver.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace { | ||||
| @ -71,7 +72,7 @@ void ContextHelper::ReportOpError(struct TfLiteContext* context, | ||||
| }  // namespace internal
 | ||||
| 
 | ||||
| MicroInterpreter::MicroInterpreter(const Model* model, | ||||
|                                    const OpResolver& op_resolver, | ||||
|                                    const MicroOpResolver& op_resolver, | ||||
|                                    uint8_t* tensor_arena, | ||||
|                                    size_t tensor_arena_size, | ||||
|                                    ErrorReporter* error_reporter) | ||||
|  | ||||
| @ -21,9 +21,9 @@ limitations under the License. | ||||
| #include "flatbuffers/flatbuffers.h"  // from @flatbuffers | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" | ||||
| #include "tensorflow/lite/micro/micro_allocator.h" | ||||
| #include "tensorflow/lite/micro/micro_op_resolver.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| #include "tensorflow/lite/type_to_tflitetype.h" | ||||
| 
 | ||||
| @ -72,7 +72,7 @@ class MicroInterpreter { | ||||
|   // function.
 | ||||
|   // The interpreter doesn't do any deallocation of any of the pointed-to
 | ||||
|   // objects, ownership remains with the caller.
 | ||||
|   MicroInterpreter(const Model* model, const OpResolver& op_resolver, | ||||
|   MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver, | ||||
|                    uint8_t* tensor_arena, size_t tensor_arena_size, | ||||
|                    ErrorReporter* error_reporter); | ||||
| 
 | ||||
| @ -160,7 +160,7 @@ class MicroInterpreter { | ||||
|   NodeAndRegistration* node_and_registrations_ = nullptr; | ||||
| 
 | ||||
|   const Model* model_; | ||||
|   const OpResolver& op_resolver_; | ||||
|   const MicroOpResolver& op_resolver_; | ||||
|   ErrorReporter* error_reporter_; | ||||
|   TfLiteContext context_ = {}; | ||||
|   MicroAllocator allocator_; | ||||
|  | ||||
| @ -17,7 +17,9 @@ limitations under the License. | ||||
| 
 | ||||
| #include <cstdint> | ||||
| 
 | ||||
| #include "tensorflow/lite/core/api/flatbuffer_conversions.h" | ||||
| #include "tensorflow/lite/kernels/kernel_util.h" | ||||
| #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/micro/micro_optional_debug_tools.h" | ||||
| #include "tensorflow/lite/micro/micro_utils.h" | ||||
| #include "tensorflow/lite/micro/test_helpers.h" | ||||
| @ -147,7 +149,7 @@ class MockCustom { | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class MockOpResolver : public OpResolver { | ||||
| class MockOpResolver : public MicroOpResolver { | ||||
|  public: | ||||
|   const TfLiteRegistration* FindOp(BuiltinOperator op, | ||||
|                                    int version) const override { | ||||
| @ -162,6 +164,22 @@ class MockOpResolver : public OpResolver { | ||||
|       return nullptr; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   MicroOpResolver::BuiltinParseFunction GetOpDataParser( | ||||
|       tflite::BuiltinOperator) const override { | ||||
|     // TODO(b/149408647): Figure out an alternative so that we do not have any
 | ||||
|     // references to ParseOpData in the micro code and the signature for
 | ||||
|     // MicroOpResolver::BuiltinParseFunction can be changed to be different from
 | ||||
|     // ParseOpData.
 | ||||
|     return ParseOpData; | ||||
|   } | ||||
| 
 | ||||
|   TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, | ||||
|                           TfLiteRegistration* registration, | ||||
|                           int version) override { | ||||
|     // This function is currently not used in the tests.
 | ||||
|     return kTfLiteError; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| @ -19,14 +19,11 @@ limitations under the License. | ||||
| 
 | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/api/flatbuffer_conversions.h" | ||||
| #include "tensorflow/lite/micro/compatibility.h" | ||||
| #include "tensorflow/lite/micro/micro_op_resolver.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| #ifndef TFLITE_REGISTRATIONS_MAX | ||||
| #define TFLITE_REGISTRATIONS_MAX (128) | ||||
| #endif | ||||
| 
 | ||||
| namespace tflite { | ||||
| 
 | ||||
| // Op versions discussed in this file are enumerated here:
 | ||||
| @ -34,10 +31,10 @@ namespace tflite { | ||||
| 
 | ||||
| inline int MicroOpResolverAnyVersion() { return 0; } | ||||
| 
 | ||||
| template <unsigned int tOpCount = TFLITE_REGISTRATIONS_MAX> | ||||
| class MicroOpResolver : public OpResolver { | ||||
| template <unsigned int tOpCount> | ||||
| class MicroMutableOpResolver : public MicroOpResolver { | ||||
|  public: | ||||
|   explicit MicroOpResolver(ErrorReporter* error_reporter = nullptr) | ||||
|   explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr) | ||||
|       : error_reporter_(error_reporter) {} | ||||
| 
 | ||||
|   const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, | ||||
| @ -68,8 +65,15 @@ class MicroOpResolver : public OpResolver { | ||||
|     return nullptr; | ||||
|   } | ||||
| 
 | ||||
|   MicroOpResolver::BuiltinParseFunction GetOpDataParser( | ||||
|       tflite::BuiltinOperator) const override { | ||||
|     // TODO(b/149408647): Replace with the more selective builtin parser.
 | ||||
|     return ParseOpData; | ||||
|   } | ||||
| 
 | ||||
|   TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, | ||||
|                           TfLiteRegistration* registration, int version = 1) { | ||||
|                           TfLiteRegistration* registration, | ||||
|                           int version = 1) override { | ||||
|     if (registrations_len_ >= tOpCount) { | ||||
|       if (error_reporter_) { | ||||
|         TF_LITE_REPORT_ERROR(error_reporter_, | ||||
| @ -144,14 +148,6 @@ class MicroOpResolver : public OpResolver { | ||||
|   TF_LITE_REMOVE_VIRTUAL_DELETE | ||||
| }; | ||||
| 
 | ||||
| // TODO(b/147854028): Consider switching all uses of MicroMutableOpResolver to
 | ||||
| // MicroOpResolver.
 | ||||
| class MicroMutableOpResolver | ||||
|     : public MicroOpResolver<TFLITE_REGISTRATIONS_MAX> { | ||||
|  private: | ||||
|   TF_LITE_REMOVE_VIRTUAL_DELETE | ||||
| }; | ||||
| 
 | ||||
| };  // namespace tflite
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
 | ||||
|  | ||||
| @ -14,6 +14,8 @@ limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" | ||||
| 
 | ||||
| #include "tensorflow/lite/micro/micro_op_resolver.h" | ||||
| #include "tensorflow/lite/micro/testing/micro_test.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| @ -58,7 +60,7 @@ TF_LITE_MICRO_TESTS_BEGIN | ||||
| TF_LITE_MICRO_TEST(TestOperations) { | ||||
|   using tflite::BuiltinOperator_CONV_2D; | ||||
|   using tflite::BuiltinOperator_RELU; | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
|   using tflite::OpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
| @ -66,7 +68,7 @@ TF_LITE_MICRO_TEST(TestOperations) { | ||||
| 
 | ||||
|   // We need space for 7 operators because of 2 ops, one with 3 versions, one
 | ||||
|   // with 4 versions.
 | ||||
|   MicroOpResolver<7> micro_op_resolver; | ||||
|   MicroMutableOpResolver<7> micro_op_resolver; | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin( | ||||
|                                          BuiltinOperator_CONV_2D, &r, 1, 3)); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, | ||||
| @ -104,32 +106,30 @@ TF_LITE_MICRO_TEST(TestOperations) { | ||||
| TF_LITE_MICRO_TEST(TestOpRegistrationOverflow) { | ||||
|   using tflite::BuiltinOperator_CONV_2D; | ||||
|   using tflite::BuiltinOperator_RELU; | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::OpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   MicroOpResolver<4> micro_op_resolver; | ||||
|   MicroMutableOpResolver<4> micro_op_resolver; | ||||
|   // Register 7 ops, but only 4 is expected because the class is created with
 | ||||
|   // that limit..
 | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin( | ||||
|                                          BuiltinOperator_CONV_2D, &r, 0, 2)); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, | ||||
|                           micro_op_resolver.AddCustom("mock_custom", &r, 0, 3)); | ||||
|   OpResolver* resolver = µ_op_resolver; | ||||
| 
 | ||||
|   TF_LITE_MICRO_EXPECT_EQ(4, micro_op_resolver.GetRegistrationLength()); | ||||
| } | ||||
| 
 | ||||
| TF_LITE_MICRO_TEST(TestZeroVersionRegistration) { | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
|   using tflite::OpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   MicroOpResolver<1> micro_op_resolver; | ||||
|   MicroMutableOpResolver<1> micro_op_resolver; | ||||
|   micro_op_resolver.AddCustom("mock_custom", &r, | ||||
|                               tflite::MicroOpResolverAnyVersion()); | ||||
| 
 | ||||
| @ -157,13 +157,13 @@ TF_LITE_MICRO_TEST(TestZeroVersionRegistration) { | ||||
| } | ||||
| 
 | ||||
| TF_LITE_MICRO_TEST(TestZeroModelVersion) { | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
|   using tflite::OpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   MicroOpResolver<2> micro_op_resolver; | ||||
|   MicroMutableOpResolver<2> micro_op_resolver; | ||||
|   micro_op_resolver.AddCustom("mock_custom", &r, 1, 2); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(2, micro_op_resolver.GetRegistrationLength()); | ||||
|   OpResolver* resolver = µ_op_resolver; | ||||
| @ -196,13 +196,13 @@ TF_LITE_MICRO_TEST(TestZeroModelVersion) { | ||||
| TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) { | ||||
|   using tflite::BuiltinOperator_CONV_2D; | ||||
|   using tflite::BuiltinOperator_RELU; | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   tflite::MockErrorReporter mock_reporter; | ||||
|   MicroOpResolver<1> micro_op_resolver(&mock_reporter); | ||||
|   MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); | ||||
|   TF_LITE_MICRO_EXPECT_EQ( | ||||
|       kTfLiteOk, micro_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r)); | ||||
| @ -215,13 +215,13 @@ TF_LITE_MICRO_TEST(TestBuiltinRegistrationErrorReporting) { | ||||
| TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) { | ||||
|   using tflite::BuiltinOperator_CONV_2D; | ||||
|   using tflite::BuiltinOperator_RELU; | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   tflite::MockErrorReporter mock_reporter; | ||||
|   MicroOpResolver<1> micro_op_resolver(&mock_reporter); | ||||
|   MicroMutableOpResolver<1> micro_op_resolver(&mock_reporter); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, | ||||
|                           micro_op_resolver.AddCustom("mock_custom_0", &r)); | ||||
| @ -234,13 +234,13 @@ TF_LITE_MICRO_TEST(TestCustomRegistrationErrorReporting) { | ||||
| TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) { | ||||
|   using tflite::BuiltinOperator_CONV_2D; | ||||
|   using tflite::BuiltinOperator_RELU; | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   tflite::MockErrorReporter mock_reporter; | ||||
|   MicroOpResolver<2> micro_op_resolver(&mock_reporter); | ||||
|   MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, micro_op_resolver.AddBuiltin( | ||||
|                                          BuiltinOperator_CONV_2D, &r, 1, 2)); | ||||
| @ -253,13 +253,13 @@ TF_LITE_MICRO_TEST(TestBuiltinVersionRegistrationErrorReporting) { | ||||
| TF_LITE_MICRO_TEST(TestCustomVersionRegistrationErrorReporting) { | ||||
|   using tflite::BuiltinOperator_CONV_2D; | ||||
|   using tflite::BuiltinOperator_RELU; | ||||
|   using tflite::MicroOpResolver; | ||||
|   using tflite::MicroMutableOpResolver; | ||||
| 
 | ||||
|   static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree, | ||||
|                                  tflite::MockPrepare, tflite::MockInvoke}; | ||||
| 
 | ||||
|   tflite::MockErrorReporter mock_reporter; | ||||
|   MicroOpResolver<2> micro_op_resolver(&mock_reporter); | ||||
|   MicroMutableOpResolver<2> micro_op_resolver(&mock_reporter); | ||||
|   TF_LITE_MICRO_EXPECT_EQ(false, mock_reporter.HasBeenCalled()); | ||||
|   TF_LITE_MICRO_EXPECT_EQ( | ||||
|       kTfLiteOk, micro_op_resolver.AddCustom("mock_custom_0", &r, 1, 2)); | ||||
|  | ||||
							
								
								
									
										58
									
								
								tensorflow/lite/micro/micro_op_resolver.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								tensorflow/lite/micro/micro_op_resolver.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| #ifndef TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_ | ||||
| #define TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_ | ||||
| 
 | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/flatbuffer_conversions.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/schema/schema_generated.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| 
 | ||||
| // This is an interface for the OpResolver for TFLiteMicro. The differences from
 | ||||
| // the TFLite OpResolver base class are to allow for finer grained registration
 | ||||
| // of the Builtin Ops to reduce code size for TFLiteMicro.  We need an interface
 | ||||
| // class instead of directly using MicroMutableOpResolver because
 | ||||
| // MicroMutableOpResolver is a class template with the number of registered Ops
 | ||||
| // as the template parameter.
 | ||||
| class MicroOpResolver : public OpResolver { | ||||
|  public: | ||||
|   // TODO(b/149408647): The op_type parameter enables a gradual transfer to
 | ||||
|   // selective registration of the parse function. It should be removed once we
 | ||||
|   // no longer need to use ParseOpData (from flatbuffer_conversions.h) as part
 | ||||
|   // of the MicroMutableOpResolver.
 | ||||
|   typedef TfLiteStatus (*BuiltinParseFunction)(const Operator* op, | ||||
|                                                BuiltinOperator op_type, | ||||
|                                                ErrorReporter* error_reporter, | ||||
|                                                BuiltinDataAllocator* allocator, | ||||
|                                                void** builtin_data); | ||||
| 
 | ||||
|   // Returns the operator specific parsing function for the OpData for a
 | ||||
|   //   BuiltinOperator (if registered), else nullptr.
 | ||||
|   virtual BuiltinParseFunction GetOpDataParser( | ||||
|       tflite::BuiltinOperator op) const = 0; | ||||
| 
 | ||||
|   virtual TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, | ||||
|                                   TfLiteRegistration* registration, | ||||
|                                   int version) = 0; | ||||
| 
 | ||||
|   ~MicroOpResolver() override {} | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tflite
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_LITE_MICRO_MICRO_OP_RESOLVER_H_
 | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user