Allow passing custom op registerers by function
PiperOrigin-RevId: 327132676 Change-Id: Iad8fdaa69f9a5afcf635a472703466f4bc3d1d73
This commit is contained in:
parent
94ca496b8a
commit
f767a8fa41
tensorflow/lite/python
@ -185,8 +185,8 @@ class Interpreter(object):
|
||||
objects returned by lite.load_delegate().
|
||||
num_threads: Sets the number of threads used by the interpreter and
|
||||
available to CPU kernels. If not set, the interpreter will use an
|
||||
implementation-dependent default number of threads. Currently,
|
||||
only a subset of kernels, such as conv, support multi-threading.
|
||||
implementation-dependent default number of threads. Currently, only a
|
||||
subset of kernels, such as conv, support multi-threading.
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter was unable to create.
|
||||
@ -194,19 +194,33 @@ class Interpreter(object):
|
||||
if not hasattr(self, '_custom_op_registerers'):
|
||||
self._custom_op_registerers = []
|
||||
if model_path and not model_content:
|
||||
custom_op_registerers_by_name = [
|
||||
x for x in self._custom_op_registerers if isinstance(x, str)
|
||||
]
|
||||
custom_op_registerers_by_func = [
|
||||
x for x in self._custom_op_registerers if not isinstance(x, str)
|
||||
]
|
||||
self._interpreter = (
|
||||
_interpreter_wrapper.CreateWrapperFromFile(
|
||||
model_path, self._custom_op_registerers))
|
||||
model_path, custom_op_registerers_by_name,
|
||||
custom_op_registerers_by_func))
|
||||
if not self._interpreter:
|
||||
raise ValueError('Failed to open {}'.format(model_path))
|
||||
elif model_content and not model_path:
|
||||
custom_op_registerers_by_name = [
|
||||
x for x in self._custom_op_registerers if isinstance(x, str)
|
||||
]
|
||||
custom_op_registerers_by_func = [
|
||||
x for x in self._custom_op_registerers if not isinstance(x, str)
|
||||
]
|
||||
# Take a reference, so the pointer remains valid.
|
||||
# Since python strings are immutable then PyString_XX functions
|
||||
# will always return the same pointer.
|
||||
self._model_content = model_content
|
||||
self._interpreter = (
|
||||
_interpreter_wrapper.CreateWrapperFromBuffer(
|
||||
model_content, self._custom_op_registerers))
|
||||
model_content, custom_op_registerers_by_name,
|
||||
custom_op_registerers_by_func))
|
||||
elif not model_content and not model_path:
|
||||
raise ValueError('`model_path` or `model_content` must be specified.')
|
||||
else:
|
||||
@ -573,8 +587,10 @@ class InterpreterWithCustomOps(Interpreter):
|
||||
experimental_delegates: Experimental. Subject to change. List of
|
||||
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
|
||||
objects returned by lite.load_delegate().
|
||||
custom_op_registerers: List of str, symbol names of functions that take a
|
||||
pointer to a MutableOpResolver and register a custom op.
|
||||
custom_op_registerers: List of str (symbol names) or functions that take a
|
||||
pointer to a MutableOpResolver and register a custom op. When passing
|
||||
functions, use a pybind function that takes a uintptr_t that can be
|
||||
recast as a pointer to a MutableOpResolver.
|
||||
|
||||
Raises:
|
||||
ValueError: If the interpreter was unable to create.
|
||||
|
@ -42,7 +42,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testRegisterer(self):
|
||||
def testRegistererByName(self):
|
||||
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'),
|
||||
@ -50,6 +50,14 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(interpreter._safe_to_run())
|
||||
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
|
||||
|
||||
def testRegistererByFunc(self):
|
||||
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'),
|
||||
custom_op_registerers=[test_registerer.TF_TestRegisterer])
|
||||
self.assertTrue(interpreter._safe_to_run())
|
||||
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
|
||||
|
||||
def testRegistererFailure(self):
|
||||
bogus_name = 'CompletelyBogusRegistererName'
|
||||
with self.assertRaisesRegex(
|
||||
@ -72,14 +80,16 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=-1)
|
||||
'testdata/permute_float.tflite'),
|
||||
num_threads=-1)
|
||||
|
||||
def testThreads_WrongType(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'type of num_threads should be int'):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=4.2)
|
||||
'testdata/permute_float.tflite'),
|
||||
num_threads=4.2)
|
||||
|
||||
def testFloat(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
@ -116,7 +126,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
def testFloatWithTwoThreads(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), num_threads=2)
|
||||
'testdata/permute_float.tflite'),
|
||||
num_threads=2)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
@ -158,8 +169,7 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
|
||||
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
|
||||
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
|
||||
interpreter.resize_tensor_input(input_details[0]['index'],
|
||||
test_input.shape)
|
||||
interpreter.resize_tensor_input(input_details[0]['index'], test_input.shape)
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
@ -267,8 +277,7 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
def testInvalidModelFile(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Could not open \'totally_invalid_file_name\''):
|
||||
interpreter_wrapper.Interpreter(
|
||||
model_path='totally_invalid_file_name')
|
||||
interpreter_wrapper.Interpreter(model_path='totally_invalid_file_name')
|
||||
|
||||
def testInvokeBeforeReady(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
@ -423,16 +432,19 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
|
||||
self.skipTest('TODO(b/142136355): fix flakiness and re-enable')
|
||||
# Track which order destructions were doned in
|
||||
destructions = []
|
||||
|
||||
def register_destruction(x):
|
||||
destructions.append(
|
||||
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
|
||||
return 0
|
||||
|
||||
# Make a wrapper for the callback so we can send this to ctypes
|
||||
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
|
||||
# Make an interpreter with the delegate
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/permute_float.tflite'), experimental_delegates=[delegate])
|
||||
'testdata/permute_float.tflite'),
|
||||
experimental_delegates=[delegate])
|
||||
|
||||
class InterpreterDestroyCallback(object):
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <stdarg.h>
|
||||
|
||||
#include <functional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
@ -168,17 +169,22 @@ bool RegisterCustomOpByName(const char* registerer_name,
|
||||
InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
|
||||
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter,
|
||||
const std::vector<std::string>& registerers, std::string* error_msg) {
|
||||
const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
||||
std::string* error_msg) {
|
||||
if (!model) {
|
||||
*error_msg = error_reporter->message();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
for (const auto& registerer : registerers) {
|
||||
for (const auto& registerer : registerers_by_name) {
|
||||
if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
|
||||
return nullptr;
|
||||
}
|
||||
for (const auto& registerer : registerers_by_func) {
|
||||
registerer(reinterpret_cast<uintptr_t>(resolver.get()));
|
||||
}
|
||||
auto interpreter = CreateInterpreter(model.get(), *resolver);
|
||||
if (!interpreter) {
|
||||
*error_msg = error_reporter->message();
|
||||
@ -655,18 +661,27 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
|
||||
}
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
|
||||
const char* model_path, const std::vector<std::string>& registerers,
|
||||
const char* model_path, const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
||||
std::string* error_msg) {
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
||||
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model =
|
||||
tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path,
|
||||
error_reporter.get());
|
||||
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
|
||||
registerers, error_msg);
|
||||
registerers_by_name, registerers_by_func,
|
||||
error_msg);
|
||||
}
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
|
||||
const char* model_path, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg) {
|
||||
return CreateWrapperCPPFromFile(model_path, registerers, {}, error_msg);
|
||||
}
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, const std::vector<std::string>& registerers,
|
||||
PyObject* data, const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
||||
std::string* error_msg) {
|
||||
char* buf = nullptr;
|
||||
Py_ssize_t length;
|
||||
@ -679,7 +694,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length,
|
||||
error_reporter.get());
|
||||
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
|
||||
registerers, error_msg);
|
||||
registerers_by_name, registerers_by_func,
|
||||
error_msg);
|
||||
}
|
||||
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg) {
|
||||
return CreateWrapperCPPFromBuffer(data, registerers, {}, error_msg);
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::ResetVariableTensors() {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
|
||||
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -51,11 +52,20 @@ class InterpreterWrapper {
|
||||
static InterpreterWrapper* CreateWrapperCPPFromFile(
|
||||
const char* model_path, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg);
|
||||
static InterpreterWrapper* CreateWrapperCPPFromFile(
|
||||
const char* model_path,
|
||||
const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
||||
std::string* error_msg);
|
||||
|
||||
// SWIG caller takes ownership of pointer.
|
||||
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg);
|
||||
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
||||
std::string* error_msg);
|
||||
|
||||
~InterpreterWrapper();
|
||||
PyObject* AllocateTensors();
|
||||
@ -106,7 +116,9 @@ class InterpreterWrapper {
|
||||
static InterpreterWrapper* CreateInterpreterWrapper(
|
||||
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter,
|
||||
const std::vector<std::string>& registerers, std::string* error_msg);
|
||||
const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
|
||||
std::string* error_msg);
|
||||
|
||||
InterpreterWrapper(
|
||||
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "pybind11/functional.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "pybind11/stl.h"
|
||||
@ -42,6 +43,20 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
||||
}
|
||||
return wrapper;
|
||||
});
|
||||
m.def("CreateWrapperFromFile",
|
||||
[](const std::string& model_path,
|
||||
const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>&
|
||||
registerers_by_func) {
|
||||
std::string error;
|
||||
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
|
||||
model_path.c_str(), registerers_by_name, registerers_by_func,
|
||||
&error);
|
||||
if (!wrapper) {
|
||||
throw std::invalid_argument(error);
|
||||
}
|
||||
return wrapper;
|
||||
});
|
||||
m.def("CreateWrapperFromBuffer",
|
||||
[](const py::bytes& data, const std::vector<std::string>& registerers) {
|
||||
std::string error;
|
||||
@ -52,6 +67,19 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
|
||||
}
|
||||
return wrapper;
|
||||
});
|
||||
m.def("CreateWrapperFromBuffer",
|
||||
[](const py::bytes& data,
|
||||
const std::vector<std::string>& registerers_by_name,
|
||||
const std::vector<std::function<void(uintptr_t)>>&
|
||||
registerers_by_func) {
|
||||
std::string error;
|
||||
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
data.ptr(), registerers_by_name, registerers_by_func, &error);
|
||||
if (!wrapper) {
|
||||
throw std::invalid_argument(error);
|
||||
}
|
||||
return wrapper;
|
||||
});
|
||||
py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
|
||||
.def("AllocateTensors",
|
||||
[](InterpreterWrapper& self) {
|
||||
|
Loading…
Reference in New Issue
Block a user