Automated rollback of commit 9e90593b24

PiperOrigin-RevId: 265347836
This commit is contained in:
A. Unique TensorFlower 2019-08-25 13:12:10 -07:00 committed by TensorFlower Gardener
parent 553a3b826a
commit 752556174a
11 changed files with 291 additions and 17 deletions

View File

@ -33,6 +33,7 @@ py_test(
],
deps = [
":interpreter",
"//tensorflow/lite/python/testdata:test_registerer_wrapper",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",

View File

@ -200,10 +200,12 @@ class Interpreter(object):
Raises:
ValueError: If the interpreter was unable to create.
"""
if not hasattr(self, '_custom_op_registerers'):
self._custom_op_registerers = []
if model_path and not model_content:
self._interpreter = (
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
model_path))
model_path, self._custom_op_registerers))
if not self._interpreter:
raise ValueError('Failed to open {}'.format(model_path))
elif model_content and not model_path:
@ -213,7 +215,7 @@ class Interpreter(object):
self._model_content = model_content
self._interpreter = (
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
model_content))
model_content, self._custom_op_registerers))
elif not model_path and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.')
else:
@ -454,3 +456,40 @@ class Interpreter(object):
def reset_all_variables(self):
return self._interpreter.ResetVariableTensors()
class InterpreterWithCustomOps(Interpreter):
"""Interpreter interface for TensorFlow Lite Models that accepts custom ops.
The interface provided by this class is experimenal and therefore not exposed
as part of the public API.
Wraps the tf.lite.Interpreter class and adds the ability to load custom ops
by providing the names of functions that take a pointer to a BuiltinOpResolver
and add a custom op.
"""
def __init__(self,
model_path=None,
model_content=None,
experimental_delegates=None,
custom_op_registerers=None):
"""Constructor.
Args:
model_path: Path to TF-Lite Flatbuffer file.
model_content: Content of model.
experimental_delegates: Experimental. Subject to change. List of
[TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
objects returned by lite.load_delegate().
custom_op_registerers: List of str, symbol names of functions that take a
pointer to a MutableOpResolver and register a custom op.
Raises:
ValueError: If the interpreter was unable to create.
"""
self._custom_op_registerers = custom_op_registerers
super(InterpreterWithCustomOps, self).__init__(
model_path=model_path,
model_content=model_content,
experimental_delegates=experimental_delegates)

View File

@ -23,10 +23,39 @@ import sys
import numpy as np
import six
# Force loaded shared object symbols to be globally visible. This is needed so
# that the interpreter_wrapper, in one .so file, can see the test_registerer,
# in a different .so file. Note that this may already be set by default.
# pylint: disable=g-import-not-at-top
if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
from tensorflow.lite.python import interpreter as interpreter_wrapper
from tensorflow.lite.python.testdata import test_registerer_wrapper as test_registerer
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
def testRegisterer(self):
interpreter = interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
custom_op_registerers=['TF_TestRegisterer'])
self.assertTrue(interpreter._safe_to_run())
self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
def testRegistererFailure(self):
bogus_name = 'CompletelyBogusRegistererName'
with self.assertRaisesRegexp(
ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
interpreter_wrapper.InterpreterWithCustomOps(
model_path=resource_loader.get_path_to_datafile(
'testdata/permute_float.tflite'),
custom_op_registerers=[bogus_name])
class InterpreterTest(test_util.TensorFlowTestCase):

View File

@ -28,9 +28,11 @@ cc_library(
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
],
)
@ -60,6 +62,7 @@ tf_py_wrap_cc(
srcs = [
"interpreter_wrapper.i",
],
copts = ["-fexceptions"],
deps = [
":interpreter_wrapper_lib",
"//third_party/python_runtime:headers",

View File

@ -14,11 +14,22 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
// Windows does not have dlfcn.h/dlsym, use GetProcAddress() instead.
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif // defined(_WIN32)
#include <stdarg.h>
#include <sstream>
#include <string>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
@ -82,18 +93,60 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
return result;
}
bool RegisterCustomOpByName(const char* registerer_name,
tflite::MutableOpResolver* resolver,
std::string* error_msg) {
// Registerer functions take a pointer to a BuiltinOpResolver as an input
// parameter and return void.
// TODO(b/137576229): We should implement this functionality in a more
// principled way.
typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
// Look for the Registerer function by name.
RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
// We don't have dlsym on Windows, use GetProcAddress instead.
#if defined(_WIN32)
GetProcAddress(nullptr, registerer_name)
#else
dlsym(RTLD_DEFAULT, registerer_name)
#endif // defined(_WIN32)
);
// Fail in an informative way if the function was not found.
if (registerer == nullptr) {
// We don't have dlerror on Windows, use GetLastError instead.
*error_msg =
#if defined(_WIN32)
absl::StrFormat("Looking up symbol '%s' failed with error (0x%x).",
registerer_name, GetLastError());
#else
absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
registerer_name, dlerror());
#endif // defined(_WIN32)
return false;
}
// Call the registerer with the resolver.
registerer(resolver);
return true;
}
} // namespace
InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
std::unique_ptr<tflite::FlatBufferModel> model,
std::unique_ptr<PythonErrorReporter> error_reporter,
std::string* error_msg) {
const std::vector<std::string>& registerers, std::string* error_msg) {
if (!model) {
*error_msg = error_reporter->message();
return nullptr;
}
auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
for (const auto registerer : registerers) {
if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
return nullptr;
}
auto interpreter = CreateInterpreter(model.get(), *resolver);
if (!interpreter) {
*error_msg = error_reporter->message();
@ -417,16 +470,18 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path, std::string* error_msg) {
const char* model_path, const std::vector<std::string>& registerers,
std::string* error_msg) {
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
error_msg);
registerers, error_msg);
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, std::string* error_msg) {
PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg) {
char * buf = nullptr;
Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
@ -438,7 +493,7 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
tflite::FlatBufferModel::BuildFromBuffer(buf, length,
error_reporter.get());
return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
error_msg);
registerers, error_msg);
}
PyObject* InterpreterWrapper::ResetVariableTensors() {

View File

@ -46,12 +46,14 @@ class PythonErrorReporter;
class InterpreterWrapper {
public:
// SWIG caller takes ownership of pointer.
static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path,
std::string* error_msg);
static InterpreterWrapper* CreateWrapperCPPFromFile(
const char* model_path, const std::vector<std::string>& registerers,
std::string* error_msg);
// SWIG caller takes ownership of pointer.
static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data,
std::string* error_msg);
static InterpreterWrapper* CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg);
~InterpreterWrapper();
PyObject* AllocateTensors();
@ -84,7 +86,7 @@ class InterpreterWrapper {
static InterpreterWrapper* CreateInterpreterWrapper(
std::unique_ptr<tflite::FlatBufferModel> model,
std::unique_ptr<PythonErrorReporter> error_reporter,
std::string* error_msg);
const std::vector<std::string>& registerers, std::string* error_msg);
InterpreterWrapper(
std::unique_ptr<tflite::FlatBufferModel> model,

View File

@ -33,6 +33,25 @@ limitations under the License.
$result = PyLong_FromVoidPtr($1)
}
// Converts a Python list of str to a std::vector<std::string>, returns true
// if the conversion was successful.
%{
static bool PyListToStdVectorString(PyObject *list, std::vector<std::string> *strings) {
// Make sure the list is actually a list.
if (!PyList_Check(list)) return false;
// Convert the Python list to a vector of strings.
const int list_size = PyList_Size(list);
strings->resize(list_size);
for (int k = 0; k < list_size; k++) {
PyObject *string_py = PyList_GetItem(list, k);
if (!PyString_Check(string_py)) return false;
(*strings)[k] = std::string(PyString_AsString(string_py));
}
return true;
}
%}
bool PyListToStdVectorString(PyObject *list, std::vector<std::string> *strings);
%include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
@ -42,12 +61,19 @@ namespace interpreter_wrapper {
// Version of the constructor that handles producing Python exceptions
// that propagate strings.
static PyObject* CreateWrapperCPPFromFile(const char* model_path) {
static PyObject* CreateWrapperCPPFromFile(
const char* model_path,
PyObject* registerers_py) {
std::string error;
std::vector<std::string> registerers;
if (!PyListToStdVectorString(registerers_py, &registerers)) {
PyErr_SetString(PyExc_ValueError, "Second argument is expected to be a list of strings.");
return nullptr;
}
if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
tflite::interpreter_wrapper::InterpreterWrapper
::CreateWrapperCPPFromFile(
model_path, &error)) {
model_path, registerers, &error)) {
return SWIG_NewPointerObj(
ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
} else {
@ -59,12 +85,18 @@ namespace interpreter_wrapper {
// Version of the constructor that handles producing Python exceptions
// that propagate strings.
static PyObject* CreateWrapperCPPFromBuffer(
PyObject* data) {
PyObject* data ,
PyObject* registerers_py) {
std::string error;
std::vector<std::string> registerers;
if (!PyListToStdVectorString(registerers_py, &registerers)) {
PyErr_SetString(PyExc_ValueError, "Second argument is expected to be a list of strings.");
return nullptr;
}
if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
tflite::interpreter_wrapper::InterpreterWrapper
::CreateWrapperCPPFromBuffer(
data, &error)) {
data, registerers, &error)) {
return SWIG_NewPointerObj(
ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
} else {

View File

@ -1,8 +1,9 @@
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
licenses = ["notice"], # Apache 2.0,
)
exports_files(glob(["*.pb"]))
@ -71,3 +72,26 @@ cc_binary(
":test_delegate",
],
)
cc_library(
name = "test_registerer",
srcs = ["test_registerer.cc"],
hdrs = ["test_registerer.h"],
visibility = ["//tensorflow/lite:__subpackages__"],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)
tf_py_wrap_cc(
name = "test_registerer_wrapper",
srcs = [
"test_registerer.i",
],
deps = [
":test_registerer",
"//third_party/python_runtime:headers",
],
)

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/python/testdata/test_registerer.h"
namespace tflite {
namespace {
static int num_test_registerer_calls = 0;
} // namespace
// Dummy registerer function with the correct signature. Ignores the resolver
// but increments the num_test_registerer_calls counter by one. The TF_ prefix
// is needed to get past the version script in the OSS build.
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver) {
num_test_registerer_calls++;
}
// Returns the num_test_registerer_calls counter and re-sets it.
int get_num_test_registerer_calls() {
const int result = num_test_registerer_calls;
num_test_registerer_calls = 0;
return result;
}
} // namespace tflite

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_
#define TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_
#include "tensorflow/lite/mutable_op_resolver.h"
namespace tflite {
// Dummy registerer function with the correct signature. Ignores the resolver
// but increments the num_test_registerer_calls counter by one. The TF_ prefix
// is needed to get past the version script in the OSS build.
extern "C" void TF_TestRegisterer(tflite::MutableOpResolver *resolver);
// Returns the num_test_registerer_calls counter and re-sets it.
int get_num_test_registerer_calls();
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_TESTDATA_TEST_REGISTERER_H_

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
%{
#include "tensorflow/lite/python/testdata/test_registerer.h"
%}
%include "tensorflow/lite/python/testdata/test_registerer.h"