Minor refactoring the TF_Tensor <-> PyArray conversion functions.
PiperOrigin-RevId: 163802822
This commit is contained in:
parent
618f913bbd
commit
6209b4b524
@ -208,45 +208,52 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Determine the pointer and offset of the string at offset 'i' in the string
|
||||
// tensor 'src', whose total length is 'num_elements'.
|
||||
static Status TF_StringTensor_GetPtrAndLen(const TF_Tensor* src,
|
||||
tensorflow::int64 num_elements,
|
||||
tensorflow::int64 i,
|
||||
const char** ptr,
|
||||
tensorflow::uint64* len) {
|
||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
||||
const size_t src_size = TF_TensorByteSize(src);
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
tensorflow::uint64 offset =
|
||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
||||
const char* p =
|
||||
tensorflow::core::GetVarint64Ptr(data_start + offset, limit, len);
|
||||
if (static_cast<int64>(offset) >= (limit - data_start) || !p ||
|
||||
static_cast<int64>(*len) > (limit - p)) {
|
||||
return errors::InvalidArgument("Malformed TF_STRING tensor; element ", i,
|
||||
" out of range");
|
||||
}
|
||||
*ptr = p;
|
||||
return Status::OK();
|
||||
}
|
||||
Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
|
||||
PyArrayObject* dst) {
|
||||
const void* tensor_data = TF_TensorData(src);
|
||||
const size_t tensor_size = TF_TensorByteSize(src);
|
||||
const char* limit = static_cast<const char*>(tensor_data) + tensor_size;
|
||||
DCHECK(tensor_data != nullptr);
|
||||
DCHECK_EQ(TF_STRING, TF_TensorType(src));
|
||||
|
||||
// Copy the string at offset 'i' in the (linearized) string tensor 'tensor' into
|
||||
// 'pyarray' at offset pointed by the 'i_ptr' iterator.
|
||||
static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
|
||||
TF_Tensor* tensor,
|
||||
tensorflow::int64 num_elements,
|
||||
tensorflow::int64 i) {
|
||||
const char* ptr = nullptr;
|
||||
tensorflow::uint64 len = 0;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
|
||||
auto py_string = tensorflow::make_safe(PyBytes_FromStringAndSize(ptr, len));
|
||||
int success = PyArray_SETITEM(
|
||||
pyarray, static_cast<char*>(PyArray_ITER_DATA(i_ptr)), py_string.get());
|
||||
if (success != 0) {
|
||||
return errors::Internal("Error setting element ", i);
|
||||
const uint64* offsets = static_cast<const uint64*>(tensor_data);
|
||||
const size_t offsets_size = sizeof(uint64) * nelems;
|
||||
const char* data = static_cast<const char*>(tensor_data) + offsets_size;
|
||||
|
||||
const size_t expected_tensor_size =
|
||||
(limit - static_cast<const char*>(tensor_data));
|
||||
if (expected_tensor_size - tensor_size) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid/corrupt TF_STRING tensor: expected ", expected_tensor_size,
|
||||
" bytes of encoded strings for the tensor containing ", nelems,
|
||||
" strings, but the tensor is encoded in ", tensor_size, " bytes");
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
|
||||
for (int64 i = 0; i < nelems; ++i) {
|
||||
const char* start = data + offsets[i];
|
||||
const char* ptr = nullptr;
|
||||
size_t len = 0;
|
||||
|
||||
TF_StringDecode(start, limit - start, &ptr, &len, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) {
|
||||
return errors::InvalidArgument(TF_Message(status.get()));
|
||||
}
|
||||
|
||||
auto py_string = make_safe(PyBytes_FromStringAndSize(ptr, len));
|
||||
if (py_string.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"failed to create a python byte array when converting element #", i,
|
||||
" of a TF_STRING tensor to a numpy ndarray");
|
||||
}
|
||||
|
||||
if (PyArray_SETITEM(dst, static_cast<char*>(PyArray_ITER_DATA(iter.get())),
|
||||
py_string.get()) != 0) {
|
||||
return errors::Internal("Error settings element #", i,
|
||||
" in the numpy ndarray");
|
||||
}
|
||||
PyArray_ITER_NEXT(iter.get());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -304,7 +311,7 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
|
||||
|
||||
// Converts the given TF_Tensor to a numpy ndarray.
|
||||
// If the returned status is OK, the caller becomes the owner of *out_array.
|
||||
Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
// A fetched operation will correspond to a null tensor, and a None
|
||||
// in Python.
|
||||
if (tensor == nullptr) {
|
||||
@ -312,15 +319,16 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
*out_ndarray = Py_None;
|
||||
return Status::OK();
|
||||
}
|
||||
if (TF_TensorData(tensor.get()) == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"TF_Tensor must be in host memory (not device memory) in order to "
|
||||
"create a numpy ndarray");
|
||||
}
|
||||
|
||||
tensorflow::int64 nelems = -1;
|
||||
int64 nelems = -1;
|
||||
gtl::InlinedVector<npy_intp, 4> dims =
|
||||
GetPyArrayDimensionsForTensor(tensor.get(), &nelems);
|
||||
|
||||
// Convert TensorFlow dtype to numpy type descriptor.
|
||||
PyArray_Descr* descr = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
|
||||
|
||||
// If the type is neither string nor resource we can reuse the Tensor memory.
|
||||
TF_Tensor* original = tensor.get();
|
||||
TF_Tensor* moved = TF_TensorMaybeMove(tensor.release());
|
||||
@ -335,6 +343,8 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
tensor.reset(original);
|
||||
|
||||
// Copy the TF_TensorData into a newly-created ndarray and return it.
|
||||
PyArray_Descr* descr = nullptr;
|
||||
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
|
||||
Safe_PyObjectPtr safe_out_array =
|
||||
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
|
||||
if (!safe_out_array) {
|
||||
@ -343,15 +353,9 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
PyArrayObject* py_array =
|
||||
reinterpret_cast<PyArrayObject*>(safe_out_array.get());
|
||||
if (TF_TensorType(tensor.get()) == TF_STRING) {
|
||||
// Copy element by element.
|
||||
auto iter = tensorflow::make_safe(PyArray_IterNew(safe_out_array.get()));
|
||||
for (tensorflow::int64 i = 0; i < nelems; ++i) {
|
||||
auto s = CopyStringToPyArrayElement(py_array, iter.get(), tensor.get(),
|
||||
nelems, i);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
PyArray_ITER_NEXT(iter.get());
|
||||
Status s = CopyTF_TensorStringsToPyArray(tensor.get(), nelems, py_array);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
} else if (static_cast<size_t>(PyArray_NBYTES(py_array)) !=
|
||||
TF_TensorByteSize(tensor.get())) {
|
||||
@ -370,12 +374,12 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||
}
|
||||
|
||||
// Converts the given numpy ndarray to a (safe) TF_Tensor. The returned
|
||||
// TF_Tensor in `out_tensor` will have its own Python reference to `ndarray`s
|
||||
// TF_Tensor in `out_tensor` may have its own Python reference to `ndarray`s
|
||||
// data. After `out_tensor` is destroyed, this reference must (eventually) be
|
||||
// decremented via ClearDecrefCache().
|
||||
//
|
||||
// `out_tensor` must be non-null. Caller retains ownership of `ndarray`.
|
||||
Status PyArrayToTFTensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||
Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||
DCHECK(out_tensor != nullptr);
|
||||
|
||||
// Make sure we dereference this array object in case of error, etc.
|
||||
@ -468,7 +472,7 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
|
||||
input_names.push_back(key_string);
|
||||
|
||||
inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
||||
s = PyArrayToTFTensor(value, &inputs_safe.back());
|
||||
s = PyArrayToTF_Tensor(value, &inputs_safe.back());
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
@ -521,7 +525,7 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
|
||||
std::vector<Safe_PyObjectPtr> py_outputs_safe;
|
||||
for (size_t i = 0; i < output_names.size(); ++i) {
|
||||
PyObject* py_array;
|
||||
s = TFTensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
|
||||
s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
@ -601,7 +605,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
||||
// cleaned up properly.
|
||||
//
|
||||
// Memory management:
|
||||
// PyArrayToTFTensor() creates a new ndarray PyObject from the input
|
||||
// PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
|
||||
// ndarray. We manage the new ndarray's lifetime in order to keep the
|
||||
// underlying data buffer alive (the new ndarray also guarantees a contiguous
|
||||
// data buffer). The new ndarray's data buffer is used to create the
|
||||
@ -616,7 +620,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
||||
std::vector<Safe_TF_TensorPtr> input_vals_safe;
|
||||
for (PyObject* ndarray : input_ndarrays) {
|
||||
input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
||||
s = PyArrayToTFTensor(ndarray, &input_vals_safe.back());
|
||||
s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
@ -654,7 +658,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
||||
std::vector<Safe_PyObjectPtr> py_outputs_safe;
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
PyObject* py_array;
|
||||
s = TFTensorToPyArray(std::move(output_vals_safe[i]), &py_array);
|
||||
s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
|
Loading…
Reference in New Issue
Block a user