Python API support for string tensors.
PiperOrigin-RevId: 232726064
This commit is contained in:
parent
ba9f4c017c
commit
ca62689feb
tensorflow/lite/python
@ -13,7 +13,6 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -91,6 +91,41 @@ class InterpreterTest(test_util.TensorFlowTestCase):
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
def testString(self):
|
||||
interpreter = interpreter_wrapper.Interpreter(
|
||||
model_path=resource_loader.get_path_to_datafile(
|
||||
'testdata/gather_string.tflite'))
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(2, len(input_details))
|
||||
self.assertEqual('input', input_details[0]['name'])
|
||||
self.assertEqual(np.string_, input_details[0]['dtype'])
|
||||
self.assertTrue(([10] == input_details[0]['shape']).all())
|
||||
self.assertEqual((0.0, 0), input_details[0]['quantization'])
|
||||
self.assertEqual('indices', input_details[1]['name'])
|
||||
self.assertEqual(np.int64, input_details[1]['dtype'])
|
||||
self.assertTrue(([3] == input_details[1]['shape']).all())
|
||||
self.assertEqual((0.0, 0), input_details[1]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('output', output_details[0]['name'])
|
||||
self.assertEqual(np.string_, output_details[0]['dtype'])
|
||||
self.assertTrue(([3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0.0, 0), output_details[0]['quantization'])
|
||||
|
||||
test_input = np.array([1, 2, 3], dtype=np.int64)
|
||||
interpreter.set_tensor(input_details[1]['index'], test_input)
|
||||
|
||||
test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
|
||||
expected_output = np.array([b'b', b'c', b'd'])
|
||||
interpreter.set_tensor(input_details[0]['index'], test_input)
|
||||
interpreter.invoke()
|
||||
|
||||
output_data = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertTrue((expected_output == output_data).all())
|
||||
|
||||
|
||||
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -6,14 +6,26 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
|
||||
cc_library(
|
||||
name = "numpy",
|
||||
srcs = ["numpy.cc"],
|
||||
hdrs = ["numpy.h"],
|
||||
deps = [
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interpreter_wrapper_lib",
|
||||
srcs = ["interpreter_wrapper.cc"],
|
||||
hdrs = ["interpreter_wrapper.h"],
|
||||
deps = [
|
||||
":numpy",
|
||||
":python_error_reporter",
|
||||
":python_utils",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
@ -36,7 +48,9 @@ cc_library(
|
||||
srcs = ["python_utils.cc"],
|
||||
hdrs = ["python_utils.h"],
|
||||
deps = [
|
||||
":numpy",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//third_party/python_runtime:headers",
|
||||
],
|
||||
|
@ -21,16 +21,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
|
||||
@ -64,12 +58,6 @@ namespace interpreter_wrapper {
|
||||
|
||||
namespace {
|
||||
|
||||
// Calls PyArray's initialization to initialize all the API pointers. Note that
|
||||
// this usage implies only this translation unit can use the pointers. See
|
||||
// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend
|
||||
// this further.
|
||||
void ImportNumpy() { import_array1(); }
|
||||
|
||||
std::unique_ptr<tflite::Interpreter> CreateInterpreter(
|
||||
const tflite::FlatBufferModel* model,
|
||||
const tflite::ops::builtin::BuiltinOpResolver& resolver) {
|
||||
@ -77,7 +65,7 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ImportNumpy();
|
||||
::tflite::python::ImportNumpy();
|
||||
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
|
||||
@ -267,7 +255,7 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
||||
}
|
||||
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
|
||||
const TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||
TfLiteTensor* tensor = interpreter_->tensor(i);
|
||||
|
||||
if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
@ -279,26 +267,41 @@ PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
|
||||
}
|
||||
|
||||
if (PyArray_NDIM(array) != tensor->dims->size) {
|
||||
PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Cannot set tensor: Dimension mismatch."
|
||||
" Got %d"
|
||||
" but expected %d for input %d.",
|
||||
PyArray_NDIM(array), tensor->dims->size, i);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int j = 0; j < PyArray_NDIM(array); j++) {
|
||||
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"Cannot set tensor: Dimension mismatch");
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Cannot set tensor: Dimension mismatch."
|
||||
" Got %ld"
|
||||
" but expected %d for dimension %d of input %d.",
|
||||
PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = PyArray_NBYTES(array);
|
||||
if (size != tensor->bytes) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"numpy array had %zu bytes but expected %zu bytes.", size,
|
||||
tensor->bytes);
|
||||
return nullptr;
|
||||
if (tensor->type != kTfLiteString) {
|
||||
size_t size = PyArray_NBYTES(array);
|
||||
if (size != tensor->bytes) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"numpy array had %zu bytes but expected %zu bytes.", size,
|
||||
tensor->bytes);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(tensor->data.raw, PyArray_DATA(array), size);
|
||||
} else {
|
||||
DynamicBuffer dynamic_buffer;
|
||||
if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
|
||||
return nullptr;
|
||||
}
|
||||
dynamic_buffer.WriteToTensor(tensor, nullptr);
|
||||
}
|
||||
memcpy(tensor->data.raw, PyArray_DATA(array), size);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@ -345,19 +348,51 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
|
||||
|
||||
std::vector<npy_intp> dims(tensor->dims->data,
|
||||
tensor->dims->data + tensor->dims->size);
|
||||
// Make a buffer copy but we must tell Numpy It owns that data or else
|
||||
// it will leak.
|
||||
void* data = malloc(tensor->bytes);
|
||||
if (!data) {
|
||||
PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
|
||||
return nullptr;
|
||||
if (tensor->type != kTfLiteString) {
|
||||
// Make a buffer copy but we must tell Numpy It owns that data or else
|
||||
// it will leak.
|
||||
void* data = malloc(tensor->bytes);
|
||||
if (!data) {
|
||||
PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(data, tensor->data.raw, tensor->bytes);
|
||||
PyObject* np_array =
|
||||
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
|
||||
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
|
||||
NPY_ARRAY_OWNDATA);
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||
} else {
|
||||
// Create a C-order array so the data is contiguous in memory.
|
||||
const int32_t kCOrder = 0;
|
||||
PyObject* py_object =
|
||||
PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
|
||||
|
||||
if (py_object == nullptr) {
|
||||
PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
|
||||
PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
|
||||
auto num_strings = GetStringCount(tensor->data.raw);
|
||||
for (int j = 0; j < num_strings; ++j) {
|
||||
auto ref = GetString(tensor->data.raw, j);
|
||||
|
||||
PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
|
||||
if (bytes == nullptr) {
|
||||
Py_DECREF(py_object);
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Could not create PyBytes from string %d of input %d.", j,
|
||||
i);
|
||||
return nullptr;
|
||||
}
|
||||
// PyArray_EMPTY produces an array full of Py_None, which we must decref.
|
||||
Py_DECREF(data[j]);
|
||||
data[j] = bytes;
|
||||
}
|
||||
return py_object;
|
||||
}
|
||||
memcpy(data, tensor->data.raw, tensor->bytes);
|
||||
PyObject* np_array =
|
||||
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
|
||||
PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
|
||||
NPY_ARRAY_OWNDATA);
|
||||
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
|
||||
}
|
||||
|
||||
PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
|
||||
|
25
tensorflow/lite/python/interpreter_wrapper/numpy.cc
Normal file
25
tensorflow/lite/python/interpreter_wrapper/numpy.cc
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define TFLITE_IMPORT_NUMPY // See numpy.h for explanation.
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace python {
|
||||
|
||||
void ImportNumpy() { import_array1(); }
|
||||
|
||||
} // namespace python
|
||||
} // namespace tflite
|
62
tensorflow/lite/python/interpreter_wrapper/numpy.h
Normal file
62
tensorflow/lite/python/interpreter_wrapper/numpy.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
|
||||
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
|
||||
|
||||
#ifdef PyArray_Type
|
||||
#error "Numpy cannot be included before numpy.h."
|
||||
#endif
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
// To handle PyArray_* calles, numpy defines a static lookup table called
|
||||
// PyArray_API, or PY_ARRAY_UNIQUE_SYMBOL, if defined. This causes the
|
||||
// PyArray_* pointers to be different for different translation units, unless
|
||||
// we take care of selectivel defined NO_IMPORT_ARRAY.
|
||||
//
|
||||
// Virtually every usage will define NO_IMPORT_ARRAY, and will have access to
|
||||
// the lookup table via:
|
||||
// extern void **PyArray_API;
|
||||
// In numpy.cc we will define TFLITE_IMPORT_NUMPY, effectively disabling that
|
||||
// and instead using:
|
||||
// void **PyArray_API;
|
||||
// which is initialized when ImportNumpy() is called.
|
||||
//
|
||||
// If we don't define PY_ARRAY_UNIQUE_SYMBOL then PyArray_API is a static
|
||||
// variable, which causes strange crashes when the pointers are used across
|
||||
// translation unit boundaries.
|
||||
//
|
||||
// For mone info see https://sourceforge.net/p/numpy/mailman/message/5700519
|
||||
// See also tensorflow/python/lib/core/numpy.h for a similar approach.
|
||||
#define PY_ARRAY_UNIQUE_SYMBOL _tensorflow_numpy_api
|
||||
#ifndef TFLITE_IMPORT_NUMPY
|
||||
#define NO_IMPORT_ARRAY
|
||||
#endif
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace python {
|
||||
|
||||
void ImportNumpy();
|
||||
|
||||
} // namespace python
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
|
@ -15,9 +15,19 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace python_utils {
|
||||
|
||||
struct PyObjectDereferencer {
|
||||
void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
|
||||
};
|
||||
|
||||
using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
|
||||
|
||||
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||
switch (tf_lite_type) {
|
||||
case kTfLiteFloat32:
|
||||
@ -33,7 +43,7 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
|
||||
case kTfLiteInt64:
|
||||
return NPY_INT64;
|
||||
case kTfLiteString:
|
||||
return NPY_OBJECT;
|
||||
return NPY_STRING;
|
||||
case kTfLiteBool:
|
||||
return NPY_BOOL;
|
||||
case kTfLiteComplex64:
|
||||
@ -73,5 +83,82 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
|
||||
return kTfLiteNoType;
|
||||
}
|
||||
|
||||
#if PY_VERSION_HEX >= 0x03030000
|
||||
bool FillStringBufferFromPyUnicode(PyObject* value,
|
||||
DynamicBuffer* dynamic_buffer) {
|
||||
Py_ssize_t len = -1;
|
||||
char* buf = PyUnicode_AsUTF8AndSize(value, &len);
|
||||
if (buf == NULL) {
|
||||
PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
|
||||
return false;
|
||||
}
|
||||
dynamic_buffer->AddString(buf, len);
|
||||
return true;
|
||||
}
|
||||
#else
|
||||
bool FillStringBufferFromPyUnicode(PyObject* value,
|
||||
DynamicBuffer* dynamic_buffer) {
|
||||
UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
|
||||
if (!utemp) {
|
||||
PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
|
||||
return false;
|
||||
}
|
||||
char* buf = nullptr;
|
||||
Py_ssize_t len = -1;
|
||||
if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
|
||||
PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
|
||||
return false;
|
||||
}
|
||||
dynamic_buffer->AddString(buf, len);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool FillStringBufferFromPyString(PyObject* value,
|
||||
DynamicBuffer* dynamic_buffer) {
|
||||
if (PyUnicode_Check(value)) {
|
||||
return FillStringBufferFromPyUnicode(value, dynamic_buffer);
|
||||
}
|
||||
|
||||
char* buf = nullptr;
|
||||
Py_ssize_t len = -1;
|
||||
if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
|
||||
PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
|
||||
return false;
|
||||
}
|
||||
dynamic_buffer->AddString(buf, len);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FillStringBufferWithPyArray(PyObject* value,
|
||||
DynamicBuffer* dynamic_buffer) {
|
||||
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
|
||||
switch (PyArray_TYPE(array)) {
|
||||
case NPY_OBJECT:
|
||||
case NPY_STRING:
|
||||
case NPY_UNICODE: {
|
||||
UniquePyObjectRef iter(PyArray_IterNew(value));
|
||||
while (PyArray_ITER_NOTDONE(iter.get())) {
|
||||
UniquePyObjectRef item(PyArray_GETITEM(
|
||||
array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
|
||||
|
||||
if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
PyArray_ITER_NEXT(iter.get());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Cannot use numpy array of type %d for string tensor.",
|
||||
PyArray_TYPE(array));
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace python_utils
|
||||
} // namespace tflite
|
||||
|
@ -17,14 +17,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||
|
||||
#include "tensorflow/lite/context.h"
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace python_utils {
|
||||
@ -33,6 +27,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type);
|
||||
|
||||
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
|
||||
|
||||
bool FillStringBufferWithPyArray(PyObject* value,
|
||||
DynamicBuffer* dynamic_buffer);
|
||||
|
||||
} // namespace python_utils
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
|
||||
|
@ -131,13 +131,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertEqual(1, len(input_details))
|
||||
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||
self.assertEqual(np.object_, input_details[0]['dtype'])
|
||||
self.assertEqual(np.string_, input_details[0]['dtype'])
|
||||
self.assertTrue(([4] == input_details[0]['shape']).all())
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(1, len(output_details))
|
||||
self.assertEqual('Reshape', output_details[0]['name'])
|
||||
self.assertEqual(np.object_, output_details[0]['dtype'])
|
||||
self.assertEqual(np.string_, output_details[0]['dtype'])
|
||||
self.assertTrue(([2, 2] == output_details[0]['shape']).all())
|
||||
# TODO(b/122659643): Test setting/getting string data via the python
|
||||
# interpreter API after support has been added.
|
||||
|
@ -13,6 +13,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:numpy",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
|
||||
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
|
||||
"//tensorflow/lite/tools/optimize:calibration_reader",
|
||||
|
@ -22,20 +22,13 @@ limitations under the License.
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
|
||||
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
|
||||
#include "tensorflow/lite/tools/optimize/calibration_reader.h"
|
||||
#include "tensorflow/lite/tools/optimize/calibrator.h"
|
||||
#include "tensorflow/lite/tools/optimize/quantize_model.h"
|
||||
|
||||
// Disallow Numpy 1.7 deprecated symbols.
|
||||
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
|
||||
#define CPP_TO_PYSTRING PyBytes_FromStringAndSize
|
||||
|
11
tensorflow/lite/python/testdata/BUILD
vendored
11
tensorflow/lite/python/testdata/BUILD
vendored
@ -32,9 +32,20 @@ tf_to_tflite(
|
||||
],
|
||||
)
|
||||
|
||||
tf_to_tflite(
|
||||
name = "gather_string",
|
||||
src = "gather.pbtxt",
|
||||
out = "gather_string.tflite",
|
||||
options = [
|
||||
"--input_arrays=input,indices",
|
||||
"--output_arrays=output",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "interpreter_test_data",
|
||||
srcs = [
|
||||
":gather_string",
|
||||
":permute_float",
|
||||
":permute_uint8",
|
||||
],
|
||||
|
93
tensorflow/lite/python/testdata/gather.pbtxt
vendored
Normal file
93
tensorflow/lite/python/testdata/gather.pbtxt
vendored
Normal file
@ -0,0 +1,93 @@
|
||||
node {
|
||||
name: "input"
|
||||
op: "Placeholder"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "indices"
|
||||
op: "Placeholder"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "axis"
|
||||
op: "Const"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "output"
|
||||
op: "GatherV2"
|
||||
input: "input"
|
||||
input: "indices"
|
||||
input: "axis"
|
||||
device: "/device:CPU:0"
|
||||
attr {
|
||||
key: "Taxis"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tindices"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tparams"
|
||||
value {
|
||||
type: DT_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 27
|
||||
}
|
Loading…
Reference in New Issue
Block a user