Implement Numpy to tensor conversion for TFRT.
PiperOrigin-RevId: 310657168 Change-Id: I3133a28194f41586f377d688dc64bff52f120d33
This commit is contained in:
parent
4204c5f856
commit
837b493f3c
|
@ -59,6 +59,20 @@ class AbstractContextInterface {
|
|||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
|
||||
|
||||
// Create a tensor instance from the given data buffer and description.
|
||||
// `memory_releaser` will be called on destruction, and it's responsible for
|
||||
// cleaning up the underlying buffer. `convert_string` indicates whether it
|
||||
// has to handle tstring conversion. Expected to be removed once tstring
|
||||
// migration is done.
|
||||
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims, void* data,
|
||||
size_t len, bool convert_string,
|
||||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
@ -168,6 +169,28 @@ AbstractTensorInterface* EagerContext::CreateTensor(
|
|||
return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
AbstractTensorInterface* EagerContext::CreateTensor(
|
||||
DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
bool convert_string, MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) {
|
||||
TF_Tensor* tensor_wrapper =
|
||||
TF_NewTensor(static_cast<TF_DataType>(dtype), dims, num_dims, data, len,
|
||||
memory_releaser, memory_releaser_arg);
|
||||
|
||||
if (convert_string) {
|
||||
tensorflow::Tensor tensor;
|
||||
Status status = TF_TensorToTensor(tensor_wrapper, &tensor);
|
||||
TF_DeleteTensor(tensor_wrapper);
|
||||
if (!status.ok()) return nullptr;
|
||||
return new TensorInterface(std::move(tensor));
|
||||
} else {
|
||||
AbstractTensorInterface* result = nullptr;
|
||||
std::swap(result, tensor_wrapper->tensor);
|
||||
TF_DeleteTensor(tensor_wrapper);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SavedModelAPI> EagerContext::LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
|
|
|
@ -173,6 +173,11 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
|
||||
AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) override;
|
||||
AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims,
|
||||
int num_dims, void* data, size_t len,
|
||||
bool convert_string,
|
||||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) override;
|
||||
|
||||
AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) override;
|
||||
|
|
|
@ -996,6 +996,8 @@ cc_library(
|
|||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
|
|
|
@ -89,7 +89,8 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle,
|
|||
input_names.push_back(key_string);
|
||||
|
||||
inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
||||
s = PyArrayToTF_Tensor(value, &inputs_safe.back());
|
||||
s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back(),
|
||||
true /*convert_to_string*/);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
|
@ -367,7 +368,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
|||
// cleaned up properly.
|
||||
//
|
||||
// Memory management:
|
||||
// PyArrayToTF_Tensor() creates a new ndarray PyObject from the input
|
||||
// NdarrayToTensor() creates a new ndarray PyObject from the input
|
||||
// ndarray. We manage the new ndarray's lifetime in order to keep the
|
||||
// underlying data buffer alive (the new ndarray also guarantees a contiguous
|
||||
// data buffer). The new ndarray's data buffer is used to create the
|
||||
|
@ -382,7 +383,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
|||
std::vector<Safe_TF_TensorPtr> input_vals_safe;
|
||||
for (PyObject* ndarray : input_ndarrays) {
|
||||
input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr)));
|
||||
s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back());
|
||||
s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back(), true);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
|
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
@ -488,8 +490,9 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||
DCHECK(out_tensor != nullptr);
|
||||
Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
|
||||
Safe_TF_TensorPtr* ret, bool convert_string) {
|
||||
DCHECK(ret != nullptr);
|
||||
|
||||
// Make sure we dereference this array object in case of error, etc.
|
||||
Safe_PyObjectPtr array_safe(make_safe(
|
||||
|
@ -515,26 +518,52 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
|||
if (dtype == TF_RESOURCE) {
|
||||
size_t size = PyArray_NBYTES(array);
|
||||
array_safe.release();
|
||||
*out_tensor = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array),
|
||||
size, &DelayedNumpyDecref, array));
|
||||
|
||||
if (ctx) {
|
||||
*ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), {}, 0, PyArray_DATA(array),
|
||||
size, convert_string, &DelayedNumpyDecref, array)});
|
||||
} else {
|
||||
*ret = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array), size,
|
||||
&DelayedNumpyDecref, array));
|
||||
}
|
||||
|
||||
} else if (dtype != TF_STRING) {
|
||||
size_t size = PyArray_NBYTES(array);
|
||||
array_safe.release();
|
||||
*out_tensor = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(),
|
||||
PyArray_DATA(array), size,
|
||||
&DelayedNumpyDecref, array));
|
||||
if (ctx) {
|
||||
*ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
|
||||
PyArray_DATA(array), size, convert_string, &DelayedNumpyDecref,
|
||||
array)});
|
||||
} else {
|
||||
*ret = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(),
|
||||
PyArray_DATA(array), size,
|
||||
&DelayedNumpyDecref, array));
|
||||
}
|
||||
|
||||
} else {
|
||||
size_t size = 0;
|
||||
void* encoded = nullptr;
|
||||
TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
|
||||
*out_tensor = make_safe(TF_NewTensor(
|
||||
dtype, dims.data(), dims.size(), encoded, size,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
},
|
||||
nullptr));
|
||||
if (ctx) {
|
||||
*ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
|
||||
encoded, size, convert_string,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
},
|
||||
nullptr)});
|
||||
} else {
|
||||
*ret = make_safe(TF_NewTensor(
|
||||
dtype, dims.data(), dims.size(), encoded, size,
|
||||
[](void* data, size_t len, void* arg) {
|
||||
delete[] reinterpret_cast<char*>(data);
|
||||
},
|
||||
nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -543,7 +572,8 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status);
|
|||
|
||||
Status NdarrayToTensor(PyObject* obj, Tensor* ret) {
|
||||
Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
|
||||
Status s = PyArrayToTF_Tensor(obj, &tf_tensor);
|
||||
Status s = NdarrayToTensor(nullptr /*ctx*/, obj, &tf_tensor,
|
||||
false /*convert_string*/);
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
|
|
|
@ -28,15 +28,21 @@ Status TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor,
|
|||
|
||||
Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray);
|
||||
|
||||
// Converts the given numpy ndarray to a (safe) TF_Tensor. The returned
|
||||
// TF_Tensor in `out_tensor` may have its own Python reference to `ndarray`s
|
||||
// data. After `out_tensor` is destroyed, this reference must (eventually) be
|
||||
// decremented via ClearDecrefCache().
|
||||
//
|
||||
// `out_tensor` must be non-null. Caller retains ownership of `ndarray`.
|
||||
Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor);
|
||||
// Creates a tensor in 'ret' from the input `ndarray`. The returned TF_Tensor
|
||||
// in `ret` may have its own Python reference to `ndarray`s data. After `ret`
|
||||
// is destroyed, this reference must (eventually) be decremented via
|
||||
// ClearDecrefCache().
|
||||
// `convert_string` indicates whether it has to handle tstring conversion.
|
||||
// Expected to be removed once tstring migration is done.
|
||||
ABSL_MUST_USE_RESULT
|
||||
Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
|
||||
Safe_TF_TensorPtr* ret, bool convert_string);
|
||||
|
||||
// Creates a tensor in 'ret' from the input Ndarray.
|
||||
// TODO(kkb): This is an old conversion function that does not support TFRT.
|
||||
// Currently it's used for session, py_func, and an internal project. Migrate
|
||||
// them.
|
||||
ABSL_MUST_USE_RESULT
|
||||
Status NdarrayToTensor(PyObject* obj, Tensor* ret);
|
||||
|
||||
// Creates a numpy array in 'ret' which either aliases the content of 't' or has
|
||||
|
|
|
@ -681,9 +681,11 @@ typedef Converter<bool> BoolConverter;
|
|||
// The two may share underlying storage so changes to one may reflect in the
|
||||
// other.
|
||||
TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
||||
tensorflow::Tensor tensor;
|
||||
tensorflow::Status status = tensorflow::NdarrayToTensor(obj, &tensor);
|
||||
if (!status.ok()) {
|
||||
Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
|
||||
Status status = tensorflow::NdarrayToTensor(ctx, obj, &tf_tensor,
|
||||
true /*convert_string*/);
|
||||
|
||||
if (TF_PREDICT_FALSE(!status.ok())) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Failed to convert a NumPy array to a Tensor (",
|
||||
|
@ -692,8 +694,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
TensorInterface t(std::move(tensor));
|
||||
return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t));
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(tf_tensor->tensor));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue