From 50acd58a69289b64bd7a4c9f25e2a8adbd19b67d Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Tue, 22 Sep 2020 13:59:01 -0700 Subject: [PATCH] Add TfLiteInterpreterOptionsSetOpResolver function to c_api_experimental.h. This provides adds a new callback-based op registration C API. This new function would be required in order to implement the TF Tasks Library, or any other C++ API that uses OpResolver parameters, using the TF Lite C API. PiperOrigin-RevId: 333154110 Change-Id: I6e53edc91ff1271f41c355bb9ad1a48d32bf3e41 --- tensorflow/lite/BUILD | 6 +++ tensorflow/lite/c/BUILD | 3 ++ tensorflow/lite/c/c_api.cc | 57 +++++++++++++++++++- tensorflow/lite/c/c_api_experimental.cc | 22 ++++++-- tensorflow/lite/c/c_api_experimental.h | 30 ++++++++++- tensorflow/lite/c/c_api_experimental_test.cc | 49 +++++++++++++++++ tensorflow/lite/c/c_api_internal.h | 36 +++++++++++-- 7 files changed, 192 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index d08a2e2fbfa..754a86e5916 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -181,6 +181,12 @@ cc_library( deps = ["//tensorflow/lite/c:common"], ) +cc_library( + name = "builtin_ops", + hdrs = ["builtin_ops.h"], + compatible_with = get_compatible_with_portable(), +) + exports_files(["builtin_ops.h"]) cc_library( diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 3f4ff9130a0..e8db0dcf440 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -47,6 +47,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":common", + "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite/core/api", ], @@ -60,6 +61,7 @@ cc_library( deps = [ ":c_api_internal", ":common", + "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:version", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", @@ -125,6 +127,7 @@ cc_library( "common.h", ], compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/lite:builtin_ops"], alwayslink = 1, ) diff --git a/tensorflow/lite/c/c_api.cc b/tensorflow/lite/c/c_api.cc index 895b6798c94..205c665d08b 100644 --- a/tensorflow/lite/c/c_api.cc +++ b/tensorflow/lite/c/c_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/error_reporter.h" @@ -43,6 +44,44 @@ class CallbackErrorReporter : public tflite::ErrorReporter { private: TfLiteErrorReporterCallback callback_; }; + +/// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the +/// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks` +/// struct. +/// +/// The SetCallbacks method must be called before calling any of the FindOp +/// methods. +class CallbackOpResolver : public ::tflite::OpResolver { + public: + CallbackOpResolver() {} + void SetCallbacks( + const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) { + op_resolver_callbacks_ = op_resolver_callbacks; + } + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override { + if (op_resolver_callbacks_.find_builtin_op == nullptr) { + return nullptr; + } + return op_resolver_callbacks_.find_builtin_op( + op_resolver_callbacks_.user_data, + static_cast(op), version); + } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + if (op_resolver_callbacks_.find_custom_op == nullptr) { + return nullptr; + } + return op_resolver_callbacks_.find_custom_op( + op_resolver_callbacks_.user_data, op, version); + } + + private: + CallbackOpResolver(const CallbackOpResolver&) = delete; + CallbackOpResolver& operator=(const CallbackOpResolver&) = delete; + + struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {}; +}; + } // namespace // LINT.IfChange @@ -210,14 +249,28 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver( new CallbackErrorReporter(optional_options->error_reporter_callback)); } + // By default, we use the provided mutable_op_resolver, adding any builtin or + // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or + // `TfLiteInterpreterOptionsAddCustomOp`. + tflite::OpResolver* op_resolver = mutable_resolver; if (optional_options) { - mutable_resolver->AddAll(optional_options->op_resolver); + mutable_resolver->AddAll(optional_options->mutable_op_resolver); + } + // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with + // a non-null callback parameter, then we instead use a + // `CallbackOpResolver` that will forward to the callbacks provided there. + CallbackOpResolver callback_op_resolver; + if (optional_options && + (optional_options->op_resolver_callbacks.find_builtin_op != nullptr || + optional_options->op_resolver_callbacks.find_custom_op != nullptr)) { + callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks); + op_resolver = &callback_op_resolver; } tflite::ErrorReporter* error_reporter = optional_error_reporter ? optional_error_reporter.get() : tflite::DefaultErrorReporter(); - tflite::InterpreterBuilder builder(model->impl->GetModel(), *mutable_resolver, + tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver, error_reporter); std::unique_ptr interpreter; diff --git a/tensorflow/lite/c/c_api_experimental.cc b/tensorflow/lite/c/c_api_experimental.cc index 1d84d24eb14..23a5ca7a275 100644 --- a/tensorflow/lite/c/c_api_experimental.cc +++ b/tensorflow/lite/c/c_api_experimental.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/mutable_op_resolver.h" #ifdef __cplusplus extern "C" { @@ -38,8 +37,9 @@ void TfLiteInterpreterOptionsAddBuiltinOp( TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version) { - options->op_resolver.AddBuiltin(static_cast(op), - registration, min_version, max_version); + options->mutable_op_resolver.AddBuiltin( + static_cast(op), registration, min_version, + max_version); } TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps( @@ -55,7 +55,21 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version) { - options->op_resolver.AddCustom(name, registration, min_version, max_version); + options->mutable_op_resolver.AddCustom(name, registration, min_version, + max_version); +} + +void TfLiteInterpreterOptionsSetOpResolver( + TfLiteInterpreterOptions* options, + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version), + void* op_resolver_user_data) { + options->op_resolver_callbacks.find_builtin_op = find_builtin_op; + options->op_resolver_callbacks.find_custom_op = find_custom_op; + options->op_resolver_callbacks.user_data = op_resolver_user_data; } void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, diff --git a/tensorflow/lite/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h index 5971b5f6a4a..8d635f32a3a 100644 --- a/tensorflow/lite/c/c_api_experimental.h +++ b/tensorflow/lite/c/c_api_experimental.h @@ -36,6 +36,9 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors( /// so the caller should ensure that its contents (function pointers, etc...) /// remain valid for the duration of the interpreter's lifetime. A common /// practice is making the provided `TfLiteRegistration` instance static. +/// +/// Code that uses this function should NOT call +/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object. TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp( TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, const TfLiteRegistration* registration, int32_t min_version, @@ -50,16 +53,41 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp( /// so the caller should ensure that its contents (function pointers, etc...) /// remain valid for the duration of any created interpreter's lifetime. A /// common practice is making the provided `TfLiteRegistration` instance static. +/// +/// Code that uses this function should NOT call +/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object. TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( TfLiteInterpreterOptions* options, const char* name, const TfLiteRegistration* registration, int32_t min_version, int32_t max_version); +/// Registers callbacks for resolving builtin or custom operators. +/// +/// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative +/// method for registering builtin ops and/or custom ops, by providing operator +/// resolver callbacks. Unlike using `TfLiteInterpreterOptionsAddBuiltinOp` +/// and/or `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all +/// the operators in a single call. +/// +/// Code that uses this function should NOT call +/// `TfLiteInterpreterOptionsAddBuiltin' or +/// `TfLiteInterpreterOptionsAddCustomOp' on the same options object. +void TfLiteInterpreterOptionsSetOpResolver( + TfLiteInterpreterOptions* options, + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version), + const TfLiteRegistration* (*find_custom_op)(void* user_data, + const char* custom_op, + int version), + void* op_resolver_user_data); + /// Returns a new interpreter using the provided model and options, or null on /// failure, where the model uses only the operators explicitly added to the /// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`, /// except that the only operators that are supported are the ones registered -/// in `options` via calls to `TfLiteInterpreterOptionsAddBuiltinOp` and/or +/// in `options` via calls to `TfLiteInterpreterOptionsSetOpResolver`, +/// `TfLiteInterpreterOptionsAddBuiltinOp`, and/or /// `TfLiteInterpreterOptionsAddCustomOp`. /// /// * `model` must be a valid model instance. The caller retains ownership of diff --git a/tensorflow/lite/c/c_api_experimental_test.cc b/tensorflow/lite/c/c_api_experimental_test.cc index ec79e4d898e..4de137ec0e6 100644 --- a/tensorflow/lite/c/c_api_experimental_test.cc +++ b/tensorflow/lite/c/c_api_experimental_test.cc @@ -110,6 +110,55 @@ TEST(CApiExperimentalTest, MissingBuiltin) { TfLiteModelDelete(model); } +struct OpResolverData { + bool called_for_add = false; +}; + +const TfLiteRegistration* MyFindBuiltinOp(void* user_data, + TfLiteBuiltinOperator op, + int version) { + OpResolverData* my_data = static_cast(user_data); + if (op == kTfLiteBuiltinAdd && version == 1) { + my_data->called_for_add = true; + return GetDummyRegistration(); + } + return nullptr; +} + +const TfLiteRegistration* MyFindCustomOp(void*, const char* custom_op, + int version) { + if (absl::string_view(custom_op) == "foo" && version == 1) { + return GetDummyRegistration(); + } + return nullptr; +} + +// Test using TfLiteInterpreterCreateWithSelectedOps. +TEST(CApiExperimentalTest, SetOpResolver) { + TfLiteModel* model = + TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + + OpResolverData my_data; + TfLiteInterpreterOptionsSetOpResolver(options, MyFindBuiltinOp, + MyFindCustomOp, &my_data); + EXPECT_FALSE(my_data.called_for_add); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreateWithSelectedOps(model, options); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk); + EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk); + EXPECT_TRUE(my_data.called_for_add); + + TfLiteInterpreterDelete(interpreter); + TfLiteInterpreterOptionsDelete(options); + TfLiteModelDelete(model); +} + } // namespace int main(int argc, char** argv) { diff --git a/tensorflow/lite/c/c_api_internal.h b/tensorflow/lite/c/c_api_internal.h index cc31807bd3f..ee07e3e06a5 100644 --- a/tensorflow/lite/c/c_api_internal.h +++ b/tensorflow/lite/c/c_api_internal.h @@ -20,7 +20,9 @@ limitations under the License. #include #include +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/mutable_op_resolver.h" @@ -36,6 +38,30 @@ struct TfLiteModel { std::shared_ptr impl; }; +// The `TfLiteOpResolver` struct is an abstract callback interface that +// contains function pointers for callbacks that return a +// `TfLiteRegistration` given an op code or custom op name. This mechanism is +// used to map ops referenced in the flatbuffer model to executable function +// pointers (`TfLiteRegistration`s). +// This struct mirrors the tflite::OpResolver C++ abstract base class. +struct TfLiteOpResolverCallbacks { + // Opaque data that gets passed down to the callback functions. + void* user_data = nullptr; + + // Callback that finds the op registration for a builtin operator by enum + // code. The `user_data` parameter will be set to the + // `op_resolver_user_data` value that was passed to + // `TfLiteInterpreterOptionsSetOpResolver`. + const TfLiteRegistration* (*find_builtin_op)(void* user_data, + TfLiteBuiltinOperator op, + int version); + // Callback that finds the op registration of a custom operator by op name. + // The `user_data` parameter will be set to the `op_resolver_user_data` value + // that was passed to `TfLiteInterpreterOptionsSetOpResolver`. + const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, + int version); +}; + // This struct mirrors the tflite::ErrorResolver C++ abstract base class. struct TfLiteErrorReporterCallback { // Opaque data that gets passed down to the callback function. @@ -52,7 +78,9 @@ struct TfLiteInterpreterOptions { }; int num_threads = kDefaultNumThreads; - tflite::MutableOpResolver op_resolver; + tflite::MutableOpResolver mutable_op_resolver; + + TfLiteOpResolverCallbacks op_resolver_callbacks = {}; std::vector delegates; @@ -79,9 +107,9 @@ namespace internal { // This adds the builtin and/or custom operators specified in options in // `optional_options` (if any) to `mutable_resolver`, and then returns a newly -// created TfLiteInterpreter using `mutable_op_resolver` as the OpResolver, and -// using any other options in `optional_options`, and using the provided -// `model`. +// created TfLiteInterpreter using `mutable_op_resolver` as the default +// OpResolver, and using any other options in `optional_options`, and using +// the provided `model`. // // * `model` must be a valid model instance. The caller retains ownership of the // object, and can destroy it immediately after creating the interpreter; the