Add single-parameter versions of AddBuiltin and AddCustom functions.
This allows us to make the max_version the same as min_version by default; C++ does not allow the default value of a parameter to depend on the other parameters. PiperOrigin-RevId: 301269532 Change-Id: If2023ec1a2f7081e601bb95fd65f0f52c6eb83c6
This commit is contained in:
parent
eb7bdc25ca
commit
876a3c1708
@ -61,37 +61,45 @@ class MicroOpResolver : public OpResolver {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
|
void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
|
||||||
int min_version = 1, int max_version = 1) {
|
int version = 1) {
|
||||||
for (int version = min_version; version <= max_version; ++version) {
|
if (registrations_len_ >= tOpCount) {
|
||||||
if (registrations_len_ >= tOpCount) {
|
// TODO(b/147748244) - Add error reporting hooks so we can report this!
|
||||||
// TODO(b/147748244) - Add error reporting hooks so we can report this!
|
return;
|
||||||
return;
|
}
|
||||||
}
|
TfLiteRegistration* new_registration = ®istrations_[registrations_len_];
|
||||||
TfLiteRegistration* new_registration =
|
registrations_len_ += 1;
|
||||||
®istrations_[registrations_len_];
|
|
||||||
registrations_len_ += 1;
|
|
||||||
|
|
||||||
*new_registration = *registration;
|
*new_registration = *registration;
|
||||||
new_registration->builtin_code = op;
|
new_registration->builtin_code = op;
|
||||||
new_registration->version = version;
|
new_registration->version = version;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
|
||||||
|
int min_version, int max_version) {
|
||||||
|
for (int version = min_version; version <= max_version; ++version) {
|
||||||
|
AddBuiltin(op, registration, version);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddCustom(const char* name, TfLiteRegistration* registration,
|
void AddCustom(const char* name, TfLiteRegistration* registration,
|
||||||
int min_version = 1, int max_version = 1) {
|
int version = 1) {
|
||||||
for (int version = min_version; version <= max_version; ++version) {
|
if (registrations_len_ >= tOpCount) {
|
||||||
if (registrations_len_ >= tOpCount) {
|
// TODO(b/147748244) - Add error reporting hooks so we can report this!
|
||||||
// TODO(b/147748244) - Add error reporting hooks so we can report this!
|
return;
|
||||||
return;
|
}
|
||||||
}
|
TfLiteRegistration* new_registration = ®istrations_[registrations_len_];
|
||||||
TfLiteRegistration* new_registration =
|
registrations_len_ += 1;
|
||||||
®istrations_[registrations_len_];
|
|
||||||
registrations_len_ += 1;
|
|
||||||
|
|
||||||
*new_registration = *registration;
|
*new_registration = *registration;
|
||||||
new_registration->builtin_code = BuiltinOperator_CUSTOM;
|
new_registration->builtin_code = BuiltinOperator_CUSTOM;
|
||||||
new_registration->custom_name = name;
|
new_registration->custom_name = name;
|
||||||
new_registration->version = version;
|
new_registration->version = version;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddCustom(const char* name, TfLiteRegistration* registration,
|
||||||
|
int min_version, int max_version) {
|
||||||
|
for (int version = min_version; version <= max_version; ++version) {
|
||||||
|
AddCustom(name, registration, version);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +113,9 @@ TF_LITE_MICRO_TEST(TestZeroVersionRegistration) {
|
|||||||
MicroOpResolver<1> micro_op_resolver;
|
MicroOpResolver<1> micro_op_resolver;
|
||||||
micro_op_resolver.AddCustom("mock_custom", &r,
|
micro_op_resolver.AddCustom("mock_custom", &r,
|
||||||
tflite::MicroOpResolverAnyVersion());
|
tflite::MicroOpResolverAnyVersion());
|
||||||
|
|
||||||
|
TF_LITE_MICRO_EXPECT_EQ(1, micro_op_resolver.GetRegistrationLength());
|
||||||
|
|
||||||
OpResolver* resolver = µ_op_resolver;
|
OpResolver* resolver = µ_op_resolver;
|
||||||
|
|
||||||
const TfLiteRegistration* registration = resolver->FindOp("mock_custom", 0);
|
const TfLiteRegistration* registration = resolver->FindOp("mock_custom", 0);
|
||||||
|
@ -29,29 +29,41 @@ const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
|
|||||||
return it != custom_ops_.end() ? &it->second : nullptr;
|
return it != custom_ops_.end() ? &it->second : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
|
||||||
|
const TfLiteRegistration* registration,
|
||||||
|
int version) {
|
||||||
|
TfLiteRegistration new_registration = *registration;
|
||||||
|
new_registration.custom_name = nullptr;
|
||||||
|
new_registration.builtin_code = op;
|
||||||
|
new_registration.version = version;
|
||||||
|
auto op_key = std::make_pair(op, version);
|
||||||
|
builtins_[op_key] = new_registration;
|
||||||
|
}
|
||||||
|
|
||||||
void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
|
void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
|
||||||
const TfLiteRegistration* registration,
|
const TfLiteRegistration* registration,
|
||||||
int min_version, int max_version) {
|
int min_version, int max_version) {
|
||||||
for (int version = min_version; version <= max_version; ++version) {
|
for (int version = min_version; version <= max_version; ++version) {
|
||||||
TfLiteRegistration new_registration = *registration;
|
AddBuiltin(op, registration, version);
|
||||||
new_registration.custom_name = nullptr;
|
|
||||||
new_registration.builtin_code = op;
|
|
||||||
new_registration.version = version;
|
|
||||||
auto op_key = std::make_pair(op, version);
|
|
||||||
builtins_[op_key] = new_registration;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MutableOpResolver::AddCustom(const char* name,
|
||||||
|
const TfLiteRegistration* registration,
|
||||||
|
int version) {
|
||||||
|
TfLiteRegistration new_registration = *registration;
|
||||||
|
new_registration.builtin_code = BuiltinOperator_CUSTOM;
|
||||||
|
new_registration.custom_name = name;
|
||||||
|
new_registration.version = version;
|
||||||
|
auto op_key = std::make_pair(name, version);
|
||||||
|
custom_ops_[op_key] = new_registration;
|
||||||
|
}
|
||||||
|
|
||||||
void MutableOpResolver::AddCustom(const char* name,
|
void MutableOpResolver::AddCustom(const char* name,
|
||||||
const TfLiteRegistration* registration,
|
const TfLiteRegistration* registration,
|
||||||
int min_version, int max_version) {
|
int min_version, int max_version) {
|
||||||
for (int version = min_version; version <= max_version; ++version) {
|
for (int version = min_version; version <= max_version; ++version) {
|
||||||
TfLiteRegistration new_registration = *registration;
|
AddCustom(name, registration, version);
|
||||||
new_registration.builtin_code = BuiltinOperator_CUSTOM;
|
|
||||||
new_registration.custom_name = name;
|
|
||||||
new_registration.version = version;
|
|
||||||
auto op_key = std::make_pair(name, version);
|
|
||||||
custom_ops_[op_key] = new_registration;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,10 +60,14 @@ class MutableOpResolver : public OpResolver {
|
|||||||
int version) const override;
|
int version) const override;
|
||||||
const TfLiteRegistration* FindOp(const char* op, int version) const override;
|
const TfLiteRegistration* FindOp(const char* op, int version) const override;
|
||||||
void AddBuiltin(tflite::BuiltinOperator op,
|
void AddBuiltin(tflite::BuiltinOperator op,
|
||||||
const TfLiteRegistration* registration, int min_version = 1,
|
const TfLiteRegistration* registration, int version = 1);
|
||||||
int max_version = 1);
|
void AddBuiltin(tflite::BuiltinOperator op,
|
||||||
|
const TfLiteRegistration* registration, int min_version,
|
||||||
|
int max_version);
|
||||||
void AddCustom(const char* name, const TfLiteRegistration* registration,
|
void AddCustom(const char* name, const TfLiteRegistration* registration,
|
||||||
int min_version = 1, int max_version = 1);
|
int version = 1);
|
||||||
|
void AddCustom(const char* name, const TfLiteRegistration* registration,
|
||||||
|
int min_version, int max_version);
|
||||||
void AddAll(const MutableOpResolver& other);
|
void AddAll(const MutableOpResolver& other);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -81,6 +81,25 @@ TEST(MutableOpResolverTest, FindMissingOp) {
|
|||||||
EXPECT_EQ(found_registration, nullptr);
|
EXPECT_EQ(found_registration, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MutableOpResolverTest, RegisterOpWithSingleVersion) {
|
||||||
|
MutableOpResolver resolver;
|
||||||
|
// The kernel supports version 2 only
|
||||||
|
resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration(), 2);
|
||||||
|
|
||||||
|
const TfLiteRegistration* found_registration;
|
||||||
|
|
||||||
|
found_registration = resolver.FindOp(BuiltinOperator_ADD, 1);
|
||||||
|
ASSERT_EQ(found_registration, nullptr);
|
||||||
|
|
||||||
|
found_registration = resolver.FindOp(BuiltinOperator_ADD, 2);
|
||||||
|
ASSERT_NE(found_registration, nullptr);
|
||||||
|
EXPECT_TRUE(found_registration->invoke == DummyInvoke);
|
||||||
|
EXPECT_EQ(found_registration->version, 2);
|
||||||
|
|
||||||
|
found_registration = resolver.FindOp(BuiltinOperator_ADD, 3);
|
||||||
|
ASSERT_EQ(found_registration, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
|
TEST(MutableOpResolverTest, RegisterOpWithMultipleVersions) {
|
||||||
MutableOpResolver resolver;
|
MutableOpResolver resolver;
|
||||||
// The kernel supports version 2 and 3
|
// The kernel supports version 2 and 3
|
||||||
|
Loading…
Reference in New Issue
Block a user