Removed PyFunc-specific conversion functions

Note that unlike their Conver* counterparts, NdarrayToTensor and
TensorToNdarray

* Convert to and from NumPy array through TF_Tensor, which means that
  py_func with DT_STRING/DT_OBJECT inputs/outputs will become slower
  due to *two* copies of the data.
* Do not preserve zero-padding when converting a np.str_ array to
  Tensor. NumPy zero-pads np.str_ arrays and strips all zero bytes on
  element access. The correct NumPy dtype for potentially zero-terminated
  data is np.object_.

PiperOrigin-RevId: 253559726
This commit is contained in:
Sergei Lebedev 2019-06-17 05:14:21 -07:00 committed by TensorFlower Gardener
parent ee17b07b70
commit 6cf83ea14d
8 changed files with 39 additions and 254 deletions

View File

@ -636,9 +636,6 @@ class Tensor {
Tensor* parent, Tensor* element,
int64 index); // For access to RefCountIsOne().
friend class NumpyTensorBuffer; // For access to the private constructor
// taking the buffer.
// Creates a tensor with the input datatype, shape and buf.
//
// Acquires a ref on buf that belongs to this Tensor.

View File

@ -419,6 +419,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:script_ops_op_lib",
"//tensorflow/python:ndarray_tensor",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
@ -4623,6 +4624,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/python:ndarray_tensor",
],
)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/py_func.h"
namespace tensorflow {
@ -106,8 +107,7 @@ Status RunCppShapeInferenceImpl(
if (py_val == Py_None) {
input_tensors.push_back(nullptr);
} else {
TF_RETURN_IF_ERROR(
ConvertNdarrayToTensor(py_val, &input_tensor_values[i]));
TF_RETURN_IF_ERROR(NdarrayToTensor(py_val, &input_tensor_values[i]));
input_tensors.push_back(&input_tensor_values[i]);
}
}

View File

@ -236,6 +236,14 @@ class PyFuncTest(test.TestCase):
s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
self.assertAllEqual(s.eval(), correct)
@test_util.run_v1_only("b/120545219")
def testNulTerminatedStrings(self):
inp = np.array(["this\0", "is\0\0", "a\0", "test\0\0"], dtype=np.str_)
correct = [b"this", b"is", b"a", b"test"]
with self.cached_session():
s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
self.assertAllEqual(s.eval(), correct)
@test_util.run_v1_only("b/120545219")
def testLarge(self):
with self.cached_session() as sess:
@ -280,8 +288,8 @@ class PyFuncTest(test.TestCase):
y, = script_ops.py_func(bad, [], [dtypes.float32])
with self.assertRaisesRegexp(errors.UnimplementedError,
"Unsupported numpy type"):
with self.assertRaisesRegexp(errors.InternalError,
"Unsupported numpy data type"):
self.evaluate(y)
@test_util.run_v1_only("b/120545219")
@ -294,7 +302,7 @@ class PyFuncTest(test.TestCase):
z, = script_ops.py_func(bad, [], [dtypes.int64])
with self.assertRaisesRegexp(errors.UnimplementedError,
with self.assertRaisesRegexp(errors.InternalError,
"Unsupported object type"):
self.evaluate(z)

View File

@ -178,28 +178,31 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
PyObject** ptr_owner) {
*ptr_owner = nullptr;
if (!PyUnicode_Check(obj)) {
if (PyBytes_Check(obj)) {
char* buf;
if (PyBytes_AsStringAndSize(obj, &buf, len) != 0) {
return errors::Internal("Unable to get element as bytes.");
}
*ptr = buf;
return Status::OK();
}
} else if (PyUnicode_Check(obj)) {
#if (PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 3))
*ptr = PyUnicode_AsUTF8AndSize(obj, len);
if (*ptr != nullptr) return Status::OK();
*ptr = PyUnicode_AsUTF8AndSize(obj, len);
if (*ptr != nullptr) return Status::OK();
#else
PyObject* utemp = PyUnicode_AsUTF8String(obj);
char* buf;
if (utemp != nullptr && PyBytes_AsStringAndSize(utemp, &buf, len) != -1) {
*ptr = buf;
*ptr_owner = utemp;
return Status::OK();
}
Py_XDECREF(utemp);
PyObject* utemp = PyUnicode_AsUTF8String(obj);
char* buf;
if (utemp != nullptr && PyBytes_AsStringAndSize(utemp, &buf, len) != -1) {
*ptr = buf;
*ptr_owner = utemp;
return Status::OK();
}
Py_XDECREF(utemp);
#endif
return errors::Internal("Unable to convert element to UTF-8.");
return errors::Internal("Unable to convert element to UTF-8");
} else {
return errors::Internal("Unsupported object type ", obj->ob_type->tp_name);
}
}
// Iterate over the string array 'array', extract the ptr and len of each string
@ -216,7 +219,7 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) {
}
Py_ssize_t len;
const char* ptr;
PyObject* ptr_owner;
PyObject* ptr_owner = nullptr;
TF_RETURN_IF_ERROR(PyObjectToString(item.get(), &ptr, &len, &ptr_owner));
f(ptr, len);
Py_XDECREF(ptr_owner);

View File

@ -15,10 +15,10 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_func.h"
#include <array>
#include <Python.h>
#include <array>
#include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
@ -90,7 +91,7 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
return errors::Internal("Unable to procure EagerTensor from Tensor.");
}
} else {
Status s = ConvertTensorToNdarray(t, &arg);
Status s = TensorToNdarray(t, &arg);
if (!s.ok()) {
Py_DECREF(lst);
return s;
@ -106,53 +107,6 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
return Status::OK();
}
// Returns the corresponding tf dtype in 'tf' for numpy data type
// 'np'. Returns an error if the type is not supported by this
// module.
Status NumericNpDTypeToTfDType(const int np, DataType* tf) {
switch (np) {
case NPY_FLOAT16:
*tf = DT_HALF;
break;
case NPY_FLOAT32:
*tf = DT_FLOAT;
break;
case NPY_FLOAT64:
*tf = DT_DOUBLE;
break;
case NPY_INT32:
*tf = DT_INT32;
break;
case NPY_UINT8:
*tf = DT_UINT8;
break;
case NPY_INT8:
*tf = DT_INT8;
break;
case NPY_UINT16:
*tf = DT_UINT16;
break;
case NPY_INT16:
*tf = DT_INT16;
break;
case NPY_INT64:
*tf = DT_INT64;
break;
case NPY_BOOL:
*tf = DT_BOOL;
break;
case NPY_COMPLEX64:
*tf = DT_COMPLEX64;
break;
case NPY_COMPLEX128:
*tf = DT_COMPLEX128;
break;
default:
return errors::Unimplemented("Unsupported numpy type ", np);
}
return Status::OK();
}
bool IsSingleNone(PyObject* obj) {
if (!PyArray_Check(obj)) {
return false;
@ -268,7 +222,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
Py_TYPE(item)->tp_name);
}
} else {
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
s = NdarrayToTensor(PyList_GetItem(result, i), &t);
}
if (!s.ok()) {
@ -289,7 +243,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
DCHECK(!call->eager);
if (!IsSingleNone(result)) {
Tensor t;
s = ConvertNdarrayToTensor(result, &t);
s = NdarrayToTensor(result, &t);
if (s.ok()) {
call->out.push_back(t);
}
@ -304,177 +258,6 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
} // end namespace
// Outside anonymous namespace just to make the friend declaration in
// tensorflow::Tensor apply.
class NumpyTensorBuffer : public TensorBuffer {
public:
NumpyTensorBuffer(PyArrayObject* array, size_t len, void* data)
: TensorBuffer(data), array_(array), len_(len) {}
~NumpyTensorBuffer() override {
// Note: The session::run wrapper is responsible for freeing this while
// holding the GIL.
DelayedNumpyDecref(data(), len_, array_);
}
size_t size() const override { return len_; }
TensorBuffer* root_buffer() override { return this; }
void FillAllocationDescription(AllocationDescription* proto) const override {
tensorflow::int64 rb = size();
proto->set_requested_bytes(rb);
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
Tensor MakeTensor(DataType dtype, const TensorShape& shape) {
CHECK_EQ(len_, shape.num_elements() * DataTypeSize(dtype));
return Tensor(dtype, shape, this);
}
// Prevents input forwarding from overwriting this buffer.
bool OwnsMemory() const override { return false; }
private:
PyArrayObject* array_;
size_t len_;
};
Status PyObjectToString(PyObject* obj, string* str) {
char* py_bytes;
Py_ssize_t size;
if (PyBytes_AsStringAndSize(obj, &py_bytes, &size) != -1) {
str->assign(py_bytes, size);
return Status::OK();
}
#if PY_MAJOR_VERSION >= 3
const char* ptr = PyUnicode_AsUTF8AndSize(obj, &size);
if (ptr != nullptr) {
str->assign(ptr, size);
return Status::OK();
}
#else
if (PyUnicode_Check(obj)) {
PyObject* unicode = PyUnicode_AsUTF8String(obj);
char* ptr;
if (unicode && PyString_AsStringAndSize(unicode, &ptr, &size) != -1) {
str->assign(ptr, size);
Py_DECREF(unicode);
return Status::OK();
}
Py_XDECREF(unicode);
}
#endif
return errors::Unimplemented("Unsupported object type ",
obj->ob_type->tp_name);
}
Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
PyArrayObject* input = reinterpret_cast<PyArrayObject*>(obj);
DataType dtype = DT_INVALID;
TensorShape shape;
for (int i = 0; i < PyArray_NDIM(input); ++i) {
shape.AddDim(PyArray_SHAPE(input)[i]);
}
const int np_type = PyArray_TYPE(input);
switch (np_type) {
case NPY_OBJECT: {
dtype = DT_STRING;
Tensor t(dtype, shape);
auto tflat = t.flat<string>();
PyObject** input_data = reinterpret_cast<PyObject**>(PyArray_DATA(input));
for (int i = 0; i < tflat.dimension(0); ++i) {
TF_RETURN_IF_ERROR(PyObjectToString(input_data[i], &tflat(i)));
}
*ret = t;
break;
}
case NPY_STRING: {
dtype = DT_STRING;
Tensor t(dtype, shape);
auto tflat = t.flat<string>();
char* input_data = PyArray_BYTES(input);
Py_ssize_t el_size = PyArray_ITEMSIZE(input);
for (int i = 0; i < tflat.dimension(0); ++i) {
tflat(i) = string(input_data + i * el_size, el_size);
}
*ret = t;
break;
}
default: {
TF_RETURN_IF_ERROR(NumericNpDTypeToTfDType(PyArray_TYPE(input), &dtype));
CHECK(DataTypeCanUseMemcpy(dtype));
if (reinterpret_cast<intptr_t>(PyArray_DATA(input)) %
std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
0) {
Tensor t(dtype, shape);
StringPiece p = t.tensor_data();
memcpy(const_cast<char*>(p.data()), PyArray_DATA(input), p.size());
*ret = t;
} else {
// Incref the array as the calling context will decref it when we
// return and we want to keep a handle to this memory.
Py_INCREF(input);
NumpyTensorBuffer* buf = new NumpyTensorBuffer(
input, shape.num_elements() * DataTypeSize(dtype),
PyArray_DATA(input));
*ret = buf->MakeTensor(dtype, shape);
buf->Unref();
}
}
}
return Status::OK();
}
// Creates a numpy array in 'ret' which either aliases the content of 't' or has
// a copy.
Status ConvertTensorToNdarray(const Tensor& t, PyObject** ret) {
int typenum = -1;
TF_RETURN_IF_ERROR(TF_DataType_to_PyArray_TYPE(
static_cast<TF_DataType>(t.dtype()), &typenum));
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
CHECK(descr);
std::vector<npy_intp> dims;
dims.reserve(t.dims());
for (int i = 0; i < t.dims(); ++i) {
dims.push_back(t.dim_size(i));
}
Tensor* copy = new Tensor(t);
if (ArrayFromMemory(dims.size(), dims.data(),
const_cast<char*>(copy->tensor_data().data()), t.dtype(),
[copy]() { delete copy; }, ret)
.ok()) {
return Status::OK();
}
delete copy;
PyObject* obj = PyArray_Empty(dims.size(), dims.data(), descr, 0);
if (obj == nullptr) {
return errors::Internal("Failed to allocate np array: ",
t.shape().DebugString());
}
PyArrayObject* np_array = reinterpret_cast<PyArrayObject*>(obj);
if (typenum == NPY_OBJECT) {
CHECK_EQ(DT_STRING, t.dtype());
auto tflat = t.flat<string>();
PyObject** out = reinterpret_cast<PyObject**>(PyArray_DATA(np_array));
for (int i = 0; i < tflat.dimension(0); ++i) {
const string& el = tflat(i);
out[i] = PyBytes_FromStringAndSize(el.data(), el.size());
if (out[i] == nullptr) {
for (int j = 0; j < i; ++j) {
Py_DECREF(out[j]);
}
Py_DECREF(obj);
return errors::Internal("Failed to allocate a copy of string ", i);
}
}
} else {
CHECK(DataTypeCanUseMemcpy(t.dtype()));
StringPiece p = t.tensor_data();
memcpy(PyArray_DATA(np_array), p.data(), p.size());
}
*ret = reinterpret_cast<PyObject*>(np_array);
return Status::OK();
}
void InitializePyTrampoline(PyObject* trampoline) {
mutex_lock l(mu);
if (py_trampoline == nullptr) {

View File

@ -49,14 +49,6 @@ namespace tensorflow {
// TODO(zhifengc): Support distributed runtime.
void InitializePyTrampoline(PyObject* trampoline);
// Creates a numpy array in 'ret' and copies the content of tensor 't'
// into 'ret'.
Status ConvertTensorToNdarray(const Tensor& t, PyObject** ret);
// Given an numpy ndarray object 'obj', creates a corresponding tf
// Tensor in '*ret'.
Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret);
} // end namespace tensorflow
#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_

View File

@ -110,7 +110,7 @@ static PyObject* CheckpointReader_GetTensor(
reader->GetTensor(name, &tensor, status);
if (TF_GetCode(status) == TF_OK) {
tensorflow::Status s =
tensorflow::ConvertTensorToNdarray(*tensor.get(), &py_obj);
tensorflow::TensorToNdarray(*tensor.get(), &py_obj);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}