Python API support for string tensors.

PiperOrigin-RevId: 232726064
This commit is contained in:
A. Unique TensorFlower 2019-02-06 12:43:48 -08:00 committed by TensorFlower Gardener
parent ba9f4c017c
commit ca62689feb
13 changed files with 410 additions and 58 deletions

View File

@ -13,7 +13,6 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)

View File

@ -91,6 +91,41 @@ class InterpreterTest(test_util.TensorFlowTestCase):
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
def testString(self):
interpreter = interpreter_wrapper.Interpreter(
model_path=resource_loader.get_path_to_datafile(
'testdata/gather_string.tflite'))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(2, len(input_details))
self.assertEqual('input', input_details[0]['name'])
self.assertEqual(np.string_, input_details[0]['dtype'])
self.assertTrue(([10] == input_details[0]['shape']).all())
self.assertEqual((0.0, 0), input_details[0]['quantization'])
self.assertEqual('indices', input_details[1]['name'])
self.assertEqual(np.int64, input_details[1]['dtype'])
self.assertTrue(([3] == input_details[1]['shape']).all())
self.assertEqual((0.0, 0), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertEqual(np.string_, output_details[0]['dtype'])
self.assertTrue(([3] == output_details[0]['shape']).all())
self.assertEqual((0.0, 0), output_details[0]['quantization'])
test_input = np.array([1, 2, 3], dtype=np.int64)
interpreter.set_tensor(input_details[1]['index'], test_input)
test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
expected_output = np.array([b'b', b'c', b'd'])
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):

View File

@ -6,14 +6,26 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
cc_library(
name = "numpy",
srcs = ["numpy.cc"],
hdrs = ["numpy.h"],
deps = [
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
)
cc_library(
name = "interpreter_wrapper_lib",
srcs = ["interpreter_wrapper.cc"],
hdrs = ["interpreter_wrapper.h"],
deps = [
":numpy",
":python_error_reporter",
":python_utils",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
@ -36,7 +48,9 @@ cc_library(
srcs = ["python_utils.cc"],
hdrs = ["python_utils.h"],
deps = [
":numpy",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],

View File

@ -21,16 +21,10 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
#include "tensorflow/lite/string_util.h"
#if PY_MAJOR_VERSION >= 3
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
@ -64,12 +58,6 @@ namespace interpreter_wrapper {
namespace {
// Calls PyArray's initialization to initialize all the API pointers. Note that
// this usage implies only this translation unit can use the pointers. See
// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend
// this further.
void ImportNumpy() { import_array1(); }
std::unique_ptr<tflite::Interpreter> CreateInterpreter(
const tflite::FlatBufferModel* model,
const tflite::ops::builtin::BuiltinOpResolver& resolver) {
@ -77,7 +65,7 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
return nullptr;
}
ImportNumpy();
::tflite::python::ImportNumpy();
std::unique_ptr<tflite::Interpreter> interpreter;
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
@ -267,7 +255,7 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
const TfLiteTensor* tensor = interpreter_->tensor(i);
TfLiteTensor* tensor = interpreter_->tensor(i);
if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
PyErr_Format(PyExc_ValueError,
@ -279,26 +267,41 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
}
if (PyArray_NDIM(array) != tensor->dims->size) {
PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
PyErr_Format(PyExc_ValueError,
"Cannot set tensor: Dimension mismatch."
" Got %d"
" but expected %d for input %d.",
PyArray_NDIM(array), tensor->dims->size, i);
return nullptr;
}
for (int j = 0; j < PyArray_NDIM(array); j++) {
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
PyErr_SetString(PyExc_ValueError,
"Cannot set tensor: Dimension mismatch");
PyErr_Format(PyExc_ValueError,
"Cannot set tensor: Dimension mismatch."
" Got %ld"
" but expected %d for dimension %d of input %d.",
PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
return nullptr;
}
}
size_t size = PyArray_NBYTES(array);
if (size != tensor->bytes) {
PyErr_Format(PyExc_ValueError,
"numpy array had %zu bytes but expected %zu bytes.", size,
tensor->bytes);
return nullptr;
if (tensor->type != kTfLiteString) {
size_t size = PyArray_NBYTES(array);
if (size != tensor->bytes) {
PyErr_Format(PyExc_ValueError,
"numpy array had %zu bytes but expected %zu bytes.", size,
tensor->bytes);
return nullptr;
}
memcpy(tensor->data.raw, PyArray_DATA(array), size);
} else {
DynamicBuffer dynamic_buffer;
if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
return nullptr;
}
dynamic_buffer.WriteToTensor(tensor, nullptr);
}
memcpy(tensor->data.raw, PyArray_DATA(array), size);
Py_RETURN_NONE;
}
@ -345,19 +348,51 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
// Make a buffer copy but we must tell Numpy It owns that data or else
// it will leak.
void* data = malloc(tensor->bytes);
if (!data) {
PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
return nullptr;
if (tensor->type != kTfLiteString) {
// Make a buffer copy but we must tell Numpy It owns that data or else
// it will leak.
void* data = malloc(tensor->bytes);
if (!data) {
PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
return nullptr;
}
memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
NPY_ARRAY_OWNDATA);
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
} else {
// Create a C-order array so the data is contiguous in memory.
const int32_t kCOrder = 0;
PyObject* py_object =
PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
if (py_object == nullptr) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
return nullptr;
}
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
auto num_strings = GetStringCount(tensor->data.raw);
for (int j = 0; j < num_strings; ++j) {
auto ref = GetString(tensor->data.raw, j);
PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
if (bytes == nullptr) {
Py_DECREF(py_object);
PyErr_Format(PyExc_ValueError,
"Could not create PyBytes from string %d of input %d.", j,
i);
return nullptr;
}
// PyArray_EMPTY produces an array full of Py_None, which we must decref.
Py_DECREF(data[j]);
data[j] = bytes;
}
return py_object;
}
memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
NPY_ARRAY_OWNDATA);
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {

View File

@ -0,0 +1,25 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define TFLITE_IMPORT_NUMPY // See numpy.h for explanation.
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
namespace tflite {
namespace python {
void ImportNumpy() { import_array1(); }
} // namespace python
} // namespace tflite

View File

@ -0,0 +1,62 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
#ifdef PyArray_Type
#error "Numpy cannot be included before numpy.h."
#endif
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
// To handle PyArray_* calles, numpy defines a static lookup table called
// PyArray_API, or PY_ARRAY_UNIQUE_SYMBOL, if defined. This causes the
// PyArray_* pointers to be different for different translation units, unless
// we take care of selectivel defined NO_IMPORT_ARRAY.
//
// Virtually every usage will define NO_IMPORT_ARRAY, and will have access to
// the lookup table via:
// extern void **PyArray_API;
// In numpy.cc we will define TFLITE_IMPORT_NUMPY, effectively disabling that
// and instead using:
// void **PyArray_API;
// which is initialized when ImportNumpy() is called.
//
// If we don't define PY_ARRAY_UNIQUE_SYMBOL then PyArray_API is a static
// variable, which causes strange crashes when the pointers are used across
// translation unit boundaries.
//
// For mone info see https://sourceforge.net/p/numpy/mailman/message/5700519
// See also tensorflow/python/lib/core/numpy.h for a similar approach.
#define PY_ARRAY_UNIQUE_SYMBOL _tensorflow_numpy_api
#ifndef TFLITE_IMPORT_NUMPY
#define NO_IMPORT_ARRAY
#endif
#include <Python.h>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
namespace tflite {
namespace python {
void ImportNumpy();
} // namespace python
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_

View File

@ -15,9 +15,19 @@ limitations under the License.
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include <memory>
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
namespace tflite {
namespace python_utils {
struct PyObjectDereferencer {
void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
};
using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
switch (tf_lite_type) {
case kTfLiteFloat32:
@ -33,7 +43,7 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
case kTfLiteInt64:
return NPY_INT64;
case kTfLiteString:
return NPY_OBJECT;
return NPY_STRING;
case kTfLiteBool:
return NPY_BOOL;
case kTfLiteComplex64:
@ -73,5 +83,82 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
return kTfLiteNoType;
}
#if PY_VERSION_HEX >= 0x03030000
bool FillStringBufferFromPyUnicode(PyObject* value,
DynamicBuffer* dynamic_buffer) {
Py_ssize_t len = -1;
char* buf = PyUnicode_AsUTF8AndSize(value, &len);
if (buf == NULL) {
PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
return false;
}
dynamic_buffer->AddString(buf, len);
return true;
}
#else
bool FillStringBufferFromPyUnicode(PyObject* value,
DynamicBuffer* dynamic_buffer) {
UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
if (!utemp) {
PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
return false;
}
char* buf = nullptr;
Py_ssize_t len = -1;
if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
return false;
}
dynamic_buffer->AddString(buf, len);
return true;
}
#endif
bool FillStringBufferFromPyString(PyObject* value,
DynamicBuffer* dynamic_buffer) {
if (PyUnicode_Check(value)) {
return FillStringBufferFromPyUnicode(value, dynamic_buffer);
}
char* buf = nullptr;
Py_ssize_t len = -1;
if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
return false;
}
dynamic_buffer->AddString(buf, len);
return true;
}
bool FillStringBufferWithPyArray(PyObject* value,
DynamicBuffer* dynamic_buffer) {
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
switch (PyArray_TYPE(array)) {
case NPY_OBJECT:
case NPY_STRING:
case NPY_UNICODE: {
UniquePyObjectRef iter(PyArray_IterNew(value));
while (PyArray_ITER_NOTDONE(iter.get())) {
UniquePyObjectRef item(PyArray_GETITEM(
array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
return false;
}
PyArray_ITER_NEXT(iter.get());
}
return true;
}
default:
break;
}
PyErr_Format(PyExc_ValueError,
"Cannot use numpy array of type %d for string tensor.",
PyArray_TYPE(array));
return false;
}
} // namespace python_utils
} // namespace tflite

View File

@ -17,14 +17,8 @@ limitations under the License.
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
#include "tensorflow/lite/context.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace python_utils {
@ -33,6 +27,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
bool FillStringBufferWithPyArray(PyObject* value,
DynamicBuffer* dynamic_buffer);
} // namespace python_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_

View File

@ -131,13 +131,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.object_, input_details[0]['dtype'])
self.assertEqual(np.string_, input_details[0]['dtype'])
self.assertTrue(([4] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('Reshape', output_details[0]['name'])
self.assertEqual(np.object_, output_details[0]['dtype'])
self.assertEqual(np.string_, output_details[0]['dtype'])
self.assertTrue(([2, 2] == output_details[0]['shape']).all())
# TODO(b/122659643): Test setting/getting string data via the python
# interpreter API after support has been added.

View File

@ -13,6 +13,7 @@ cc_library(
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/python/interpreter_wrapper:numpy",
"//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
"//tensorflow/lite/tools/optimize:calibration_reader",

View File

@ -22,20 +22,13 @@ limitations under the License.
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/tools/optimize/calibration_reader.h"
#include "tensorflow/lite/tools/optimize/calibrator.h"
#include "tensorflow/lite/tools/optimize/quantize_model.h"
// Disallow Numpy 1.7 deprecated symbols.
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
#if PY_MAJOR_VERSION >= 3
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
#define CPP_TO_PYSTRING PyBytes_FromStringAndSize

View File

@ -32,9 +32,20 @@ tf_to_tflite(
],
)
tf_to_tflite(
name = "gather_string",
src = "gather.pbtxt",
out = "gather_string.tflite",
options = [
"--input_arrays=input,indices",
"--output_arrays=output",
],
)
filegroup(
name = "interpreter_test_data",
srcs = [
":gather_string",
":permute_float",
":permute_uint8",
],

View File

@ -0,0 +1,93 @@
node {
name: "input"
op: "Placeholder"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 10
}
}
}
}
}
node {
name: "indices"
op: "Placeholder"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_INT64
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
}
}
}
}
node {
name: "axis"
op: "Const"
device: "/device:CPU:0"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "output"
op: "GatherV2"
input: "input"
input: "indices"
input: "axis"
device: "/device:CPU:0"
attr {
key: "Taxis"
value {
type: DT_INT32
}
}
attr {
key: "Tindices"
value {
type: DT_INT64
}
}
attr {
key: "Tparams"
value {
type: DT_STRING
}
}
}
versions {
producer: 27
}