diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index a6cc8ab48d9..ac304352a57 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -61,37 +61,45 @@ class MicroOpResolver : public OpResolver { } void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, - int min_version = 1, int max_version = 1) { - for (int version = min_version; version <= max_version; ++version) { - if (registrations_len_ >= tOpCount) { - // TODO(b/147748244) - Add error reporting hooks so we can report this! - return; - } - TfLiteRegistration* new_registration = - ®istrations_[registrations_len_]; - registrations_len_ += 1; + int version = 1) { + if (registrations_len_ >= tOpCount) { + // TODO(b/147748244) - Add error reporting hooks so we can report this! + return; + } + TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; + registrations_len_ += 1; - *new_registration = *registration; - new_registration->builtin_code = op; - new_registration->version = version; + *new_registration = *registration; + new_registration->builtin_code = op; + new_registration->version = version; + } + + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + AddBuiltin(op, registration, version); } } void AddCustom(const char* name, TfLiteRegistration* registration, - int min_version = 1, int max_version = 1) { - for (int version = min_version; version <= max_version; ++version) { - if (registrations_len_ >= tOpCount) { - // TODO(b/147748244) - Add error reporting hooks so we can report this! - return; - } - TfLiteRegistration* new_registration = - ®istrations_[registrations_len_]; - registrations_len_ += 1; + int version = 1) { + if (registrations_len_ >= tOpCount) { + // TODO(b/147748244) - Add error reporting hooks so we can report this! + return; + } + TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; + registrations_len_ += 1; - *new_registration = *registration; - new_registration->builtin_code = BuiltinOperator_CUSTOM; - new_registration->custom_name = name; - new_registration->version = version; + *new_registration = *registration; + new_registration->builtin_code = BuiltinOperator_CUSTOM; + new_registration->custom_name = name; + new_registration->version = version; + } + + void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version, int max_version) { + for (int version = min_version; version <= max_version; ++version) { + AddCustom(name, registration, version); } } diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc b/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc index 8f22a6c8a42..0619591523a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc +++ b/tensorflow/lite/micro/micro_mutable_op_resolver_test.cc @@ -113,6 +113,9 @@ TF_LITE_MICRO_TEST(TestZeroVersionRegistration) { MicroOpResolver<1> micro_op_resolver; micro_op_resolver.AddCustom("mock_custom", &r, tflite::MicroOpResolverAnyVersion()); + + TF_LITE_MICRO_EXPECT_EQ(1, micro_op_resolver.GetRegistrationLength()); + OpResolver* resolver = µ_op_resolver; const TfLiteRegistration* registration = resolver->FindOp("mock_custom", 0); diff --git a/tensorflow/lite/mutable_op_resolver.cc b/tensorflow/lite/mutable_op_resolver.cc index 36c512dcaac..5cb6ed169e7 100644 --- a/tensorflow/lite/mutable_op_resolver.cc +++ b/tensorflow/lite/mutable_op_resolver.cc @@ -29,29 +29,41 @@ const TfLiteRegistration* MutableOpResolver::FindOp(const char* op, return it != custom_ops_.end() ? &it->second : nullptr; } +void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, + const TfLiteRegistration* registration, + int version) { + TfLiteRegistration new_registration = *registration; + new_registration.custom_name = nullptr; + new_registration.builtin_code = op; + new_registration.version = version; + auto op_key = std::make_pair(op, version); + builtins_[op_key] = new_registration; +} + void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, const TfLiteRegistration* registration, int min_version, int max_version) { for (int version = min_version; version <= max_version; ++version) { - TfLiteRegistration new_registration = *registration; - new_registration.custom_name = nullptr; - new_registration.builtin_code = op; - new_registration.version = version; - auto op_key = std::make_pair(op, version); - builtins_[op_key] = new_registration; + AddBuiltin(op, registration, version); } } +void MutableOpResolver::AddCustom(const char* name, + const TfLiteRegistration* registration, + int version) { + TfLiteRegistration new_registration = *registration; + new_registration.builtin_code = BuiltinOperator_CUSTOM; + new_registration.custom_name = name; + new_registration.version = version; + auto op_key = std::make_pair(name, version); + custom_ops_[op_key] = new_registration; +} + void MutableOpResolver::AddCustom(const char* name, const TfLiteRegistration* registration, int min_version, int max_version) { for (int version = min_version; version <= max_version; ++version) { - TfLiteRegistration new_registration = *registration; - new_registration.builtin_code = BuiltinOperator_CUSTOM; - new_registration.custom_name = name; - new_registration.version = version; - auto op_key = std::make_pair(name, version); - custom_ops_[op_key] = new_registration; + AddCustom(name, registration, version); } } diff --git a/tensorflow/lite/mutable_op_resolver.h b/tensorflow/lite/mutable_op_resolver.h index 9e41ee86423..fe5e121424c 100644 --- a/tensorflow/lite/mutable_op_resolver.h +++ b/tensorflow/lite/mutable_op_resolver.h @@ -60,10 +60,14 @@ class MutableOpResolver : public OpResolver { int version) const override; const TfLiteRegistration* FindOp(const char* op, int version) const override; void AddBuiltin(tflite::BuiltinOperator op, - const TfLiteRegistration* registration, int min_version = 1, - int max_version = 1); + const TfLiteRegistration* registration, int version = 1); + void AddBuiltin(tflite::BuiltinOperator op, + const TfLiteRegistration* registration, int min_version, + int max_version); void AddCustom(const char* name, const TfLiteRegistration* registration, - int min_version = 1, int max_version = 1); + int version = 1); + void AddCustom(const char* name, const TfLiteRegistration* registration, + int min_version, int max_version); void AddAll(const MutableOpResolver& other); private: diff --git a/tensorflow/lite/mutable_op_resolver_test.cc b/tensorflow/lite/mutable_op_resolver_test.cc index 22641ebd539..71a30d95b16 100644 --- a/tensorflow/lite/mutable_op_resolver_test.cc +++ b/tensorflow/lite/mutable_op_resolver_test.cc @@ -81,6 +81,25 @@ TEST(MutableOpResolverTest, FindMissingOp) { EXPECT_EQ(found_registration, nullptr); } +TEST(MutableOpResolverTest, RegisterOpWithSingleVersion) { + MutableOpResolver resolver; + // The kernel supports version 2 only + resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2); + + const TfLiteRegistration* found_registration; + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 1); + ASSERT_EQ(found_registration, nullptr); + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 2); + ASSERT_NE(found_registration, nullptr); + EXPECT_TRUE(found_registration->invoke == DummyInvoke); + EXPECT_EQ(found_registration->version, 2); + + found_registration = resolver.FindOp(BuiltinOperator_ADD, 3); + ASSERT_EQ(found_registration, nullptr); +} + TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) { MutableOpResolver resolver; // The kernel supports version 2 and 3