diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index ca005465212..331a4a89457 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -33,6 +33,7 @@ py_test( ], deps = [ ":interpreter", + "//tensorflow/lite/python/testdata:test_registerer_wrapper", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 43b90883c8a..b5d6ad543d1 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -200,10 +200,12 @@ class Interpreter(object): Raises: ValueError: If the interpreter was unable to create. """ + if not hasattr(self, '_custom_op_registerers'): + self._custom_op_registerers = [] if model_path and not model_content: self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile( - model_path)) + model_path, self._custom_op_registerers)) if not self._interpreter: raise ValueError('Failed to open {}'.format(model_path)) elif model_content and not model_path: @@ -213,7 +215,7 @@ class Interpreter(object): self._model_content = model_content self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( - model_content)) + model_content, self._custom_op_registerers)) elif not model_path and not model_path: raise ValueError('`model_path` or `model_content` must be specified.') else: @@ -454,3 +456,40 @@ class Interpreter(object): def reset_all_variables(self): return self._interpreter.ResetVariableTensors() + + +class InterpreterWithCustomOps(Interpreter): + """Interpreter interface for TensorFlow Lite Models that accepts custom ops. + + The interface provided by this class is experimenal and therefore not exposed + as part of the public API. + + Wraps the tf.lite.Interpreter class and adds the ability to load custom ops + by providing the names of functions that take a pointer to a BuiltinOpResolver + and add a custom op. + """ + + def __init__(self, + model_path=None, + model_content=None, + experimental_delegates=None, + custom_op_registerers=None): + """Constructor. + + Args: + model_path: Path to TF-Lite Flatbuffer file. + model_content: Content of model. + experimental_delegates: Experimental. Subject to change. List of + [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates) + objects returned by lite.load_delegate(). + custom_op_registerers: List of str, symbol names of functions that take a + pointer to a MutableOpResolver and register a custom op. + + Raises: + ValueError: If the interpreter was unable to create. + """ + self._custom_op_registerers = custom_op_registerers + super(InterpreterWithCustomOps, self).__init__( + model_path=model_path, + model_content=model_content, + experimental_delegates=experimental_delegates) diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 27c4e5756ca..af0540c510a 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -23,10 +23,39 @@ import sys import numpy as np import six +# Force loaded shared object symbols to be globally visible. This is needed so +# that the interpreter_wrapper, in one .so file, can see the test_registerer, +# in a different .so file. Note that this may already be set by default. +# pylint: disable=g-import-not-at-top +if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'): + sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) + from tensorflow.lite.python import interpreter as interpreter_wrapper +from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer from tensorflow.python.framework import test_util from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test +# pylint: enable=g-import-not-at-top + + +class InterpreterCustomOpsTest(test_util.TensorFlowTestCase): + + def testRegisterer(self): + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite'), + custom_op_registerers=['TF_TestRegisterer']) + self.assertTrue(interpreter._safe_to_run()) + self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1) + + def testRegistererFailure(self): + bogus_name = 'CompletelyBogusRegistererName' + with self.assertRaisesRegexp( + ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'): + interpreter_wrapper.InterpreterWithCustomOps( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite'), + custom_op_registerers=[bogus_name]) class InterpreterTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 476f9390e57..6e8ba8e7de1 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -28,9 +28,11 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite:string_util", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:builtin_ops", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -60,6 +62,7 @@ tf_py_wrap_cc( srcs = [ "interpreter_wrapper.i", ], + copts = ["-fexceptions"], deps = [ ":interpreter_wrapper_lib", "//third_party/python_runtime:headers", diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index d0076e6a351..b4da1fd6d36 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -14,11 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" +// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead. +#if defined(_WIN32) +#include <windows.h> +#else +#include <dlfcn.h> +#endif // defined(_WIN32) + +#include <stdarg.h> + #include <sstream> #include <string> #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" @@ -82,18 +93,60 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { return result; } +bool RegisterCustomOpByName(const char* registerer_name, + tflite::MutableOpResolver* resolver, + std::string* error_msg) { + // Registerer functions take a pointer to a BuiltinOpResolver as an input + // parameter and return void. + // TODO(b/137576229): We should implement this functionality in a more + // principled way. + typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*); + + // Look for the Registerer function by name. + RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>( + // We don't have dlsym on Windows, use GetProcAddress instead. +#if defined(_WIN32) + GetProcAddress(nullptr, registerer_name) +#else + dlsym(RTLD_DEFAULT, registerer_name) +#endif // defined(_WIN32) + ); + + // Fail in an informative way if the function was not found. + if (registerer == nullptr) { + // We don't have dlerror on Windows, use GetLastError instead. + *error_msg = +#if defined(_WIN32) + absl::StrFormat("Looking up symbol '%s' failed with error (0x%x).", + registerer_name, GetLastError()); +#else + absl::StrFormat("Looking up symbol '%s' failed with error '%s'.", + registerer_name, dlerror()); +#endif // defined(_WIN32) + return false; + } + + // Call the registerer with the resolver. + registerer(resolver); + return true; +} + } // namespace InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( std::unique_ptr<tflite::FlatBufferModel> model, std::unique_ptr<PythonErrorReporter> error_reporter, - std::string* error_msg) { + const std::vector<std::string>& registerers, std::string* error_msg) { if (!model) { *error_msg = error_reporter->message(); return nullptr; } auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); + for (const auto registerer : registerers) { + if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg)) + return nullptr; + } auto interpreter = CreateInterpreter(model.get(), *resolver); if (!interpreter) { *error_msg = error_reporter->message(); @@ -417,16 +470,18 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( - const char* model_path, std::string* error_msg) { + const char* model_path, const std::vector<std::string>& registerers, + std::string* error_msg) { std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), - error_msg); + registerers, error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - PyObject* data, std::string* error_msg) { + PyObject* data, const std::vector<std::string>& registerers, + std::string* error_msg) { char * buf = nullptr; Py_ssize_t length; std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); @@ -438,7 +493,7 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( tflite::FlatBufferModel::BuildFromBuffer(buf, length, error_reporter.get()); return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), - error_msg); + registerers, error_msg); } PyObject* InterpreterWrapper::ResetVariableTensors() { diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index da3e5516743..de57f732038 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -46,12 +46,14 @@ class PythonErrorReporter; class InterpreterWrapper { public: // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path, - std::string* error_msg); + static InterpreterWrapper* CreateWrapperCPPFromFile( + const char* model_path, const std::vector<std::string>& registerers, + std::string* error_msg); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data, - std::string* error_msg); + static InterpreterWrapper* CreateWrapperCPPFromBuffer( + PyObject* data, const std::vector<std::string>& registerers, + std::string* error_msg); ~InterpreterWrapper(); PyObject* AllocateTensors(); @@ -84,7 +86,7 @@ class InterpreterWrapper { static InterpreterWrapper* CreateInterpreterWrapper( std::unique_ptr<tflite::FlatBufferModel> model, std::unique_ptr<PythonErrorReporter> error_reporter, - std::string* error_msg); + const std::vector<std::string>& registerers, std::string* error_msg); InterpreterWrapper( std::unique_ptr<tflite::FlatBufferModel> model, diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i index 5424c625508..cfa4d0ae87d 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -33,6 +33,25 @@ limitations under the License. $result = PyLong_FromVoidPtr($1) } +// Converts a Python list of str to a std::vector<std::string>, returns true +// if the conversion was successful. +%{ +static bool PyListToStdVectorString(PyObject *list, std::vector<std::string> *strings) { + // Make sure the list is actually a list. + if (!PyList_Check(list)) return false; + + // Convert the Python list to a vector of strings. + const int list_size = PyList_Size(list); + strings->resize(list_size); + for (int k = 0; k < list_size; k++) { + PyObject *string_py = PyList_GetItem(list, k); + if (!PyString_Check(string_py)) return false; + (*strings)[k] = std::string(PyString_AsString(string_py)); + } + return true; +} +%} +bool PyListToStdVectorString(PyObject *list, std::vector<std::string> *strings); %include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h" @@ -42,12 +61,19 @@ namespace interpreter_wrapper { // Version of the constructor that handles producing Python exceptions // that propagate strings. - static PyObject* CreateWrapperCPPFromFile(const char* model_path) { + static PyObject* CreateWrapperCPPFromFile( + const char* model_path, + PyObject* registerers_py) { std::string error; + std::vector<std::string> registerers; + if (!PyListToStdVectorString(registerers_py, ®isterers)) { + PyErr_SetString(PyExc_ValueError, "Second argument is expected to be a list of strings."); + return nullptr; + } if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = tflite::interpreter_wrapper::InterpreterWrapper ::CreateWrapperCPPFromFile( - model_path, &error)) { + model_path, registerers, &error)) { return SWIG_NewPointerObj( ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); } else { @@ -59,12 +85,18 @@ namespace interpreter_wrapper { // Version of the constructor that handles producing Python exceptions // that propagate strings. static PyObject* CreateWrapperCPPFromBuffer( - PyObject* data) { + PyObject* data , + PyObject* registerers_py) { std::string error; + std::vector<std::string> registerers; + if (!PyListToStdVectorString(registerers_py, ®isterers)) { + PyErr_SetString(PyExc_ValueError, "Second argument is expected to be a list of strings."); + return nullptr; + } if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = tflite::interpreter_wrapper::InterpreterWrapper ::CreateWrapperCPPFromBuffer( - data, &error)) { + data, registerers, &error)) { return SWIG_NewPointerObj( ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); } else { diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD index 7bda81358f9..0c12e19451c 100644 --- a/tensorflow/lite/python/testdata/BUILD +++ b/tensorflow/lite/python/testdata/BUILD @@ -1,8 +1,9 @@ load("//tensorflow/lite:build_def.bzl", "tf_to_tflite") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") package( default_visibility = ["//tensorflow:internal"], - licenses = ["notice"], # Apache 2.0 + licenses = ["notice"], # Apache 2.0, ) exports_files(glob(["*.pb"])) @@ -71,3 +72,26 @@ cc_binary( ":test_delegate", ], ) + +cc_library( + name = "test_registerer", + srcs = ["test_registerer.cc"], + hdrs = ["test_registerer.h"], + visibility = ["//tensorflow/lite:__subpackages__"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +tf_py_wrap_cc( + name = "test_registerer_wrapper", + srcs = [ + "test_registerer.i", + ], + deps = [ + ":test_registerer", + "//third_party/python_runtime:headers", + ], +) diff --git a/tensorflow/lite/python/testdata/test_registerer.cc b/tensorflow/lite/python/testdata/test_registerer.cc new file mode 100644 index 00000000000..6adde65a863 --- /dev/null +++ b/tensorflow/lite/python/testdata/test_registerer.cc @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/python/testdata/test_registerer.h" + +namespace tflite { + +namespace { +static int num_test_registerer_calls = 0; +} // namespace + +// Dummy registerer function with the correct signature. Ignores the resolver +// but increments the num_test_registerer_calls counter by one. The TF_ prefix +// is needed to get past the version script in the OSS build. +extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) { + num_test_registerer_calls++; +} + +// Returns the num_test_registerer_calls counter and re-sets it. +int get_num_test_registerer_calls() { + const int result = num_test_registerer_calls; + num_test_registerer_calls = 0; + return result; +} + +} // namespace tflite diff --git a/tensorflow/lite/python/testdata/test_registerer.h b/tensorflow/lite/python/testdata/test_registerer.h new file mode 100644 index 00000000000..8ee7e198358 --- /dev/null +++ b/tensorflow/lite/python/testdata/test_registerer.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_ +#define TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_ + +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { + +// Dummy registerer function with the correct signature. Ignores the resolver +// but increments the num_test_registerer_calls counter by one. The TF_ prefix +// is needed to get past the version script in the OSS build. +extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver); + +// Returns the num_test_registerer_calls counter and re-sets it. +int get_num_test_registerer_calls(); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_ diff --git a/tensorflow/lite/python/testdata/test_registerer.i b/tensorflow/lite/python/testdata/test_registerer.i new file mode 100644 index 00000000000..1cd41c9164d --- /dev/null +++ b/tensorflow/lite/python/testdata/test_registerer.i @@ -0,0 +1,20 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%{ +#include "tensorflow/lite/python/testdata/test_registerer.h" +%} + +%include "tensorflow/lite/python/testdata/test_registerer.h"