Allow passing custom op registerers by function

PiperOrigin-RevId: 327132676
Change-Id: Iad8fdaa69f9a5afcf635a472703466f4bc3d1d73
This commit is contained in:
A. Unique TensorFlower 2020-08-17 17:24:07 -07:00 committed by TensorFlower Gardener
parent 94ca496b8a
commit f767a8fa41
5 changed files with 112 additions and 22 deletions

View File

@ -185,8 +185,8 @@ class Interpreter(object):
objects returned by lite.load_delegate().
num_threads: Sets the number of threads used by the interpreter and
available to CPU kernels. If not set, the interpreter will use an
implementation-dependent default number of threads. Currently,
only a subset of kernels, such as conv, support multi-threading.
implementation-dependent default number of threads. Currently, only a
subset of kernels, such as conv, support multi-threading.
Raises:
ValueError: If the interpreter was unable to create.
@ -194,19 +194,33 @@ class Interpreter(object):
if not hasattr(self, '_custom_op_registerers'):
self._custom_op_registerers = []
if model_path and not model_content:
custom_op_registerers_by_name = [
x for x in self._custom_op_registerers if isinstance(x, str)
]
custom_op_registerers_by_func = [
x for x in self._custom_op_registerers if not isinstance(x, str)
]
self._interpreter = (
_interpreter_wrapper.CreateWrapperFromFile(
model_path, self._custom_op_registerers))
model_path, custom_op_registerers_by_name,
custom_op_registerers_by_func))
if not self._interpreter:
raise ValueError('Failed to open {}'.format(model_path))
elif model_content and not model_path:
custom_op_registerers_by_name = [
x for x in self._custom_op_registerers if isinstance(x, str)
]
custom_op_registerers_by_func = [
x for x in self._custom_op_registerers if not isinstance(x, str)
]
# Take a reference, so the pointer remains valid.
# Since python strings are immutable then PyString_XX functions
# will always return the same pointer.
self._model_content = model_content
self._interpreter = (
_interpreter_wrapper.CreateWrapperFromBuffer(
model_content, self._custom_op_registerers))
model_content, custom_op_registerers_by_name,
custom_op_registerers_by_func))
elif not model_content and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.')
else:
@ -573,8 +587,10 @@ class InterpreterWithCustomOps(Interpreter):
experimental_delegates: Experimental. Subject to change. List of
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
objects returned by lite.load_delegate().
custom_op_registerers: List of str, symbol names of functions that take a
pointer to a MutableOpResolver and register a custom op.
custom_op_registerers: List of str (symbol names) or functions that take a
pointer to a MutableOpResolver and register a custom op. When passing
functions, use a pybind function that takes a uintptr_t that can be
recast as a pointer to a MutableOpResolver.
Raises:
ValueError: If the interpreter was unable to create.

View File

@ -42,7 +42,7 @@ from tensorflow.python.platform import test
class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
def testRegisterer(self):
def testRegistererByName(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
@ -50,6 +50,14 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
self.assertTrue(interpreter._safe_to_run())
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
def testRegistererByFunc(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
custom_op_registerers=[test_registerer.TF_TestRegisterer])
self.assertTrue(interpreter._safe_to_run())
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
def testRegistererFailure(self):
bogus_name = 'CompletelyBogusRegistererName'
with self.assertRaisesRegex(
@ -72,14 +80,16 @@ class InterpreterTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'):
interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=-1)
'testdata/permute_float.tflite'),
num_threads=-1)
def testThreads_WrongType(self):
with self.assertRaisesRegex(ValueError,
'type of num_threads should be int'):
interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=4.2)
'testdata/permute_float.tflite'),
num_threads=4.2)
def testFloat(self):
interpreter = interpreter_wrapper.Interpreter(
@ -116,7 +126,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
def testFloatWithTwoThreads(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), num_threads=2)
'testdata/permute_float.tflite'),
num_threads=2)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
@ -158,8 +169,7 @@ class InterpreterTest(test_util.TensorFlowTestCase):
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
interpreter.resize_tensor_input(input_details[0]['index'],
test_input.shape)
interpreter.resize_tensor_input(input_details[0]['index'], test_input.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
@ -267,8 +277,7 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
def testInvalidModelFile(self):
with self.assertRaisesRegex(ValueError,
'Could not open \'totally_invalid_file_name\''):
interpreter_wrapper.Interpreter(
model_path='totally_invalid_file_name')
interpreter_wrapper.Interpreter(model_path='totally_invalid_file_name')
def testInvokeBeforeReady(self):
interpreter = interpreter_wrapper.Interpreter(
@ -423,16 +432,19 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase):
self.skipTest('TODO(b/142136355): fix flakiness and re-enable')
# Track which order destructions were doned in
destructions = []
def register_destruction(x):
destructions.append(
x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
return 0
# Make a wrapper for the callback so we can send this to ctypes
delegate = interpreter_wrapper.load_delegate(self._delegate_file)
# Make an interpreter with the delegate
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'), experimental_delegates=[delegate])
'testdata/permute_float.tflite'),
experimental_delegates=[delegate])
class InterpreterDestroyCallback(object):

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <stdarg.h>
#include <functional>
#include <sstream>
#include <string>
@ -168,17 +169,22 @@ bool RegisterCustomOpByName(const char* registerer_name,
InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
std::unique_ptr<PythonErrorReporter> error_reporter,
const std::vector<std::string>& registerers, std::string* error_msg) {
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
if (!model) {
*error_msg = error_reporter->message();
return nullptr;
}
auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
for (const auto& registerer : registerers) {
for (const auto& registerer : registerers_by_name) {
if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
return nullptr;
}
for (const auto& registerer : registerers_by_func) {
registerer(reinterpret_cast<uintptr_t>(resolver.get()));
}
auto interpreter = CreateInterpreter(model.get(), *resolver);
if (!interpreter) {
*error_msg = error_reporter->message();
@ -655,18 +661,27 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers,
const char* model_path, const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model =
tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path,
error_reporter.get());
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
registerers, error_msg);
registerers_by_name, registerers_by_func,
error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers,
std::string* error_msg) {
return CreateWrapperCPPFromFile(model_path, registerers, {}, error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers,
PyObject* data, const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg) {
char* buf = nullptr;
Py_ssize_t length;
@ -679,7 +694,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length,
error_reporter.get());
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
registerers, error_msg);
registerers_by_name, registerers_by_func,
error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg) {
return CreateWrapperCPPFromBuffer(data, registerers, {}, error_msg);
}
PyObject* InterpreterWrapper::ResetVariableTensors() {

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
@ -51,11 +52,20 @@ class InterpreterWrapper {
static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers,
std::string* error_msg);
static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
// SWIG caller takes ownership of pointer.
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg);
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
~InterpreterWrapper();
PyObject* AllocateTensors();
@ -106,7 +116,9 @@ class InterpreterWrapper {
static InterpreterWrapper* CreateInterpreterWrapper(
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,
std::unique_ptr<PythonErrorReporter> error_reporter,
const std::vector<std::string>& registerers, std::string* error_msg);
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
std::string* error_msg);
InterpreterWrapper(
std::unique_ptr<tflite_api_dispatcher::TfLiteModel> model,

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/functional.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
@ -42,6 +43,20 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
}
return wrapper;
});
m.def("CreateWrapperFromFile",
[](const std::string& model_path,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>&
registerers_by_func) {
std::string error;
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile(
model_path.c_str(), registerers_by_name, registerers_by_func,
&error);
if (!wrapper) {
throw std::invalid_argument(error);
}
return wrapper;
});
m.def("CreateWrapperFromBuffer",
[](const py::bytes& data, const std::vector<std::string>& registerers) {
std::string error;
@ -52,6 +67,19 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) {
}
return wrapper;
});
m.def("CreateWrapperFromBuffer",
[](const py::bytes& data,
const std::vector<std::string>& registerers_by_name,
const std::vector<std::function<void(uintptr_t)>>&
registerers_by_func) {
std::string error;
auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer(
data.ptr(), registerers_by_name, registerers_by_func, &error);
if (!wrapper) {
throw std::invalid_argument(error);
}
return wrapper;
});
py::class_<InterpreterWrapper>(m, "InterpreterWrapper")
.def("AllocateTensors",
[](InterpreterWrapper& self) {