Implement Numpy to tensor conversion for TFRT.

PiperOrigin-RevId: 310657168
Change-Id: I3133a28194f41586f377d688dc64bff52f120d33
This commit is contained in:
Kibeom Kim 2020-05-08 17:13:23 -07:00 committed by TensorFlower Gardener
parent 4204c5f856
commit 837b493f3c
8 changed files with 112 additions and 29 deletions

View File

@ -59,6 +59,20 @@ class AbstractContextInterface {
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
// Create a tensor instance from the given data buffer and description.
// `memory_releaser` will be called on destruction, and it's responsible for
// cleaning up the underlying buffer. `convert_string` indicates whether it
// has to handle tstring conversion. Expected to be removed once tstring
// migration is done.
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
const int64_t* dims,
int num_dims, void* data,
size_t len, bool convert_string,
MemoryReleaser memory_releaser,
void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) = 0;

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
@ -168,6 +169,28 @@ AbstractTensorInterface* EagerContext::CreateTensor(
return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
}
AbstractTensorInterface* EagerContext::CreateTensor(
DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
bool convert_string, MemoryReleaser memory_releaser,
void* memory_releaser_arg) {
TF_Tensor* tensor_wrapper =
TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
memory_releaser, memory_releaser_arg);
if (convert_string) {
tensorflow::Tensor tensor;
Status status = TF_TensorToTensor(tensor_wrapper, &tensor);
TF_DeleteTensor(tensor_wrapper);
if (!status.ok()) return nullptr;
return new TensorInterface(std::move(tensor));
} else {
AbstractTensorInterface* result = nullptr;
std::swap(result, tensor_wrapper->tensor);
TF_DeleteTensor(tensor_wrapper);
return result;
}
}
std::unique_ptr<SavedModelAPI> EagerContext::LoadSavedModelAPI(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,

View File

@ -173,6 +173,11 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) override;
AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
int num_dims, void* data, size_t len,
bool convert_string,
MemoryReleaser memory_releaser,
void* memory_releaser_arg) override;
AbstractTensorHandleInterface* CreateLocalHandle(
AbstractTensorInterface* t) override;

View File

@ -996,6 +996,8 @@ cc_library(
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],

View File

@ -89,7 +89,8 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
input_names.push_back(key_string);
inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
s = PyArrayToTF_Tensor(value, &inputs_safe.back());
s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back(),
true /*convert_to_string*/);
if (!s.ok()) {
Set_TF_Status_from_Status(out_status, s);
return;
@ -367,7 +368,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
// cleaned up properly.
//
// Memory management:
// PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
// NdarrayToTensor() creates a new ndarray PyObject from the input
// ndarray. We manage the new ndarray's lifetime in order to keep the
// underlying data buffer alive (the new ndarray also guarantees a contiguous
// data buffer). The new ndarray's data buffer is used to create the
@ -382,7 +383,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
std::vector<Safe_TF_TensorPtr> input_vals_safe;
for (PyObject* ndarray : input_ndarrays) {
input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back(), true);
if (!s.ok()) {
Set_TF_Status_from_Status(out_status, s);
return;

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <cstring>
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -488,8 +490,9 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
return Status::OK();
}
Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
DCHECK(out_tensor != nullptr);
Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
Safe_TF_TensorPtr* ret, bool convert_string) {
DCHECK(ret != nullptr);
// Make sure we dereference this array object in case of error, etc.
Safe_PyObjectPtr array_safe(make_safe(
@ -515,26 +518,52 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
if (dtype == TF_RESOURCE) {
size_t size = PyArray_NBYTES(array);
array_safe.release();
*out_tensor = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array),
size, &DelayedNumpyDecref, array));
if (ctx) {
*ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), {}, 0, PyArray_DATA(array),
size, convert_string, &DelayedNumpyDecref, array)});
} else {
*ret = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array), size,
&DelayedNumpyDecref, array));
}
} else if (dtype != TF_STRING) {
size_t size = PyArray_NBYTES(array);
array_safe.release();
*out_tensor = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(),
PyArray_DATA(array), size,
&DelayedNumpyDecref, array));
if (ctx) {
*ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
PyArray_DATA(array), size, convert_string, &DelayedNumpyDecref,
array)});
} else {
*ret = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(),
PyArray_DATA(array), size,
&DelayedNumpyDecref, array));
}
} else {
size_t size = 0;
void* encoded = nullptr;
TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
*out_tensor = make_safe(TF_NewTensor(
dtype, dims.data(), dims.size(), encoded, size,
[](void* data, size_t len, void* arg) {
delete[] reinterpret_cast<char*>(data);
},
nullptr));
if (ctx) {
*ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
encoded, size, convert_string,
[](void* data, size_t len, void* arg) {
delete[] reinterpret_cast<char*>(data);
},
nullptr)});
} else {
*ret = make_safe(TF_NewTensor(
dtype, dims.data(), dims.size(), encoded, size,
[](void* data, size_t len, void* arg) {
delete[] reinterpret_cast<char*>(data);
},
nullptr));
}
}
return Status::OK();
}
@ -543,7 +572,8 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status);
Status NdarrayToTensor(PyObject* obj, Tensor* ret) {
Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
Status s = PyArrayToTF_Tensor(obj, &tf_tensor);
Status s = NdarrayToTensor(nullptr /*ctx*/, obj, &tf_tensor,
false /*convert_string*/);
if (!s.ok()) {
return s;
}

View File

@ -28,15 +28,21 @@ Status TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor,
Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray);
// Converts the given numpy ndarray to a (safe) TF_Tensor. The returned
// TF_Tensor in `out_tensor` may have its own Python reference to `ndarray`s
// data. After `out_tensor` is destroyed, this reference must (eventually) be
// decremented via ClearDecrefCache().
//
// `out_tensor` must be non-null. Caller retains ownership of `ndarray`.
Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor);
// Creates a tensor in 'ret' from the input `ndarray`. The returned TF_Tensor
// in `ret` may have its own Python reference to `ndarray`s data. After `ret`
// is destroyed, this reference must (eventually) be decremented via
// ClearDecrefCache().
// `convert_string` indicates whether it has to handle tstring conversion.
// Expected to be removed once tstring migration is done.
ABSL_MUST_USE_RESULT
Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
Safe_TF_TensorPtr* ret, bool convert_string);
// Creates a tensor in 'ret' from the input Ndarray.
// TODO(kkb): This is an old conversion function that does not support TFRT.
// Currently it's used for session, py_func, and an internal project. Migrate
// them.
ABSL_MUST_USE_RESULT
Status NdarrayToTensor(PyObject* obj, Tensor* ret);
// Creates a numpy array in 'ret' which either aliases the content of 't' or has

View File

@ -681,9 +681,11 @@ typedef Converter<bool> BoolConverter;
// The two may share underlying storage so changes to one may reflect in the
// other.
TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
tensorflow::Tensor tensor;
tensorflow::Status status = tensorflow::NdarrayToTensor(obj, &tensor);
if (!status.ok()) {
Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
Status status = tensorflow::NdarrayToTensor(ctx, obj, &tf_tensor,
true /*convert_string*/);
if (TF_PREDICT_FALSE(!status.ok())) {
PyErr_SetString(PyExc_ValueError,
tensorflow::strings::StrCat(
"Failed to convert a NumPy array to a Tensor (",
@ -692,8 +694,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
return nullptr;
}
TensorInterface t(std::move(tensor));
return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(tf_tensor->tensor));
}
} // namespace