From f767a8fa41d5dd907365bcefba84ce7f69bc0a5d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 17 Aug 2020 17:24:07 -0700 Subject: [PATCH] Allow passing custom op registerers by function PiperOrigin-RevId: 327132676 Change-Id: Iad8fdaa69f9a5afcf635a472703466f4bc3d1d73 --- tensorflow/lite/python/interpreter.py | 28 +++++++++++---- tensorflow/lite/python/interpreter_test.py | 30 +++++++++++----- .../interpreter_wrapper.cc | 34 +++++++++++++++---- .../interpreter_wrapper/interpreter_wrapper.h | 14 +++++++- .../interpreter_wrapper_pybind11.cc | 28 +++++++++++++++ 5 files changed, 112 insertions(+), 22 deletions(-) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 12ee41d6dee..d0ee2dbc700 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -185,8 +185,8 @@ class Interpreter(object): objects returned by lite.load_delegate(). num_threads: Sets the number of threads used by the interpreter and available to CPU kernels. If not set, the interpreter will use an - implementation-dependent default number of threads. Currently, - only a subset of kernels, such as conv, support multi-threading. + implementation-dependent default number of threads. Currently, only a + subset of kernels, such as conv, support multi-threading. Raises: ValueError: If the interpreter was unable to create. @@ -194,19 +194,33 @@ class Interpreter(object): if not hasattr(self, '_custom_op_registerers'): self._custom_op_registerers = [] if model_path and not model_content: + custom_op_registerers_by_name = [ + x for x in self._custom_op_registerers if isinstance(x, str) + ] + custom_op_registerers_by_func = [ + x for x in self._custom_op_registerers if not isinstance(x, str) + ] self._interpreter = ( _interpreter_wrapper.CreateWrapperFromFile( - model_path, self._custom_op_registerers)) + model_path, custom_op_registerers_by_name, + custom_op_registerers_by_func)) if not self._interpreter: raise ValueError('Failed to open {}'.format(model_path)) elif model_content and not model_path: + custom_op_registerers_by_name = [ + x for x in self._custom_op_registerers if isinstance(x, str) + ] + custom_op_registerers_by_func = [ + x for x in self._custom_op_registerers if not isinstance(x, str) + ] # Take a reference, so the pointer remains valid. # Since python strings are immutable then PyString_XX functions # will always return the same pointer. self._model_content = model_content self._interpreter = ( _interpreter_wrapper.CreateWrapperFromBuffer( - model_content, self._custom_op_registerers)) + model_content, custom_op_registerers_by_name, + custom_op_registerers_by_func)) elif not model_content and not model_path: raise ValueError('`model_path` or `model_content` must be specified.') else: @@ -573,8 +587,10 @@ class InterpreterWithCustomOps(Interpreter): experimental_delegates: Experimental. Subject to change. List of [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) objects returned by lite.load_delegate(). - custom_op_registerers: List of str, symbol names of functions that take a - pointer to a MutableOpResolver and register a custom op. + custom_op_registerers: List of str (symbol names) or functions that take a + pointer to a MutableOpResolver and register a custom op. When passing + functions, use a pybind function that takes a uintptr_t that can be + recast as a pointer to a MutableOpResolver. Raises: ValueError: If the interpreter was unable to create. diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index cc74f4d8fbc..bcb338b84cf 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -42,7 +42,7 @@ from tensorflow.python.platform import test class InterpreterCustomOpsTest(test_util.TensorFlowTestCase): - def testRegisterer(self): + def testRegistererByName(self): interpreter = interpreter_wrapper.InterpreterWithCustomOps( model_path=resource_loader.get_path_to_datafile( 'testdata/permute_float.tflite'), @@ -50,6 +50,14 @@ class InterpreterCustomOpsTest(test_util.TensorFlowTestCase): self.assertTrue(interpreter._safe_to_run()) self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1) + def testRegistererByFunc(self): + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite'), + custom_op_registerers=[test_registerer.TF_TestRegisterer]) + self.assertTrue(interpreter._safe_to_run()) + self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1) + def testRegistererFailure(self): bogus_name = 'CompletelyBogusRegistererName' with self.assertRaisesRegex( @@ -72,14 +80,16 @@ class InterpreterTest(test_util.TensorFlowTestCase): with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'): interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile( - 'testdata/permute_float.tflite'), num_threads=-1) + 'testdata/permute_float.tflite'), + num_threads=-1) def testThreads_WrongType(self): with self.assertRaisesRegex(ValueError, 'type of num_threads should be int'): interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile( - 'testdata/permute_float.tflite'), num_threads=4.2) + 'testdata/permute_float.tflite'), + num_threads=4.2) def testFloat(self): interpreter = interpreter_wrapper.Interpreter( @@ -116,7 +126,8 @@ class InterpreterTest(test_util.TensorFlowTestCase): def testFloatWithTwoThreads(self): interpreter = interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile( - 'testdata/permute_float.tflite'), num_threads=2) + 'testdata/permute_float.tflite'), + num_threads=2) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -158,8 +169,7 @@ class InterpreterTest(test_util.TensorFlowTestCase): test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) - interpreter.resize_tensor_input(input_details[0]['index'], - test_input.shape) + interpreter.resize_tensor_input(input_details[0]['index'], test_input.shape) interpreter.allocate_tensors() interpreter.set_tensor(input_details[0]['index'], test_input) interpreter.invoke() @@ -267,8 +277,7 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): def testInvalidModelFile(self): with self.assertRaisesRegex(ValueError, 'Could not open \'totally_invalid_file_name\''): - interpreter_wrapper.Interpreter( - model_path='totally_invalid_file_name') + interpreter_wrapper.Interpreter(model_path='totally_invalid_file_name') def testInvokeBeforeReady(self): interpreter = interpreter_wrapper.Interpreter( @@ -423,16 +432,19 @@ class InterpreterDelegateTest(test_util.TensorFlowTestCase): self.skipTest('TODO(b/142136355): fix flakiness and re-enable') # Track which order destructions were doned in destructions = [] + def register_destruction(x): destructions.append( x if isinstance(x, str) else six.ensure_text(x, 'utf-8')) return 0 + # Make a wrapper for the callback so we can send this to ctypes delegate = interpreter_wrapper.load_delegate(self._delegate_file) # Make an interpreter with the delegate interpreter = interpreter_wrapper.Interpreter( model_path=resource_loader.get_path_to_datafile( - 'testdata/permute_float.tflite'), experimental_delegates=[delegate]) + 'testdata/permute_float.tflite'), + experimental_delegates=[delegate]) class InterpreterDestroyCallback(object): diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 7295a46193e..adfa760f147 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include #include #include @@ -168,17 +169,22 @@ bool RegisterCustomOpByName(const char* registerer_name, InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( std::unique_ptr model, std::unique_ptr error_reporter, - const std::vector& registerers, std::string* error_msg) { + const std::vector& registerers_by_name, + const std::vector>& registerers_by_func, + std::string* error_msg) { if (!model) { *error_msg = error_reporter->message(); return nullptr; } auto resolver = absl::make_unique(); - for (const auto& registerer : registerers) { + for (const auto& registerer : registerers_by_name) { if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg)) return nullptr; } + for (const auto& registerer : registerers_by_func) { + registerer(reinterpret_cast(resolver.get())); + } auto interpreter = CreateInterpreter(model.get(), *resolver); if (!interpreter) { *error_msg = error_reporter->message(); @@ -655,18 +661,27 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( - const char* model_path, const std::vector& registerers, + const char* model_path, const std::vector& registerers_by_name, + const std::vector>& registerers_by_func, std::string* error_msg) { std::unique_ptr error_reporter(new PythonErrorReporter); std::unique_ptr model = tflite_api_dispatcher::TfLiteModel::BuildFromFile(model_path, error_reporter.get()); return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), - registerers, error_msg); + registerers_by_name, registerers_by_func, + error_msg); +} + +InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( + const char* model_path, const std::vector& registerers, + std::string* error_msg) { + return CreateWrapperCPPFromFile(model_path, registerers, {}, error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - PyObject* data, const std::vector& registerers, + PyObject* data, const std::vector& registerers_by_name, + const std::vector>& registerers_by_func, std::string* error_msg) { char* buf = nullptr; Py_ssize_t length; @@ -679,7 +694,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( tflite_api_dispatcher::TfLiteModel::BuildFromBuffer(buf, length, error_reporter.get()); return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), - registerers, error_msg); + registerers_by_name, registerers_by_func, + error_msg); +} + +InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( + PyObject* data, const std::vector& registerers, + std::string* error_msg) { + return CreateWrapperCPPFromBuffer(data, registerers, {}, error_msg); } PyObject* InterpreterWrapper::ResetVariableTensors() { diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index 5580eaa0f4b..6b83d2d06db 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ #define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_INTERPRETER_WRAPPER_H_ +#include #include #include #include @@ -51,11 +52,20 @@ class InterpreterWrapper { static InterpreterWrapper* CreateWrapperCPPFromFile( const char* model_path, const std::vector& registerers, std::string* error_msg); + static InterpreterWrapper* CreateWrapperCPPFromFile( + const char* model_path, + const std::vector& registerers_by_name, + const std::vector>& registerers_by_func, + std::string* error_msg); // SWIG caller takes ownership of pointer. static InterpreterWrapper* CreateWrapperCPPFromBuffer( PyObject* data, const std::vector& registerers, std::string* error_msg); + static InterpreterWrapper* CreateWrapperCPPFromBuffer( + PyObject* data, const std::vector& registerers_by_name, + const std::vector>& registerers_by_func, + std::string* error_msg); ~InterpreterWrapper(); PyObject* AllocateTensors(); @@ -106,7 +116,9 @@ class InterpreterWrapper { static InterpreterWrapper* CreateInterpreterWrapper( std::unique_ptr model, std::unique_ptr error_reporter, - const std::vector& registerers, std::string* error_msg); + const std::vector& registerers_by_name, + const std::vector>& registerers_by_func, + std::string* error_msg); InterpreterWrapper( std::unique_ptr model, diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc index a85bdc8baf4..f30912c44b4 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "pybind11/functional.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" @@ -42,6 +43,20 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) { } return wrapper; }); + m.def("CreateWrapperFromFile", + [](const std::string& model_path, + const std::vector& registerers_by_name, + const std::vector>& + registerers_by_func) { + std::string error; + auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromFile( + model_path.c_str(), registerers_by_name, registerers_by_func, + &error); + if (!wrapper) { + throw std::invalid_argument(error); + } + return wrapper; + }); m.def("CreateWrapperFromBuffer", [](const py::bytes& data, const std::vector& registerers) { std::string error; @@ -52,6 +67,19 @@ PYBIND11_MODULE(_pywrap_tensorflow_interpreter_wrapper, m) { } return wrapper; }); + m.def("CreateWrapperFromBuffer", + [](const py::bytes& data, + const std::vector& registerers_by_name, + const std::vector>& + registerers_by_func) { + std::string error; + auto* wrapper = ::InterpreterWrapper::CreateWrapperCPPFromBuffer( + data.ptr(), registerers_by_name, registerers_by_func, &error); + if (!wrapper) { + throw std::invalid_argument(error); + } + return wrapper; + }); py::class_(m, "InterpreterWrapper") .def("AllocateTensors", [](InterpreterWrapper& self) {