Add TfLiteInterpreterOptionsSetOpResolver function to c_api_experimental.h.

This provides adds a new callback-based op registration C API.
This new function would be required in order to implement the TF Tasks Library, or any other C++ API that uses OpResolver parameters, using the TF Lite C API.

PiperOrigin-RevId: 333154110
Change-Id: I6e53edc91ff1271f41c355bb9ad1a48d32bf3e41
This commit is contained in:
Fergus Henderson 2020-09-22 13:59:01 -07:00 committed by TensorFlower Gardener
parent 604bb7ed52
commit 50acd58a69
7 changed files with 192 additions and 11 deletions

View File

@ -181,6 +181,12 @@ cc_library(
deps = ["//tensorflow/lite/c:common"],
)
cc_library(
name = "builtin_ops",
hdrs = ["builtin_ops.h"],
compatible_with = get_compatible_with_portable(),
)
exports_files(["builtin_ops.h"])
cc_library(

View File

@ -47,6 +47,7 @@ cc_library(
visibility = ["//visibility:private"],
deps = [
":common",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/core/api",
],
@ -60,6 +61,7 @@ cc_library(
deps = [
":c_api_internal",
":common",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite:version",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
@ -125,6 +127,7 @@ cc_library(
"common.h",
],
compatible_with = get_compatible_with_portable(),
deps = ["//tensorflow/lite:builtin_ops"],
alwayslink = 1,
)

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/error_reporter.h"
@ -43,6 +44,44 @@ class CallbackErrorReporter : public tflite::ErrorReporter {
private:
TfLiteErrorReporterCallback callback_;
};
/// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the
/// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks`
/// struct.
///
/// The SetCallbacks method must be called before calling any of the FindOp
/// methods.
class CallbackOpResolver : public ::tflite::OpResolver {
public:
CallbackOpResolver() {}
void SetCallbacks(
const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) {
op_resolver_callbacks_ = op_resolver_callbacks;
}
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override {
if (op_resolver_callbacks_.find_builtin_op == nullptr) {
return nullptr;
}
return op_resolver_callbacks_.find_builtin_op(
op_resolver_callbacks_.user_data,
static_cast<TfLiteBuiltinOperator>(op), version);
}
const TfLiteRegistration* FindOp(const char* op, int version) const override {
if (op_resolver_callbacks_.find_custom_op == nullptr) {
return nullptr;
}
return op_resolver_callbacks_.find_custom_op(
op_resolver_callbacks_.user_data, op, version);
}
private:
CallbackOpResolver(const CallbackOpResolver&) = delete;
CallbackOpResolver& operator=(const CallbackOpResolver&) = delete;
struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {};
};
} // namespace
// LINT.IfChange
@ -210,14 +249,28 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver(
new CallbackErrorReporter(optional_options->error_reporter_callback));
}
// By default, we use the provided mutable_op_resolver, adding any builtin or
// custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
// `TfLiteInterpreterOptionsAddCustomOp`.
tflite::OpResolver* op_resolver = mutable_resolver;
if (optional_options) {
mutable_resolver->AddAll(optional_options->op_resolver);
mutable_resolver->AddAll(optional_options->mutable_op_resolver);
}
// However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
// a non-null callback parameter, then we instead use a
// `CallbackOpResolver` that will forward to the callbacks provided there.
CallbackOpResolver callback_op_resolver;
if (optional_options &&
(optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
optional_options->op_resolver_callbacks.find_custom_op != nullptr)) {
callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
op_resolver = &callback_op_resolver;
}
tflite::ErrorReporter* error_reporter = optional_error_reporter
? optional_error_reporter.get()
: tflite::DefaultErrorReporter();
tflite::InterpreterBuilder builder(model->impl->GetModel(), *mutable_resolver,
tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
error_reporter);
std::unique_ptr<tflite::Interpreter> interpreter;

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#ifdef __cplusplus
extern "C" {
@ -38,8 +37,9 @@ void TfLiteInterpreterOptionsAddBuiltinOp(
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
const TfLiteRegistration* registration, int32_t min_version,
int32_t max_version) {
options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op),
registration, min_version, max_version);
options->mutable_op_resolver.AddBuiltin(
static_cast<tflite::BuiltinOperator>(op), registration, min_version,
max_version);
}
TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(
@ -55,7 +55,21 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options,
const TfLiteRegistration* registration,
int32_t min_version,
int32_t max_version) {
options->op_resolver.AddCustom(name, registration, min_version, max_version);
options->mutable_op_resolver.AddCustom(name, registration, min_version,
max_version);
}
void TfLiteInterpreterOptionsSetOpResolver(
TfLiteInterpreterOptions* options,
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
TfLiteBuiltinOperator op,
int version),
const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
int version),
void* op_resolver_user_data) {
options->op_resolver_callbacks.find_builtin_op = find_builtin_op;
options->op_resolver_callbacks.find_custom_op = find_custom_op;
options->op_resolver_callbacks.user_data = op_resolver_user_data;
}
void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,

View File

@ -36,6 +36,9 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors(
/// so the caller should ensure that its contents (function pointers, etc...)
/// remain valid for the duration of the interpreter's lifetime. A common
/// practice is making the provided `TfLiteRegistration` instance static.
///
/// Code that uses this function should NOT call
/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object.
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
const TfLiteRegistration* registration, int32_t min_version,
@ -50,16 +53,41 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
/// so the caller should ensure that its contents (function pointers, etc...)
/// remain valid for the duration of any created interpreter's lifetime. A
/// common practice is making the provided `TfLiteRegistration` instance static.
///
/// Code that uses this function should NOT call
/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object.
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
TfLiteInterpreterOptions* options, const char* name,
const TfLiteRegistration* registration, int32_t min_version,
int32_t max_version);
/// Registers callbacks for resolving builtin or custom operators.
///
/// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative
/// method for registering builtin ops and/or custom ops, by providing operator
/// resolver callbacks. Unlike using `TfLiteInterpreterOptionsAddBuiltinOp`
/// and/or `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all
/// the operators in a single call.
///
/// Code that uses this function should NOT call
/// `TfLiteInterpreterOptionsAddBuiltin' or
/// `TfLiteInterpreterOptionsAddCustomOp' on the same options object.
void TfLiteInterpreterOptionsSetOpResolver(
TfLiteInterpreterOptions* options,
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
TfLiteBuiltinOperator op,
int version),
const TfLiteRegistration* (*find_custom_op)(void* user_data,
const char* custom_op,
int version),
void* op_resolver_user_data);
/// Returns a new interpreter using the provided model and options, or null on
/// failure, where the model uses only the operators explicitly added to the
/// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`,
/// except that the only operators that are supported are the ones registered
/// in `options` via calls to `TfLiteInterpreterOptionsAddBuiltinOp` and/or
/// in `options` via calls to `TfLiteInterpreterOptionsSetOpResolver`,
/// `TfLiteInterpreterOptionsAddBuiltinOp`, and/or
/// `TfLiteInterpreterOptionsAddCustomOp`.
///
/// * `model` must be a valid model instance. The caller retains ownership of

View File

@ -110,6 +110,55 @@ TEST(CApiExperimentalTest, MissingBuiltin) {
TfLiteModelDelete(model);
}
struct OpResolverData {
bool called_for_add = false;
};
const TfLiteRegistration* MyFindBuiltinOp(void* user_data,
TfLiteBuiltinOperator op,
int version) {
OpResolverData* my_data = static_cast<OpResolverData*>(user_data);
if (op == kTfLiteBuiltinAdd && version == 1) {
my_data->called_for_add = true;
return GetDummyRegistration();
}
return nullptr;
}
const TfLiteRegistration* MyFindCustomOp(void*, const char* custom_op,
int version) {
if (absl::string_view(custom_op) == "foo" && version == 1) {
return GetDummyRegistration();
}
return nullptr;
}
// Test using TfLiteInterpreterCreateWithSelectedOps.
TEST(CApiExperimentalTest, SetOpResolver) {
TfLiteModel* model =
TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
ASSERT_NE(model, nullptr);
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
OpResolverData my_data;
TfLiteInterpreterOptionsSetOpResolver(options, MyFindBuiltinOp,
MyFindCustomOp, &my_data);
EXPECT_FALSE(my_data.called_for_add);
TfLiteInterpreter* interpreter =
TfLiteInterpreterCreateWithSelectedOps(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk);
EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
EXPECT_TRUE(my_data.called_for_add);
TfLiteInterpreterDelete(interpreter);
TfLiteInterpreterOptionsDelete(options);
TfLiteModelDelete(model);
}
} // namespace
int main(int argc, char** argv) {

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <memory>
#include <vector>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/mutable_op_resolver.h"
@ -36,6 +38,30 @@ struct TfLiteModel {
std::shared_ptr<const tflite::FlatBufferModel> impl;
};
// The `TfLiteOpResolver` struct is an abstract callback interface that
// contains function pointers for callbacks that return a
// `TfLiteRegistration` given an op code or custom op name. This mechanism is
// used to map ops referenced in the flatbuffer model to executable function
// pointers (`TfLiteRegistration`s).
// This struct mirrors the tflite::OpResolver C++ abstract base class.
struct TfLiteOpResolverCallbacks {
// Opaque data that gets passed down to the callback functions.
void* user_data = nullptr;
// Callback that finds the op registration for a builtin operator by enum
// code. The `user_data` parameter will be set to the
// `op_resolver_user_data` value that was passed to
// `TfLiteInterpreterOptionsSetOpResolver`.
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
TfLiteBuiltinOperator op,
int version);
// Callback that finds the op registration of a custom operator by op name.
// The `user_data` parameter will be set to the `op_resolver_user_data` value
// that was passed to `TfLiteInterpreterOptionsSetOpResolver`.
const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
int version);
};
// This struct mirrors the tflite::ErrorResolver C++ abstract base class.
struct TfLiteErrorReporterCallback {
// Opaque data that gets passed down to the callback function.
@ -52,7 +78,9 @@ struct TfLiteInterpreterOptions {
};
int num_threads = kDefaultNumThreads;
tflite::MutableOpResolver op_resolver;
tflite::MutableOpResolver mutable_op_resolver;
TfLiteOpResolverCallbacks op_resolver_callbacks = {};
std::vector<TfLiteDelegate*> delegates;
@ -79,9 +107,9 @@ namespace internal {
// This adds the builtin and/or custom operators specified in options in
// `optional_options` (if any) to `mutable_resolver`, and then returns a newly
// created TfLiteInterpreter using `mutable_op_resolver` as the OpResolver, and
// using any other options in `optional_options`, and using the provided
// `model`.
// created TfLiteInterpreter using `mutable_op_resolver` as the default
// OpResolver, and using any other options in `optional_options`, and using
// the provided `model`.
//
// * `model` must be a valid model instance. The caller retains ownership of the
// object, and can destroy it immediately after creating the interpreter; the