Allow passing custom op registerers by function

PiperOrigin-RevId: 327132676
Change-Id: Iad8fdaa69f9a5afcf635a472703466f4bc3d1d73
This commit is contained in:
A. Unique TensorFlower 2020-08-17 17:24:07 -07:00 committed by TensorFlower Gardener
parent 94ca496b8a
commit f767a8fa41
5 changed files with 112 additions and 22 deletions

View File

@ -185,8 +185,8 @@ class Interpreter(object):
objects returned by lite.load_delegate(). objects returned by lite.load_delegate().
num_threads: Sets the number of threads used by the interpreter and num_threads: Sets the number of threads used by the interpreter and
available to CPU kernels. If not set, the interpreter will use an available to CPU kernels. If not set, the interpreter will use an
implementation-dependent default number of threads. Currently, implementation-dependent default number of threads. Currently, only a
only a subset of kernels, such as conv, support multi-threading. subset of kernels, such as conv, support multi-threading.
Raises: Raises:
ValueError: If the interpreter was unable to create. ValueError: If the interpreter was unable to create.
@ -194,19 +194,33 @@ class Interpreter(object):
if not hasattr(self, '_custom_op_registerers'): if not hasattr(self, '_custom_op_registerers'):
self._custom_op_registerers = [] self._custom_op_registerers = []
if model_path and not model_content: if model_path and not model_content:
custom_op_registerers_by_name = [
x for x in self._custom_op_registerers if isinstance(x, str)
]
custom_op_registerers_by_func = [
x for x in self._custom_op_registerers if not isinstance(x, str)
]
self._interpreter = ( self._interpreter = (
_interpreter_wrapper.CreateWrapperFromFile( _interpreter_wrapper.CreateWrapperFromFile(
model_path, self._custom_op_registerers)) model_path, custom_op_registerers_by_name,
custom_op_registerers_by_func))
if not self._interpreter: if not self._interpreter:
raise ValueError('Failed to open {}'.format(model_path)) raise ValueError('Failed to open {}'.format(model_path))
elif model_content and not model_path: elif model_content and not model_path:
custom_op_registerers_by_name = [
x for x in self._custom_op_registerers if isinstance(x, str)
]
custom_op_registerers_by_func = [
x for x in self._custom_op_registerers if not isinstance(x, str)
]
# Take a reference, so the pointer remains valid. # Take a reference, so the pointer remains valid.
# Since python strings are immutable then PyString_XX functions # Since python strings are immutable then PyString_XX functions
# will always return the same pointer. # will always return the same pointer.
self._model_content = model_content self._model_content = model_content
self._interpreter = ( self._interpreter = (
_interpreter_wrapper.CreateWrapperFromBuffer( _interpreter_wrapper.CreateWrapperFromBuffer(
model_content, self._custom_op_registerers)) model_content, custom_op_registerers_by_name,
custom_op_registerers_by_func))
elif not model_content and not model_path: elif not model_content and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.') raise ValueError('`model_path` or `model_content` must be specified.')
else: else:
@ -573,8 +587,10 @@ class InterpreterWithCustomOps(Interpreter):
experimental_delegates: Experimental. Subject to change. List of experimental_delegates: Experimental. Subject to change. List of
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
objects returned by lite.load_delegate(). objects returned by lite.load_delegate().
custom_op_registerers: List of str, symbol names of functions that take a custom_op_registerers: List of str (symbol names) or functions that take a
pointer to a MutableOpResolver and register a custom op. pointer to a MutableOpResolver and register a custom op. When passing
functions, use a pybind function that takes a uintptr_t that can be
recast as a pointer to a MutableOpResolver.
Raises: Raises:
ValueError: If the interpreter was unable to create. ValueError: If the interpreter was unable to create.

View File

@ -42,7 +42,7 @@ from tensorflow.python.platform import test
class InterpreterCustomOpsTest(test_util.TensorFlowTestCase): class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
def testRegisterer(self): def testRegistererByName(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps( interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile( model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), 'testdata/permute_float.tflite'),
@ -50,6 +50,14 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
self.assertTrue(interpreter._safe_to_run()) self.assertTrue(interpreter._safe_to_run())
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1) self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
def testRegistererByFunc(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
custom_op_registerers=[test_registerer.TF_TestRegisterer])
self.assertTrue(interpreter._safe_to_run())
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
def testRegistererFailure(self): def testRegistererFailure(self):
bogus_name = 'CompletelyBogusRegistererName' bogus_name = 'CompletelyBogusRegistererName'
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -72,14 +80,16 @@ class InterpreterTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'): with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'):
interpreter_wrapper.Interpreter( interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile( model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=-1) 'testdata/permute_float.tflite'),
num_threads=-1)
def testThreads_WrongType(self): def testThreads_WrongType(self):
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
'type of num_threads should be int'): 'type of num_threads should be int'):
interpreter_wrapper.Interpreter( interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile( model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=4.2) 'testdata/permute_float.tflite'),
num_threads=4.2)
def testFloat(self): def testFloat(self):
interpreter = interpreter_wrapper.Interpreter( interpreter = interpreter_wrapper.Interpreter(
@ -116,7 +126,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
def testFloatWithTwoThreads(self): def testFloatWithTwoThreads(self):
interpreter = interpreter_wrapper.Interpreter( interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile( model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=2) 'testdata/permute_float.tflite'),
num_threads=2)
interpreter.allocate_tensors() interpreter.allocate_tensors()
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
@ -158,8 +169,7 @@ class InterpreterTest(test_util.TensorFlowTestCase):
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
interpreter.resize_tensor_input(input_details[0]['index'], interpreter.resize_tensor_input(input_details[0]['index'], test_input.shape)
test_input.shape)
interpreter.allocate_tensors() interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], test_input) interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke() interpreter.invoke()
@ -267,8 +277,7 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
def testInvalidModelFile(self): def testInvalidModelFile(self):
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
'Could not open \'totally_invalid_file_name\''): 'Could not open \'totally_invalid_file_name\''):
interpreter_wrapper.Interpreter( interpreter_wrapper.Interpreter(model_path='totally_invalid_file_name')
model_path='totally_invalid_file_name')
def testInvokeBeforeReady(self): def testInvokeBeforeReady(self):
interpreter = interpreter_wrapper.Interpreter( interpreter = interpreter_wrapper.Interpreter(
@ -423,16 +432,19 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
self.skipTest('TODO(b/142136355): fix flakiness and re-enable') self.skipTest('TODO(b/142136355): fix flakiness and re-enable')
# Track which order destructions were doned in # Track which order destructions were doned in
destructions = [] destructions = []
def register_destruction(x): def register_destruction(x):
destructions.append( destructions.append(
x if isinstance(x, str) else six.ensure_text(x, 'utf-8')) x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
return 0 return 0
# Make a wrapper for the callback so we can send this to ctypes # Make a wrapper for the callback so we can send this to ctypes
delegate = interpreter_wrapper.load_delegate(self._delegate_file) delegate = interpreter_wrapper.load_delegate(self._delegate_file)
# Make an interpreter with the delegate # Make an interpreter with the delegate
interpreter = interpreter_wrapper.Interpreter( interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile( model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), experimental_delegates=[delegate]) 'testdata/permute_float.tflite'),
experimental_delegates=[delegate])
class InterpreterDestroyCallback(object): class InterpreterDestroyCallback(object):

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <stdarg.h> #include <stdarg.h>
#include <functional>
#include <sstream> #include <sstream>
#include <string> #include <string>
@ -168,17 +169,22 @@ bool RegisterCustomOpByName(const char* registerer_name,
InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
std::unique_ptr<PythonErrorReporter> error_reporter, std::unique_ptr<PythonErrorReporter> error_reporter,
const std::vector<std::string>& registerers, std::string* error_msg) { const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
if (!model) { if (!model) {
*error_msg = error_reporter->message(); *error_msg = error_reporter->message();
return nullptr; return nullptr;
} }
auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
for (const auto& registerer : registerers) { for (const auto& registerer : registerers_by_name) {
if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg)) if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
return nullptr; return nullptr;
} }
for (const auto& registerer : registerers_by_func) {
registerer(reinterpret_cast<uintptr_t>(resolver.get()));
}
auto interpreter = CreateInterpreter(model.get(), *resolver); auto interpreter = CreateInterpreter(model.get(), *resolver);
if (!interpreter) { if (!interpreter) {
*error_msg = error_reporter->message(); *error_msg = error_reporter->message();
@ -655,18 +661,27 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
} }
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers, const char* model_path, const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) { std::string* error_msg) {
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model = std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model =
tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path, tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path,
error_reporter.get()); error_reporter.get());
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
registerers, error_msg); registerers_by_name, registerers_by_func,
error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers,
std::string* error_msg) {
return CreateWrapperCPPFromFile(model_path, registerers, {}, error_msg);
} }
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers, PyObject* data, const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) { std::string* error_msg) {
char* buf = nullptr; char* buf = nullptr;
Py_ssize_t length; Py_ssize_t length;
@ -679,7 +694,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length, tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length,
error_reporter.get()); error_reporter.get());
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
registerers, error_msg); registerers_by_name, registerers_by_func,
error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg) {
return CreateWrapperCPPFromBuffer(data, registerers, {}, error_msg);
} }
PyObject* InterpreterWrapper::ResetVariableTensors() { PyObject* InterpreterWrapper::ResetVariableTensors() {

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
#include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -51,11 +52,20 @@ class InterpreterWrapper {
static InterpreterWrapper* CreateWrapperCPPFromFile( static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers, const char* model_path, const std::vector<std::string>& registerers,
std::string* error_msg); std::string* error_msg);
static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
// SWIG caller takes ownership of pointer. // SWIG caller takes ownership of pointer.
static InterpreterWrapper* CreateWrapperCPPFromBuffer( static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers, PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg); std::string* error_msg);
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
~InterpreterWrapper(); ~InterpreterWrapper();
PyObject* AllocateTensors(); PyObject* AllocateTensors();
@ -106,7 +116,9 @@ class InterpreterWrapper {
static InterpreterWrapper* CreateInterpreterWrapper( static InterpreterWrapper* CreateInterpreterWrapper(
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
std::unique_ptr<PythonErrorReporter> error_reporter, std::unique_ptr<PythonErrorReporter> error_reporter,
const std::vector<std::string>& registerers, std::string* error_msg); const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
InterpreterWrapper( InterpreterWrapper(
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model, std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "pybind11/functional.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/pytypes.h" #include "pybind11/pytypes.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
@ -42,6 +43,20 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
} }
return wrapper; return wrapper;
}); });
m.def("CreateWrapperFromFile",
[](const std::string& model_path,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>&
registerers_by_func) {
std::string error;
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
model_path.c_str(), registerers_by_name, registerers_by_func,
&error);
if (!wrapper) {
throw std::invalid_argument(error);
}
return wrapper;
});
m.def("CreateWrapperFromBuffer", m.def("CreateWrapperFromBuffer",
[](const py::bytes& data, const std::vector<std::string>& registerers) { [](const py::bytes& data, const std::vector<std::string>& registerers) {
std::string error; std::string error;
@ -52,6 +67,19 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
} }
return wrapper; return wrapper;
}); });
m.def("CreateWrapperFromBuffer",
[](const py::bytes& data,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>&
registerers_by_func) {
std::string error;
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
data.ptr(), registerers_by_name, registerers_by_func, &error);
if (!wrapper) {
throw std::invalid_argument(error);
}
return wrapper;
});
py::class_<InterpreterWrapper>(m, "InterpreterWrapper") py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
.def("AllocateTensors", .def("AllocateTensors",
[](InterpreterWrapper& self) { [](InterpreterWrapper& self) {