diff --git a/tensorflow/lite/mutable_op_resolver_test.cc b/tensorflow/lite/mutable_op_resolver_test.cc index 64fc68a16ca..22641ebd539 100644 --- a/tensorflow/lite/mutable_op_resolver_test.cc +++ b/tensorflow/lite/mutable_op_resolver_test.cc @@ -40,11 +40,21 @@ TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus Dummy2Prepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +void* Dummy2Init(TfLiteContext* context, const char* buffer, size_t length) { + return nullptr; +} + +void Dummy2free(TfLiteContext* context, void* buffer) {} + TfLiteRegistration* GetDummy2Registration() { static TfLiteRegistration registration = { - .init = nullptr, - .free = nullptr, - .prepare = nullptr, + .init = Dummy2Init, + .free = Dummy2free, + .prepare = Dummy2Prepare, .invoke = Dummy2Invoke, }; return ®istration; @@ -112,8 +122,41 @@ TEST(MutableOpResolverTest, FindCustomOp) { EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM); EXPECT_TRUE(found_registration->invoke == DummyInvoke); EXPECT_EQ(found_registration->version, 1); - // TODO(ycling): The `custom_name` in TfLiteRegistration isn't properly - // filled yet. Fix this and add tests. +} + +TEST(MutableOpResolverTest, FindCustomName) { + MutableOpResolver resolver; + TfLiteRegistration* reg = GetDummyRegistration(); + + reg->custom_name = "UPDATED"; + resolver.AddCustom(reg->custom_name, reg); + const TfLiteRegistration* found_registration = + resolver.FindOp(reg->custom_name, 1); + + ASSERT_NE(found_registration, nullptr); + EXPECT_EQ(found_registration->builtin_code, BuiltinOperator_CUSTOM); + EXPECT_EQ(found_registration->invoke, GetDummyRegistration()->invoke); + EXPECT_EQ(found_registration->version, 1); + EXPECT_EQ(found_registration->custom_name, "UPDATED"); +} + +TEST(MutableOpResolverTest, FindBuiltinName) { + MutableOpResolver resolver1; + TfLiteRegistration* reg = GetDummy2Registration(); + + reg->custom_name = "UPDATED"; + resolver1.AddBuiltin(BuiltinOperator_ADD, reg); + + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke, + GetDummy2Registration()->invoke); + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->prepare, + GetDummy2Registration()->prepare); + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->init, + GetDummy2Registration()->init); + ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->free, + GetDummy2Registration()->free); + // custom_name for builtin ops will be nullptr + EXPECT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->custom_name, nullptr); } TEST(MutableOpResolverTest, FindMissingCustomOp) {