Add TfLiteInterpreterOptionsSetOpResolver function to c_api_experimental.h.

This provides adds a new callback-based op registration C API.
This new function would be required in order to implement the TF Tasks Library, or any other C++ API that uses OpResolver parameters, using the TF Lite C API.

PiperOrigin-RevId: 333154110
Change-Id: I6e53edc91ff1271f41c355bb9ad1a48d32bf3e41
This commit is contained in:
Fergus Henderson 2020-09-22 13:59:01 -07:00 committed by TensorFlower Gardener
parent 604bb7ed52
commit 50acd58a69
7 changed files with 192 additions and 11 deletions

View File

@ -181,6 +181,12 @@ cc_library(
deps = ["//tensorflow/lite/c:common"], deps = ["//tensorflow/lite/c:common"],
) )
cc_library(
name = "builtin_ops",
hdrs = ["builtin_ops.h"],
compatible_with = get_compatible_with_portable(),
)
exports_files(["builtin_ops.h"]) exports_files(["builtin_ops.h"])
cc_library( cc_library(

View File

@ -47,6 +47,7 @@ cc_library(
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
":common", ":common",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
], ],
@ -60,6 +61,7 @@ cc_library(
deps = [ deps = [
":c_api_internal", ":c_api_internal",
":common", ":common",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite:version", "//tensorflow/lite:version",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
@ -125,6 +127,7 @@ cc_library(
"common.h", "common.h",
], ],
compatible_with = get_compatible_with_portable(), compatible_with = get_compatible_with_portable(),
deps = ["//tensorflow/lite:builtin_ops"],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/error_reporter.h"
@ -43,6 +44,44 @@ class CallbackErrorReporter : public tflite::ErrorReporter {
private: private:
TfLiteErrorReporterCallback callback_; TfLiteErrorReporterCallback callback_;
}; };
/// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the
/// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks`
/// struct.
///
/// The SetCallbacks method must be called before calling any of the FindOp
/// methods.
class CallbackOpResolver : public ::tflite::OpResolver {
public:
CallbackOpResolver() {}
void SetCallbacks(
const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) {
op_resolver_callbacks_ = op_resolver_callbacks;
}
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override {
if (op_resolver_callbacks_.find_builtin_op == nullptr) {
return nullptr;
}
return op_resolver_callbacks_.find_builtin_op(
op_resolver_callbacks_.user_data,
static_cast<TfLiteBuiltinOperator>(op), version);
}
const TfLiteRegistration* FindOp(const char* op, int version) const override {
if (op_resolver_callbacks_.find_custom_op == nullptr) {
return nullptr;
}
return op_resolver_callbacks_.find_custom_op(
op_resolver_callbacks_.user_data, op, version);
}
private:
CallbackOpResolver(const CallbackOpResolver&) = delete;
CallbackOpResolver& operator=(const CallbackOpResolver&) = delete;
struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {};
};
} // namespace } // namespace
// LINT.IfChange // LINT.IfChange
@ -210,14 +249,28 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver(
new CallbackErrorReporter(optional_options->error_reporter_callback)); new CallbackErrorReporter(optional_options->error_reporter_callback));
} }
// By default, we use the provided mutable_op_resolver, adding any builtin or
// custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
// `TfLiteInterpreterOptionsAddCustomOp`.
tflite::OpResolver* op_resolver = mutable_resolver;
if (optional_options) { if (optional_options) {
mutable_resolver->AddAll(optional_options->op_resolver); mutable_resolver->AddAll(optional_options->mutable_op_resolver);
}
// However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
// a non-null callback parameter, then we instead use a
// `CallbackOpResolver` that will forward to the callbacks provided there.
CallbackOpResolver callback_op_resolver;
if (optional_options &&
(optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
optional_options->op_resolver_callbacks.find_custom_op != nullptr)) {
callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
op_resolver = &callback_op_resolver;
} }
tflite::ErrorReporter* error_reporter = optional_error_reporter tflite::ErrorReporter* error_reporter = optional_error_reporter
? optional_error_reporter.get() ? optional_error_reporter.get()
: tflite::DefaultErrorReporter(); : tflite::DefaultErrorReporter();
tflite::InterpreterBuilder builder(model->impl->GetModel(), *mutable_resolver, tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
error_reporter); error_reporter);
std::unique_ptr<tflite::Interpreter> interpreter; std::unique_ptr<tflite::Interpreter> interpreter;

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -38,8 +37,9 @@ void TfLiteInterpreterOptionsAddBuiltinOp(
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
const TfLiteRegistration* registration, int32_t min_version, const TfLiteRegistration* registration, int32_t min_version,
int32_t max_version) { int32_t max_version) {
options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op), options->mutable_op_resolver.AddBuiltin(
registration, min_version, max_version); static_cast<tflite::BuiltinOperator>(op), registration, min_version,
max_version);
} }
TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps( TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(
@ -55,7 +55,21 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options,
const TfLiteRegistration* registration, const TfLiteRegistration* registration,
int32_t min_version, int32_t min_version,
int32_t max_version) { int32_t max_version) {
options->op_resolver.AddCustom(name, registration, min_version, max_version); options->mutable_op_resolver.AddCustom(name, registration, min_version,
max_version);
}
void TfLiteInterpreterOptionsSetOpResolver(
TfLiteInterpreterOptions* options,
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
TfLiteBuiltinOperator op,
int version),
const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
int version),
void* op_resolver_user_data) {
options->op_resolver_callbacks.find_builtin_op = find_builtin_op;
options->op_resolver_callbacks.find_custom_op = find_custom_op;
options->op_resolver_callbacks.user_data = op_resolver_user_data;
} }
void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options, void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,

View File

@ -36,6 +36,9 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors(
/// so the caller should ensure that its contents (function pointers, etc...) /// so the caller should ensure that its contents (function pointers, etc...)
/// remain valid for the duration of the interpreter's lifetime. A common /// remain valid for the duration of the interpreter's lifetime. A common
/// practice is making the provided `TfLiteRegistration` instance static. /// practice is making the provided `TfLiteRegistration` instance static.
///
/// Code that uses this function should NOT call
/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object.
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp( TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
const TfLiteRegistration* registration, int32_t min_version, const TfLiteRegistration* registration, int32_t min_version,
@ -50,16 +53,41 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
/// so the caller should ensure that its contents (function pointers, etc...) /// so the caller should ensure that its contents (function pointers, etc...)
/// remain valid for the duration of any created interpreter's lifetime. A /// remain valid for the duration of any created interpreter's lifetime. A
/// common practice is making the provided `TfLiteRegistration` instance static. /// common practice is making the provided `TfLiteRegistration` instance static.
///
/// Code that uses this function should NOT call
/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object.
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
TfLiteInterpreterOptions* options, const char* name, TfLiteInterpreterOptions* options, const char* name,
const TfLiteRegistration* registration, int32_t min_version, const TfLiteRegistration* registration, int32_t min_version,
int32_t max_version); int32_t max_version);
/// Registers callbacks for resolving builtin or custom operators.
///
/// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative
/// method for registering builtin ops and/or custom ops, by providing operator
/// resolver callbacks. Unlike using `TfLiteInterpreterOptionsAddBuiltinOp`
/// and/or `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all
/// the operators in a single call.
///
/// Code that uses this function should NOT call
/// `TfLiteInterpreterOptionsAddBuiltin' or
/// `TfLiteInterpreterOptionsAddCustomOp' on the same options object.
void TfLiteInterpreterOptionsSetOpResolver(
TfLiteInterpreterOptions* options,
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
TfLiteBuiltinOperator op,
int version),
const TfLiteRegistration* (*find_custom_op)(void* user_data,
const char* custom_op,
int version),
void* op_resolver_user_data);
/// Returns a new interpreter using the provided model and options, or null on /// Returns a new interpreter using the provided model and options, or null on
/// failure, where the model uses only the operators explicitly added to the /// failure, where the model uses only the operators explicitly added to the
/// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`, /// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`,
/// except that the only operators that are supported are the ones registered /// except that the only operators that are supported are the ones registered
/// in `options` via calls to `TfLiteInterpreterOptionsAddBuiltinOp` and/or /// in `options` via calls to `TfLiteInterpreterOptionsSetOpResolver`,
/// `TfLiteInterpreterOptionsAddBuiltinOp`, and/or
/// `TfLiteInterpreterOptionsAddCustomOp`. /// `TfLiteInterpreterOptionsAddCustomOp`.
/// ///
/// * `model` must be a valid model instance. The caller retains ownership of /// * `model` must be a valid model instance. The caller retains ownership of

View File

@ -110,6 +110,55 @@ TEST(CApiExperimentalTest, MissingBuiltin) {
TfLiteModelDelete(model); TfLiteModelDelete(model);
} }
struct OpResolverData {
bool called_for_add = false;
};
const TfLiteRegistration* MyFindBuiltinOp(void* user_data,
TfLiteBuiltinOperator op,
int version) {
OpResolverData* my_data = static_cast<OpResolverData*>(user_data);
if (op == kTfLiteBuiltinAdd && version == 1) {
my_data->called_for_add = true;
return GetDummyRegistration();
}
return nullptr;
}
const TfLiteRegistration* MyFindCustomOp(void*, const char* custom_op,
int version) {
if (absl::string_view(custom_op) == "foo" && version == 1) {
return GetDummyRegistration();
}
return nullptr;
}
// Test using TfLiteInterpreterCreateWithSelectedOps.
TEST(CApiExperimentalTest, SetOpResolver) {
TfLiteModel* model =
TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
ASSERT_NE(model, nullptr);
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
OpResolverData my_data;
TfLiteInterpreterOptionsSetOpResolver(options, MyFindBuiltinOp,
MyFindCustomOp, &my_data);
EXPECT_FALSE(my_data.called_for_add);
TfLiteInterpreter* interpreter =
TfLiteInterpreterCreateWithSelectedOps(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk);
EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
EXPECT_TRUE(my_data.called_for_add);
TfLiteInterpreterDelete(interpreter);
TfLiteInterpreterOptionsDelete(options);
TfLiteModelDelete(model);
}
} // namespace } // namespace
int main(int argc, char** argv) { int main(int argc, char** argv) {

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/mutable_op_resolver.h"
@ -36,6 +38,30 @@ struct TfLiteModel {
std::shared_ptr<const tflite::FlatBufferModel> impl; std::shared_ptr<const tflite::FlatBufferModel> impl;
}; };
// The `TfLiteOpResolver` struct is an abstract callback interface that
// contains function pointers for callbacks that return a
// `TfLiteRegistration` given an op code or custom op name. This mechanism is
// used to map ops referenced in the flatbuffer model to executable function
// pointers (`TfLiteRegistration`s).
// This struct mirrors the tflite::OpResolver C++ abstract base class.
struct TfLiteOpResolverCallbacks {
// Opaque data that gets passed down to the callback functions.
void* user_data = nullptr;
// Callback that finds the op registration for a builtin operator by enum
// code. The `user_data` parameter will be set to the
// `op_resolver_user_data` value that was passed to
// `TfLiteInterpreterOptionsSetOpResolver`.
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
TfLiteBuiltinOperator op,
int version);
// Callback that finds the op registration of a custom operator by op name.
// The `user_data` parameter will be set to the `op_resolver_user_data` value
// that was passed to `TfLiteInterpreterOptionsSetOpResolver`.
const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
int version);
};
// This struct mirrors the tflite::ErrorResolver C++ abstract base class. // This struct mirrors the tflite::ErrorResolver C++ abstract base class.
struct TfLiteErrorReporterCallback { struct TfLiteErrorReporterCallback {
// Opaque data that gets passed down to the callback function. // Opaque data that gets passed down to the callback function.
@ -52,7 +78,9 @@ struct TfLiteInterpreterOptions {
}; };
int num_threads = kDefaultNumThreads; int num_threads = kDefaultNumThreads;
tflite::MutableOpResolver op_resolver; tflite::MutableOpResolver mutable_op_resolver;
TfLiteOpResolverCallbacks op_resolver_callbacks = {};
std::vector<TfLiteDelegate*> delegates; std::vector<TfLiteDelegate*> delegates;
@ -79,9 +107,9 @@ namespace internal {
// This adds the builtin and/or custom operators specified in options in // This adds the builtin and/or custom operators specified in options in
// `optional_options` (if any) to `mutable_resolver`, and then returns a newly // `optional_options` (if any) to `mutable_resolver`, and then returns a newly
// created TfLiteInterpreter using `mutable_op_resolver` as the OpResolver, and // created TfLiteInterpreter using `mutable_op_resolver` as the default
// using any other options in `optional_options`, and using the provided // OpResolver, and using any other options in `optional_options`, and using
// `model`. // the provided `model`.
// //
// * `model` must be a valid model instance. The caller retains ownership of the // * `model` must be a valid model instance. The caller retains ownership of the
// object, and can destroy it immediately after creating the interpreter; the // object, and can destroy it immediately after creating the interpreter; the