Minor refactoring the TF_Tensor <-> PyArray conversion functions.
PiperOrigin-RevId: 163802822
This commit is contained in:
parent
618f913bbd
commit
6209b4b524
@ -208,45 +208,52 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the pointer and offset of the string at offset 'i' in the string
|
Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
|
||||||
// tensor 'src', whose total length is 'num_elements'.
|
PyArrayObject* dst) {
|
||||||
static Status TF_StringTensor_GetPtrAndLen(const TF_Tensor* src,
|
const void* tensor_data = TF_TensorData(src);
|
||||||
tensorflow::int64 num_elements,
|
const size_t tensor_size = TF_TensorByteSize(src);
|
||||||
tensorflow::int64 i,
|
const char* limit = static_cast<const char*>(tensor_data) + tensor_size;
|
||||||
const char** ptr,
|
DCHECK(tensor_data != nullptr);
|
||||||
tensorflow::uint64* len) {
|
DCHECK_EQ(TF_STRING, TF_TensorType(src));
|
||||||
const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
|
|
||||||
const size_t src_size = TF_TensorByteSize(src);
|
|
||||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
|
||||||
const char* limit = input + src_size;
|
|
||||||
tensorflow::uint64 offset =
|
|
||||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
|
||||||
const char* p =
|
|
||||||
tensorflow::core::GetVarint64Ptr(data_start + offset, limit, len);
|
|
||||||
if (static_cast<int64>(offset) >= (limit - data_start) || !p ||
|
|
||||||
static_cast<int64>(*len) > (limit - p)) {
|
|
||||||
return errors::InvalidArgument("Malformed TF_STRING tensor; element ", i,
|
|
||||||
" out of range");
|
|
||||||
}
|
|
||||||
*ptr = p;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy the string at offset 'i' in the (linearized) string tensor 'tensor' into
|
const uint64* offsets = static_cast<const uint64*>(tensor_data);
|
||||||
// 'pyarray' at offset pointed by the 'i_ptr' iterator.
|
const size_t offsets_size = sizeof(uint64) * nelems;
|
||||||
static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr,
|
const char* data = static_cast<const char*>(tensor_data) + offsets_size;
|
||||||
TF_Tensor* tensor,
|
|
||||||
tensorflow::int64 num_elements,
|
const size_t expected_tensor_size =
|
||||||
tensorflow::int64 i) {
|
(limit - static_cast<const char*>(tensor_data));
|
||||||
const char* ptr = nullptr;
|
if (expected_tensor_size - tensor_size) {
|
||||||
tensorflow::uint64 len = 0;
|
return errors::InvalidArgument(
|
||||||
TF_RETURN_IF_ERROR(
|
"Invalid/corrupt TF_STRING tensor: expected ", expected_tensor_size,
|
||||||
TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len));
|
" bytes of encoded strings for the tensor containing ", nelems,
|
||||||
auto py_string = tensorflow::make_safe(PyBytes_FromStringAndSize(ptr, len));
|
" strings, but the tensor is encoded in ", tensor_size, " bytes");
|
||||||
int success = PyArray_SETITEM(
|
}
|
||||||
pyarray, static_cast<char*>(PyArray_ITER_DATA(i_ptr)), py_string.get());
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
if (success != 0) {
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
return errors::Internal("Error setting element ", i);
|
auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
|
||||||
|
for (int64 i = 0; i < nelems; ++i) {
|
||||||
|
const char* start = data + offsets[i];
|
||||||
|
const char* ptr = nullptr;
|
||||||
|
size_t len = 0;
|
||||||
|
|
||||||
|
TF_StringDecode(start, limit - start, &ptr, &len, status.get());
|
||||||
|
if (TF_GetCode(status.get()) != TF_OK) {
|
||||||
|
return errors::InvalidArgument(TF_Message(status.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto py_string = make_safe(PyBytes_FromStringAndSize(ptr, len));
|
||||||
|
if (py_string.get() == nullptr) {
|
||||||
|
return errors::Internal(
|
||||||
|
"failed to create a python byte array when converting element #", i,
|
||||||
|
" of a TF_STRING tensor to a numpy ndarray");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (PyArray_SETITEM(dst, static_cast<char*>(PyArray_ITER_DATA(iter.get())),
|
||||||
|
py_string.get()) != 0) {
|
||||||
|
return errors::Internal("Error settings element #", i,
|
||||||
|
" in the numpy ndarray");
|
||||||
|
}
|
||||||
|
PyArray_ITER_NEXT(iter.get());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -304,7 +311,7 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
|
|||||||
|
|
||||||
// Converts the given TF_Tensor to a numpy ndarray.
|
// Converts the given TF_Tensor to a numpy ndarray.
|
||||||
// If the returned status is OK, the caller becomes the owner of *out_array.
|
// If the returned status is OK, the caller becomes the owner of *out_array.
|
||||||
Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
||||||
// A fetched operation will correspond to a null tensor, and a None
|
// A fetched operation will correspond to a null tensor, and a None
|
||||||
// in Python.
|
// in Python.
|
||||||
if (tensor == nullptr) {
|
if (tensor == nullptr) {
|
||||||
@ -312,15 +319,16 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
|||||||
*out_ndarray = Py_None;
|
*out_ndarray = Py_None;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
if (TF_TensorData(tensor.get()) == nullptr) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"TF_Tensor must be in host memory (not device memory) in order to "
|
||||||
|
"create a numpy ndarray");
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::int64 nelems = -1;
|
int64 nelems = -1;
|
||||||
gtl::InlinedVector<npy_intp, 4> dims =
|
gtl::InlinedVector<npy_intp, 4> dims =
|
||||||
GetPyArrayDimensionsForTensor(tensor.get(), &nelems);
|
GetPyArrayDimensionsForTensor(tensor.get(), &nelems);
|
||||||
|
|
||||||
// Convert TensorFlow dtype to numpy type descriptor.
|
|
||||||
PyArray_Descr* descr = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
|
|
||||||
|
|
||||||
// If the type is neither string nor resource we can reuse the Tensor memory.
|
// If the type is neither string nor resource we can reuse the Tensor memory.
|
||||||
TF_Tensor* original = tensor.get();
|
TF_Tensor* original = tensor.get();
|
||||||
TF_Tensor* moved = TF_TensorMaybeMove(tensor.release());
|
TF_Tensor* moved = TF_TensorMaybeMove(tensor.release());
|
||||||
@ -335,6 +343,8 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
|||||||
tensor.reset(original);
|
tensor.reset(original);
|
||||||
|
|
||||||
// Copy the TF_TensorData into a newly-created ndarray and return it.
|
// Copy the TF_TensorData into a newly-created ndarray and return it.
|
||||||
|
PyArray_Descr* descr = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
|
||||||
Safe_PyObjectPtr safe_out_array =
|
Safe_PyObjectPtr safe_out_array =
|
||||||
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
|
tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
|
||||||
if (!safe_out_array) {
|
if (!safe_out_array) {
|
||||||
@ -343,15 +353,9 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
|||||||
PyArrayObject* py_array =
|
PyArrayObject* py_array =
|
||||||
reinterpret_cast<PyArrayObject*>(safe_out_array.get());
|
reinterpret_cast<PyArrayObject*>(safe_out_array.get());
|
||||||
if (TF_TensorType(tensor.get()) == TF_STRING) {
|
if (TF_TensorType(tensor.get()) == TF_STRING) {
|
||||||
// Copy element by element.
|
Status s = CopyTF_TensorStringsToPyArray(tensor.get(), nelems, py_array);
|
||||||
auto iter = tensorflow::make_safe(PyArray_IterNew(safe_out_array.get()));
|
if (!s.ok()) {
|
||||||
for (tensorflow::int64 i = 0; i < nelems; ++i) {
|
return s;
|
||||||
auto s = CopyStringToPyArrayElement(py_array, iter.get(), tensor.get(),
|
|
||||||
nelems, i);
|
|
||||||
if (!s.ok()) {
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
PyArray_ITER_NEXT(iter.get());
|
|
||||||
}
|
}
|
||||||
} else if (static_cast<size_t>(PyArray_NBYTES(py_array)) !=
|
} else if (static_cast<size_t>(PyArray_NBYTES(py_array)) !=
|
||||||
TF_TensorByteSize(tensor.get())) {
|
TF_TensorByteSize(tensor.get())) {
|
||||||
@ -370,12 +374,12 @@ Status TFTensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Converts the given numpy ndarray to a (safe) TF_Tensor. The returned
|
// Converts the given numpy ndarray to a (safe) TF_Tensor. The returned
|
||||||
// TF_Tensor in `out_tensor` will have its own Python reference to `ndarray`s
|
// TF_Tensor in `out_tensor` may have its own Python reference to `ndarray`s
|
||||||
// data. After `out_tensor` is destroyed, this reference must (eventually) be
|
// data. After `out_tensor` is destroyed, this reference must (eventually) be
|
||||||
// decremented via ClearDecrefCache().
|
// decremented via ClearDecrefCache().
|
||||||
//
|
//
|
||||||
// `out_tensor` must be non-null. Caller retains ownership of `ndarray`.
|
// `out_tensor` must be non-null. Caller retains ownership of `ndarray`.
|
||||||
Status PyArrayToTFTensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||||
DCHECK(out_tensor != nullptr);
|
DCHECK(out_tensor != nullptr);
|
||||||
|
|
||||||
// Make sure we dereference this array object in case of error, etc.
|
// Make sure we dereference this array object in case of error, etc.
|
||||||
@ -468,7 +472,7 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
|
|||||||
input_names.push_back(key_string);
|
input_names.push_back(key_string);
|
||||||
|
|
||||||
inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
||||||
s = PyArrayToTFTensor(value, &inputs_safe.back());
|
s = PyArrayToTF_Tensor(value, &inputs_safe.back());
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Set_TF_Status_from_Status(out_status, s);
|
Set_TF_Status_from_Status(out_status, s);
|
||||||
return;
|
return;
|
||||||
@ -521,7 +525,7 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
|
|||||||
std::vector<Safe_PyObjectPtr> py_outputs_safe;
|
std::vector<Safe_PyObjectPtr> py_outputs_safe;
|
||||||
for (size_t i = 0; i < output_names.size(); ++i) {
|
for (size_t i = 0; i < output_names.size(); ++i) {
|
||||||
PyObject* py_array;
|
PyObject* py_array;
|
||||||
s = TFTensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
|
s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Set_TF_Status_from_Status(out_status, s);
|
Set_TF_Status_from_Status(out_status, s);
|
||||||
return;
|
return;
|
||||||
@ -601,7 +605,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
|||||||
// cleaned up properly.
|
// cleaned up properly.
|
||||||
//
|
//
|
||||||
// Memory management:
|
// Memory management:
|
||||||
// PyArrayToTFTensor() creates a new ndarray PyObject from the input
|
// PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
|
||||||
// ndarray. We manage the new ndarray's lifetime in order to keep the
|
// ndarray. We manage the new ndarray's lifetime in order to keep the
|
||||||
// underlying data buffer alive (the new ndarray also guarantees a contiguous
|
// underlying data buffer alive (the new ndarray also guarantees a contiguous
|
||||||
// data buffer). The new ndarray's data buffer is used to create the
|
// data buffer). The new ndarray's data buffer is used to create the
|
||||||
@ -616,7 +620,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
|||||||
std::vector<Safe_TF_TensorPtr> input_vals_safe;
|
std::vector<Safe_TF_TensorPtr> input_vals_safe;
|
||||||
for (PyObject* ndarray : input_ndarrays) {
|
for (PyObject* ndarray : input_ndarrays) {
|
||||||
input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
||||||
s = PyArrayToTFTensor(ndarray, &input_vals_safe.back());
|
s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Set_TF_Status_from_Status(out_status, s);
|
Set_TF_Status_from_Status(out_status, s);
|
||||||
return;
|
return;
|
||||||
@ -654,7 +658,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
|||||||
std::vector<Safe_PyObjectPtr> py_outputs_safe;
|
std::vector<Safe_PyObjectPtr> py_outputs_safe;
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
PyObject* py_array;
|
PyObject* py_array;
|
||||||
s = TFTensorToPyArray(std::move(output_vals_safe[i]), &py_array);
|
s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Set_TF_Status_from_Status(out_status, s);
|
Set_TF_Status_from_Status(out_status, s);
|
||||||
return;
|
return;
|
||||||
|
Loading…
Reference in New Issue
Block a user