parent
553a3b826a
commit
752556174a
@ -33,6 +33,7 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":interpreter",
|
||||
"//tensorflow/lite/python/testdata:test_registerer_wrapper",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
|
@ -200,10 +200,12 @@ class Interpreter(object):
|
||||
Raises:
|
||||
ValueError: If the interpreter was unable to create.
|
||||
"""
|
||||
if not hasattr(self, '_custom_op_registerers'):
|
||||
self._custom_op_registerers = []
|
||||
if model_path and not model_content:
|
||||
self._interpreter = (
|
||||
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
|
||||
model_path))
|
||||
model_path, self._custom_op_registerers))
|
||||
if not self._interpreter:
|
||||
raise ValueError('Failed to open {}'.format(model_path))
|
||||
elif model_content and not model_path:
|
||||
@ -213,7 +215,7 @@ class Interpreter(object):
|
||||
self._model_content = model_content
|
||||
self._interpreter = (
|
||||
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
|
||||
model_content))
|
||||
model_content, self._custom_op_registerers))
|
||||
elif not model_path and not model_path:
|
||||
raise ValueError('`model_path` or `model_content` must be specified.')
|
||||
else:
|
||||
@ -454,3 +456,40 @@ class Interpreter(object):
|
||||
|
||||
def reset_all_variables(self):
|
||||
return self._interpreter.ResetVariableTensors()
|
||||
|
||||
|
||||
class InterpreterWithCustomOps(Interpreter):
|
||||
"""Interpreter interface for TensorFlow Lite Models that accepts custom ops.
|
||||
|
||||
The interface provided by this class is experimenal and therefore not exposed
|
||||
as part of the public API.
|
||||
|
||||
Wraps the tf.lite.Interpreter class and adds the ability to load custom ops
|
||||
by providing the names of functions that take a pointer to a BuiltinOpResolver
|
||||
and add a custom op.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_path=None,
|
||||
model_content=None,
|
||||
experimental_delegates=None,
|
||||
custom_op_registerers=None):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
model_path: Path to TF-Lite Flatbuffer file.
|
||||
model_content: Content of model.
|
||||
experimental_delegates: Experimental. Subject to change. List of
|
||||
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
|
||||
objects returned by lite.load_delegate().
|
||||
custom_op_registerers: List of str, symbol names of functions that take a
|
||||
pointer to a MutableOpResolver and register a custom op.
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter was unable to create.
|
||||
"""
|
||||
self._custom_op_registerers = custom_op_registerers
|
||||
super(InterpreterWithCustomOps, self).__init__(
|
||||
model_path=model_path,
|
||||
model_content=model_content,
|
||||
experimental_delegates=experimental_delegates)
|
||||
|
@ -23,10 +23,39 @@ import sys
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
# Force loaded shared object symbols to be globally visible. This is needed so
|
||||
# that the interpreter_wrapper, in one .so file, can see the test_registerer,
|
||||
# in a different .so file. Note that this may already be set by default.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
|
||||
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
||||
|
||||
from tensorflow.lite.python import interpreter as interpreter_wrapper
|
||||
from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testRegisterer(self):
|
||||
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'),
|
||||
custom_op_registerers=['TF_TestRegisterer'])
|
||||
self.assertTrue(interpreter._safe_to_run())
|
||||
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
|
||||
|
||||
def testRegistererFailure(self):
|
||||
bogus_name = 'CompletelyBogusRegistererName'
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
|
||||
interpreter_wrapper.InterpreterWithCustomOps(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'),
|
||||
custom_op_registerers=[bogus_name])
|
||||
|
||||
|
||||
class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
|
@ -28,9 +28,11 @@ cc_library(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
@ -60,6 +62,7 @@ tf_py_wrap_cc(
|
||||
srcs = [
|
||||
"interpreter_wrapper.i",
|
||||
],
|
||||
copts = ["-fexceptions"],
|
||||
deps = [
|
||||
":interpreter_wrapper_lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
|
@ -14,11 +14,22 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
||||
|
||||
// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead.
|
||||
#if defined(_WIN32)
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
@ -82,18 +93,60 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool RegisterCustomOpByName(const char* registerer_name,
|
||||
tflite::MutableOpResolver* resolver,
|
||||
std::string* error_msg) {
|
||||
// Registerer functions take a pointer to a BuiltinOpResolver as an input
|
||||
// parameter and return void.
|
||||
// TODO(b/137576229): We should implement this functionality in a more
|
||||
// principled way.
|
||||
typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
|
||||
|
||||
// Look for the Registerer function by name.
|
||||
RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
|
||||
// We don't have dlsym on Windows, use GetProcAddress instead.
|
||||
#if defined(_WIN32)
|
||||
GetProcAddress(nullptr, registerer_name)
|
||||
#else
|
||||
dlsym(RTLD_DEFAULT, registerer_name)
|
||||
#endif // defined(_WIN32)
|
||||
);
|
||||
|
||||
// Fail in an informative way if the function was not found.
|
||||
if (registerer == nullptr) {
|
||||
// We don't have dlerror on Windows, use GetLastError instead.
|
||||
*error_msg =
|
||||
#if defined(_WIN32)
|
||||
absl::StrFormat("Looking up symbol '%s' failed with error (0x%x).",
|
||||
registerer_name, GetLastError());
|
||||
#else
|
||||
absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
|
||||
registerer_name, dlerror());
|
||||
#endif // defined(_WIN32)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Call the registerer with the resolver.
|
||||
registerer(resolver);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
|
||||
std::unique_ptr<tflite::FlatBufferModel> model,
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter,
|
||||
std::string* error_msg) {
|
||||
const std::vector<std::string>& registerers, std::string* error_msg) {
|
||||
if (!model) {
|
||||
*error_msg = error_reporter->message();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
for (const auto registerer : registerers) {
|
||||
if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
|
||||
return nullptr;
|
||||
}
|
||||
auto interpreter = CreateInterpreter(model.get(), *resolver);
|
||||
if (!interpreter) {
|
||||
*error_msg = error_reporter->message();
|
||||
@ -417,16 +470,18 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
|
||||
}
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
|
||||
const char* model_path, std::string* error_msg) {
|
||||
const char* model_path, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg) {
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
||||
std::unique_ptr<tflite::FlatBufferModel> model =
|
||||
tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
|
||||
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
|
||||
error_msg);
|
||||
registerers, error_msg);
|
||||
}
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, std::string* error_msg) {
|
||||
PyObject* data, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg) {
|
||||
char * buf = nullptr;
|
||||
Py_ssize_t length;
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
||||
@ -438,7 +493,7 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
|
||||
error_reporter.get());
|
||||
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
|
||||
error_msg);
|
||||
registerers, error_msg);
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::ResetVariableTensors() {
|
||||
|
@ -46,12 +46,14 @@ class PythonErrorReporter;
|
||||
class InterpreterWrapper {
|
||||
public:
|
||||
// SWIG caller takes ownership of pointer.
|
||||
static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path,
|
||||
std::string* error_msg);
|
||||
static InterpreterWrapper* CreateWrapperCPPFromFile(
|
||||
const char* model_path, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg);
|
||||
|
||||
// SWIG caller takes ownership of pointer.
|
||||
static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data,
|
||||
std::string* error_msg);
|
||||
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg);
|
||||
|
||||
~InterpreterWrapper();
|
||||
PyObject* AllocateTensors();
|
||||
@ -84,7 +86,7 @@ class InterpreterWrapper {
|
||||
static InterpreterWrapper* CreateInterpreterWrapper(
|
||||
std::unique_ptr<tflite::FlatBufferModel> model,
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter,
|
||||
std::string* error_msg);
|
||||
const std::vector<std::string>& registerers, std::string* error_msg);
|
||||
|
||||
InterpreterWrapper(
|
||||
std::unique_ptr<tflite::FlatBufferModel> model,
|
||||
|
@ -33,6 +33,25 @@ limitations under the License.
|
||||
$result = PyLong_FromVoidPtr($1)
|
||||
}
|
||||
|
||||
// Converts a Python list of str to a std::vector<std::string>, returns true
|
||||
// if the conversion was successful.
|
||||
%{
|
||||
static bool PyListToStdVectorString(PyObject *list, std::vector<std::string> *strings) {
|
||||
// Make sure the list is actually a list.
|
||||
if (!PyList_Check(list)) return false;
|
||||
|
||||
// Convert the Python list to a vector of strings.
|
||||
const int list_size = PyList_Size(list);
|
||||
strings->resize(list_size);
|
||||
for (int k = 0; k < list_size; k++) {
|
||||
PyObject *string_py = PyList_GetItem(list, k);
|
||||
if (!PyString_Check(string_py)) return false;
|
||||
(*strings)[k] = std::string(PyString_AsString(string_py));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
%}
|
||||
bool PyListToStdVectorString(PyObject *list, std::vector<std::string> *strings);
|
||||
|
||||
%include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
|
||||
|
||||
@ -42,12 +61,19 @@ namespace interpreter_wrapper {
|
||||
|
||||
// Version of the constructor that handles producing Python exceptions
|
||||
// that propagate strings.
|
||||
static PyObject* CreateWrapperCPPFromFile(const char* model_path) {
|
||||
static PyObject* CreateWrapperCPPFromFile(
|
||||
const char* model_path,
|
||||
PyObject* registerers_py) {
|
||||
std::string error;
|
||||
std::vector<std::string> registerers;
|
||||
if (!PyListToStdVectorString(registerers_py, ®isterers)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Second argument is expected to be a list of strings.");
|
||||
return nullptr;
|
||||
}
|
||||
if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
|
||||
tflite::interpreter_wrapper::InterpreterWrapper
|
||||
::CreateWrapperCPPFromFile(
|
||||
model_path, &error)) {
|
||||
model_path, registerers, &error)) {
|
||||
return SWIG_NewPointerObj(
|
||||
ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
|
||||
} else {
|
||||
@ -59,12 +85,18 @@ namespace interpreter_wrapper {
|
||||
// Version of the constructor that handles producing Python exceptions
|
||||
// that propagate strings.
|
||||
static PyObject* CreateWrapperCPPFromBuffer(
|
||||
PyObject* data) {
|
||||
PyObject* data ,
|
||||
PyObject* registerers_py) {
|
||||
std::string error;
|
||||
std::vector<std::string> registerers;
|
||||
if (!PyListToStdVectorString(registerers_py, ®isterers)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Second argument is expected to be a list of strings.");
|
||||
return nullptr;
|
||||
}
|
||||
if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
|
||||
tflite::interpreter_wrapper::InterpreterWrapper
|
||||
::CreateWrapperCPPFromBuffer(
|
||||
data, &error)) {
|
||||
data, registerers, &error)) {
|
||||
return SWIG_NewPointerObj(
|
||||
ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
|
||||
} else {
|
||||
|
26
tensorflow/lite/python/testdata/BUILD
vendored
26
tensorflow/lite/python/testdata/BUILD
vendored
@ -1,8 +1,9 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
licenses = ["notice"], # Apache 2.0,
|
||||
)
|
||||
|
||||
exports_files(glob(["*.pb"]))
|
||||
@ -71,3 +72,26 @@ cc_binary(
|
||||
":test_delegate",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_registerer",
|
||||
srcs = ["test_registerer.cc"],
|
||||
hdrs = ["test_registerer.h"],
|
||||
visibility = ["//tensorflow/lite:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_py_wrap_cc(
|
||||
name = "test_registerer_wrapper",
|
||||
srcs = [
|
||||
"test_registerer.i",
|
||||
],
|
||||
deps = [
|
||||
":test_registerer",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
37
tensorflow/lite/python/testdata/test_registerer.cc
vendored
Normal file
37
tensorflow/lite/python/testdata/test_registerer.cc
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/python/testdata/test_registerer.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace {
|
||||
static int num_test_registerer_calls = 0;
|
||||
} // namespace
|
||||
|
||||
// Dummy registerer function with the correct signature. Ignores the resolver
|
||||
// but increments the num_test_registerer_calls counter by one. The TF_ prefix
|
||||
// is needed to get past the version script in the OSS build.
|
||||
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) {
|
||||
num_test_registerer_calls++;
|
||||
}
|
||||
|
||||
// Returns the num_test_registerer_calls counter and re-sets it.
|
||||
int get_num_test_registerer_calls() {
|
||||
const int result = num_test_registerer_calls;
|
||||
num_test_registerer_calls = 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
32
tensorflow/lite/python/testdata/test_registerer.h
vendored
Normal file
32
tensorflow/lite/python/testdata/test_registerer.h
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_
|
||||
#define TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_
|
||||
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Dummy registerer function with the correct signature. Ignores the resolver
|
||||
// but increments the num_test_registerer_calls counter by one. The TF_ prefix
|
||||
// is needed to get past the version script in the OSS build.
|
||||
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver);
|
||||
|
||||
// Returns the num_test_registerer_calls counter and re-sets it.
|
||||
int get_num_test_registerer_calls();
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_
|
20
tensorflow/lite/python/testdata/test_registerer.i
vendored
Normal file
20
tensorflow/lite/python/testdata/test_registerer.i
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%{
|
||||
#include "tensorflow/lite/python/testdata/test_registerer.h"
|
||||
%}
|
||||
|
||||
%include "tensorflow/lite/python/testdata/test_registerer.h"
|
Loading…
Reference in New Issue
Block a user