Add TfLiteInterpreterOptionsSetOpResolver function to c_api_experimental.h.
This provides adds a new callback-based op registration C API. This new function would be required in order to implement the TF Tasks Library, or any other C++ API that uses OpResolver parameters, using the TF Lite C API. PiperOrigin-RevId: 333154110 Change-Id: I6e53edc91ff1271f41c355bb9ad1a48d32bf3e41
This commit is contained in:
parent
604bb7ed52
commit
50acd58a69
@ -181,6 +181,12 @@ cc_library(
|
|||||||
deps = ["//tensorflow/lite/c:common"],
|
deps = ["//tensorflow/lite/c:common"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "builtin_ops",
|
||||||
|
hdrs = ["builtin_ops.h"],
|
||||||
|
compatible_with = get_compatible_with_portable(),
|
||||||
|
)
|
||||||
|
|
||||||
exports_files(["builtin_ops.h"])
|
exports_files(["builtin_ops.h"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -47,6 +47,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
|
"//tensorflow/lite:builtin_ops",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/core/api",
|
"//tensorflow/lite/core/api",
|
||||||
],
|
],
|
||||||
@ -60,6 +61,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":c_api_internal",
|
":c_api_internal",
|
||||||
":common",
|
":common",
|
||||||
|
"//tensorflow/lite:builtin_ops",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite:version",
|
"//tensorflow/lite:version",
|
||||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||||
@ -125,6 +127,7 @@ cc_library(
|
|||||||
"common.h",
|
"common.h",
|
||||||
],
|
],
|
||||||
compatible_with = get_compatible_with_portable(),
|
compatible_with = get_compatible_with_portable(),
|
||||||
|
deps = ["//tensorflow/lite:builtin_ops"],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#include "tensorflow/lite/error_reporter.h"
|
#include "tensorflow/lite/error_reporter.h"
|
||||||
@ -43,6 +44,44 @@ class CallbackErrorReporter : public tflite::ErrorReporter {
|
|||||||
private:
|
private:
|
||||||
TfLiteErrorReporterCallback callback_;
|
TfLiteErrorReporterCallback callback_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the
|
||||||
|
/// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks`
|
||||||
|
/// struct.
|
||||||
|
///
|
||||||
|
/// The SetCallbacks method must be called before calling any of the FindOp
|
||||||
|
/// methods.
|
||||||
|
class CallbackOpResolver : public ::tflite::OpResolver {
|
||||||
|
public:
|
||||||
|
CallbackOpResolver() {}
|
||||||
|
void SetCallbacks(
|
||||||
|
const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) {
|
||||||
|
op_resolver_callbacks_ = op_resolver_callbacks;
|
||||||
|
}
|
||||||
|
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
|
||||||
|
int version) const override {
|
||||||
|
if (op_resolver_callbacks_.find_builtin_op == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return op_resolver_callbacks_.find_builtin_op(
|
||||||
|
op_resolver_callbacks_.user_data,
|
||||||
|
static_cast<TfLiteBuiltinOperator>(op), version);
|
||||||
|
}
|
||||||
|
const TfLiteRegistration* FindOp(const char* op, int version) const override {
|
||||||
|
if (op_resolver_callbacks_.find_custom_op == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return op_resolver_callbacks_.find_custom_op(
|
||||||
|
op_resolver_callbacks_.user_data, op, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CallbackOpResolver(const CallbackOpResolver&) = delete;
|
||||||
|
CallbackOpResolver& operator=(const CallbackOpResolver&) = delete;
|
||||||
|
|
||||||
|
struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {};
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
@ -210,14 +249,28 @@ TfLiteInterpreter* InterpreterCreateWithOpResolver(
|
|||||||
new CallbackErrorReporter(optional_options->error_reporter_callback));
|
new CallbackErrorReporter(optional_options->error_reporter_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// By default, we use the provided mutable_op_resolver, adding any builtin or
|
||||||
|
// custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
|
||||||
|
// `TfLiteInterpreterOptionsAddCustomOp`.
|
||||||
|
tflite::OpResolver* op_resolver = mutable_resolver;
|
||||||
if (optional_options) {
|
if (optional_options) {
|
||||||
mutable_resolver->AddAll(optional_options->op_resolver);
|
mutable_resolver->AddAll(optional_options->mutable_op_resolver);
|
||||||
|
}
|
||||||
|
// However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
|
||||||
|
// a non-null callback parameter, then we instead use a
|
||||||
|
// `CallbackOpResolver` that will forward to the callbacks provided there.
|
||||||
|
CallbackOpResolver callback_op_resolver;
|
||||||
|
if (optional_options &&
|
||||||
|
(optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
|
||||||
|
optional_options->op_resolver_callbacks.find_custom_op != nullptr)) {
|
||||||
|
callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
|
||||||
|
op_resolver = &callback_op_resolver;
|
||||||
}
|
}
|
||||||
|
|
||||||
tflite::ErrorReporter* error_reporter = optional_error_reporter
|
tflite::ErrorReporter* error_reporter = optional_error_reporter
|
||||||
? optional_error_reporter.get()
|
? optional_error_reporter.get()
|
||||||
: tflite::DefaultErrorReporter();
|
: tflite::DefaultErrorReporter();
|
||||||
tflite::InterpreterBuilder builder(model->impl->GetModel(), *mutable_resolver,
|
tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
|
||||||
error_reporter);
|
error_reporter);
|
||||||
|
|
||||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/c_api.h"
|
#include "tensorflow/lite/c/c_api.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -38,8 +37,9 @@ void TfLiteInterpreterOptionsAddBuiltinOp(
|
|||||||
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
|
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
|
||||||
const TfLiteRegistration* registration, int32_t min_version,
|
const TfLiteRegistration* registration, int32_t min_version,
|
||||||
int32_t max_version) {
|
int32_t max_version) {
|
||||||
options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op),
|
options->mutable_op_resolver.AddBuiltin(
|
||||||
registration, min_version, max_version);
|
static_cast<tflite::BuiltinOperator>(op), registration, min_version,
|
||||||
|
max_version);
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(
|
TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(
|
||||||
@ -55,7 +55,21 @@ void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options,
|
|||||||
const TfLiteRegistration* registration,
|
const TfLiteRegistration* registration,
|
||||||
int32_t min_version,
|
int32_t min_version,
|
||||||
int32_t max_version) {
|
int32_t max_version) {
|
||||||
options->op_resolver.AddCustom(name, registration, min_version, max_version);
|
options->mutable_op_resolver.AddCustom(name, registration, min_version,
|
||||||
|
max_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TfLiteInterpreterOptionsSetOpResolver(
|
||||||
|
TfLiteInterpreterOptions* options,
|
||||||
|
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
|
||||||
|
TfLiteBuiltinOperator op,
|
||||||
|
int version),
|
||||||
|
const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
|
||||||
|
int version),
|
||||||
|
void* op_resolver_user_data) {
|
||||||
|
options->op_resolver_callbacks.find_builtin_op = find_builtin_op;
|
||||||
|
options->op_resolver_callbacks.find_custom_op = find_custom_op;
|
||||||
|
options->op_resolver_callbacks.user_data = op_resolver_user_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,
|
void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,
|
||||||
|
@ -36,6 +36,9 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors(
|
|||||||
/// so the caller should ensure that its contents (function pointers, etc...)
|
/// so the caller should ensure that its contents (function pointers, etc...)
|
||||||
/// remain valid for the duration of the interpreter's lifetime. A common
|
/// remain valid for the duration of the interpreter's lifetime. A common
|
||||||
/// practice is making the provided `TfLiteRegistration` instance static.
|
/// practice is making the provided `TfLiteRegistration` instance static.
|
||||||
|
///
|
||||||
|
/// Code that uses this function should NOT call
|
||||||
|
/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object.
|
||||||
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
|
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
|
||||||
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
|
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
|
||||||
const TfLiteRegistration* registration, int32_t min_version,
|
const TfLiteRegistration* registration, int32_t min_version,
|
||||||
@ -50,16 +53,41 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
|
|||||||
/// so the caller should ensure that its contents (function pointers, etc...)
|
/// so the caller should ensure that its contents (function pointers, etc...)
|
||||||
/// remain valid for the duration of any created interpreter's lifetime. A
|
/// remain valid for the duration of any created interpreter's lifetime. A
|
||||||
/// common practice is making the provided `TfLiteRegistration` instance static.
|
/// common practice is making the provided `TfLiteRegistration` instance static.
|
||||||
|
///
|
||||||
|
/// Code that uses this function should NOT call
|
||||||
|
/// `TfLiteInterpreterOptionsSetOpResolver' on the same options object.
|
||||||
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
|
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
|
||||||
TfLiteInterpreterOptions* options, const char* name,
|
TfLiteInterpreterOptions* options, const char* name,
|
||||||
const TfLiteRegistration* registration, int32_t min_version,
|
const TfLiteRegistration* registration, int32_t min_version,
|
||||||
int32_t max_version);
|
int32_t max_version);
|
||||||
|
|
||||||
|
/// Registers callbacks for resolving builtin or custom operators.
|
||||||
|
///
|
||||||
|
/// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative
|
||||||
|
/// method for registering builtin ops and/or custom ops, by providing operator
|
||||||
|
/// resolver callbacks. Unlike using `TfLiteInterpreterOptionsAddBuiltinOp`
|
||||||
|
/// and/or `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all
|
||||||
|
/// the operators in a single call.
|
||||||
|
///
|
||||||
|
/// Code that uses this function should NOT call
|
||||||
|
/// `TfLiteInterpreterOptionsAddBuiltin' or
|
||||||
|
/// `TfLiteInterpreterOptionsAddCustomOp' on the same options object.
|
||||||
|
void TfLiteInterpreterOptionsSetOpResolver(
|
||||||
|
TfLiteInterpreterOptions* options,
|
||||||
|
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
|
||||||
|
TfLiteBuiltinOperator op,
|
||||||
|
int version),
|
||||||
|
const TfLiteRegistration* (*find_custom_op)(void* user_data,
|
||||||
|
const char* custom_op,
|
||||||
|
int version),
|
||||||
|
void* op_resolver_user_data);
|
||||||
|
|
||||||
/// Returns a new interpreter using the provided model and options, or null on
|
/// Returns a new interpreter using the provided model and options, or null on
|
||||||
/// failure, where the model uses only the operators explicitly added to the
|
/// failure, where the model uses only the operators explicitly added to the
|
||||||
/// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`,
|
/// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`,
|
||||||
/// except that the only operators that are supported are the ones registered
|
/// except that the only operators that are supported are the ones registered
|
||||||
/// in `options` via calls to `TfLiteInterpreterOptionsAddBuiltinOp` and/or
|
/// in `options` via calls to `TfLiteInterpreterOptionsSetOpResolver`,
|
||||||
|
/// `TfLiteInterpreterOptionsAddBuiltinOp`, and/or
|
||||||
/// `TfLiteInterpreterOptionsAddCustomOp`.
|
/// `TfLiteInterpreterOptionsAddCustomOp`.
|
||||||
///
|
///
|
||||||
/// * `model` must be a valid model instance. The caller retains ownership of
|
/// * `model` must be a valid model instance. The caller retains ownership of
|
||||||
|
@ -110,6 +110,55 @@ TEST(CApiExperimentalTest, MissingBuiltin) {
|
|||||||
TfLiteModelDelete(model);
|
TfLiteModelDelete(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct OpResolverData {
|
||||||
|
bool called_for_add = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
const TfLiteRegistration* MyFindBuiltinOp(void* user_data,
|
||||||
|
TfLiteBuiltinOperator op,
|
||||||
|
int version) {
|
||||||
|
OpResolverData* my_data = static_cast<OpResolverData*>(user_data);
|
||||||
|
if (op == kTfLiteBuiltinAdd && version == 1) {
|
||||||
|
my_data->called_for_add = true;
|
||||||
|
return GetDummyRegistration();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteRegistration* MyFindCustomOp(void*, const char* custom_op,
|
||||||
|
int version) {
|
||||||
|
if (absl::string_view(custom_op) == "foo" && version == 1) {
|
||||||
|
return GetDummyRegistration();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test using TfLiteInterpreterCreateWithSelectedOps.
|
||||||
|
TEST(CApiExperimentalTest, SetOpResolver) {
|
||||||
|
TfLiteModel* model =
|
||||||
|
TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
|
||||||
|
ASSERT_NE(model, nullptr);
|
||||||
|
|
||||||
|
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||||
|
|
||||||
|
OpResolverData my_data;
|
||||||
|
TfLiteInterpreterOptionsSetOpResolver(options, MyFindBuiltinOp,
|
||||||
|
MyFindCustomOp, &my_data);
|
||||||
|
EXPECT_FALSE(my_data.called_for_add);
|
||||||
|
|
||||||
|
TfLiteInterpreter* interpreter =
|
||||||
|
TfLiteInterpreterCreateWithSelectedOps(model, options);
|
||||||
|
ASSERT_NE(interpreter, nullptr);
|
||||||
|
ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
|
||||||
|
EXPECT_EQ(TfLiteInterpreterResetVariableTensors(interpreter), kTfLiteOk);
|
||||||
|
EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
|
||||||
|
EXPECT_TRUE(my_data.called_for_add);
|
||||||
|
|
||||||
|
TfLiteInterpreterDelete(interpreter);
|
||||||
|
TfLiteInterpreterOptionsDelete(options);
|
||||||
|
TfLiteModelDelete(model);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
@ -36,6 +38,30 @@ struct TfLiteModel {
|
|||||||
std::shared_ptr<const tflite::FlatBufferModel> impl;
|
std::shared_ptr<const tflite::FlatBufferModel> impl;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The `TfLiteOpResolver` struct is an abstract callback interface that
|
||||||
|
// contains function pointers for callbacks that return a
|
||||||
|
// `TfLiteRegistration` given an op code or custom op name. This mechanism is
|
||||||
|
// used to map ops referenced in the flatbuffer model to executable function
|
||||||
|
// pointers (`TfLiteRegistration`s).
|
||||||
|
// This struct mirrors the tflite::OpResolver C++ abstract base class.
|
||||||
|
struct TfLiteOpResolverCallbacks {
|
||||||
|
// Opaque data that gets passed down to the callback functions.
|
||||||
|
void* user_data = nullptr;
|
||||||
|
|
||||||
|
// Callback that finds the op registration for a builtin operator by enum
|
||||||
|
// code. The `user_data` parameter will be set to the
|
||||||
|
// `op_resolver_user_data` value that was passed to
|
||||||
|
// `TfLiteInterpreterOptionsSetOpResolver`.
|
||||||
|
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
|
||||||
|
TfLiteBuiltinOperator op,
|
||||||
|
int version);
|
||||||
|
// Callback that finds the op registration of a custom operator by op name.
|
||||||
|
// The `user_data` parameter will be set to the `op_resolver_user_data` value
|
||||||
|
// that was passed to `TfLiteInterpreterOptionsSetOpResolver`.
|
||||||
|
const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
|
||||||
|
int version);
|
||||||
|
};
|
||||||
|
|
||||||
// This struct mirrors the tflite::ErrorResolver C++ abstract base class.
|
// This struct mirrors the tflite::ErrorResolver C++ abstract base class.
|
||||||
struct TfLiteErrorReporterCallback {
|
struct TfLiteErrorReporterCallback {
|
||||||
// Opaque data that gets passed down to the callback function.
|
// Opaque data that gets passed down to the callback function.
|
||||||
@ -52,7 +78,9 @@ struct TfLiteInterpreterOptions {
|
|||||||
};
|
};
|
||||||
int num_threads = kDefaultNumThreads;
|
int num_threads = kDefaultNumThreads;
|
||||||
|
|
||||||
tflite::MutableOpResolver op_resolver;
|
tflite::MutableOpResolver mutable_op_resolver;
|
||||||
|
|
||||||
|
TfLiteOpResolverCallbacks op_resolver_callbacks = {};
|
||||||
|
|
||||||
std::vector<TfLiteDelegate*> delegates;
|
std::vector<TfLiteDelegate*> delegates;
|
||||||
|
|
||||||
@ -79,9 +107,9 @@ namespace internal {
|
|||||||
|
|
||||||
// This adds the builtin and/or custom operators specified in options in
|
// This adds the builtin and/or custom operators specified in options in
|
||||||
// `optional_options` (if any) to `mutable_resolver`, and then returns a newly
|
// `optional_options` (if any) to `mutable_resolver`, and then returns a newly
|
||||||
// created TfLiteInterpreter using `mutable_op_resolver` as the OpResolver, and
|
// created TfLiteInterpreter using `mutable_op_resolver` as the default
|
||||||
// using any other options in `optional_options`, and using the provided
|
// OpResolver, and using any other options in `optional_options`, and using
|
||||||
// `model`.
|
// the provided `model`.
|
||||||
//
|
//
|
||||||
// * `model` must be a valid model instance. The caller retains ownership of the
|
// * `model` must be a valid model instance. The caller retains ownership of the
|
||||||
// object, and can destroy it immediately after creating the interpreter; the
|
// object, and can destroy it immediately after creating the interpreter; the
|
||||||
|
Loading…
Reference in New Issue
Block a user