diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index 9377bf0be12..d21ab45e579 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -59,6 +59,20 @@ class AbstractContextInterface { virtual AbstractTensorInterface* CreateTensor( DataType dtype, absl::Span dim_sizes) = 0; + typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); + + // Create a tensor instance from the given data buffer and description. + // `memory_releaser` will be called on destruction, and it's responsible for + // cleaning up the underlying buffer. `convert_string` indicates whether it + // has to handle tstring conversion. Expected to be removed once tstring + // migration is done. + virtual AbstractTensorInterface* CreateTensor(DataType dtype, + const int64_t* dims, + int num_dims, void* data, + size_t len, bool convert_string, + MemoryReleaser memory_releaser, + void* memory_releaser_arg) = 0; + // Create a handle to wrap and manage a Tensor virtual AbstractTensorHandleInterface* CreateLocalHandle( AbstractTensorInterface* t) = 0; diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 35780077aa8..b8dfe92aac6 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" // clang-format on +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" @@ -168,6 +169,28 @@ AbstractTensorInterface* EagerContext::CreateTensor( return new TensorInterface(Tensor(dtype, TensorShape(dim_sizes))); } +AbstractTensorInterface* EagerContext::CreateTensor( + DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, + bool convert_string, MemoryReleaser memory_releaser, + void* memory_releaser_arg) { + TF_Tensor* tensor_wrapper = + TF_NewTensor(static_cast(dtype), dims, num_dims, data, len, + memory_releaser, memory_releaser_arg); + + if (convert_string) { + tensorflow::Tensor tensor; + Status status = TF_TensorToTensor(tensor_wrapper, &tensor); + TF_DeleteTensor(tensor_wrapper); + if (!status.ok()) return nullptr; + return new TensorInterface(std::move(tensor)); + } else { + AbstractTensorInterface* result = nullptr; + std::swap(result, tensor_wrapper->tensor); + TF_DeleteTensor(tensor_wrapper); + return result; + } +} + std::unique_ptr EagerContext::LoadSavedModelAPI( const std::string& directory, const absl::optional>& tags, diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index c5404773ba6..683425919d1 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -173,6 +173,11 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { AbstractTensorInterface* CreateTensor( DataType dtype, absl::Span dim_sizes) override; + AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims, + int num_dims, void* data, size_t len, + bool convert_string, + MemoryReleaser memory_releaser, + void* memory_releaser_arg) override; AbstractTensorHandleInterface* CreateLocalHandle( AbstractTensorInterface* t) override; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4729ce9d743..0b046ea8d61 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -996,6 +996,8 @@ cc_library( "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_tensor_internal", + "//tensorflow/c/eager:tfe_context_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", ], diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 78a1613c86c..cb960fd599a 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -89,7 +89,8 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, input_names.push_back(key_string); inputs_safe.emplace_back(make_safe(static_cast(nullptr))); - s = PyArrayToTF_Tensor(value, &inputs_safe.back()); + s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back(), + true /*convert_to_string*/); if (!s.ok()) { Set_TF_Status_from_Status(out_status, s); return; @@ -367,7 +368,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle, // cleaned up properly. // // Memory management: - // PyArrayToTF_Tensor() creates a new ndarray PyObject from the input + // NdarrayToTensor() creates a new ndarray PyObject from the input // ndarray. We manage the new ndarray's lifetime in order to keep the // underlying data buffer alive (the new ndarray also guarantees a contiguous // data buffer). The new ndarray's data buffer is used to create the @@ -382,7 +383,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle, std::vector input_vals_safe; for (PyObject* ndarray : input_ndarrays) { input_vals_safe.emplace_back(make_safe(static_cast(nullptr))); - s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back()); + s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back(), true); if (!s.ok()) { Set_TF_Status_from_Status(out_status, s); return; diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 2f9972c81bf..2afd2888e8f 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -488,8 +490,9 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) { return Status::OK(); } -Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) { - DCHECK(out_tensor != nullptr); +Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray, + Safe_TF_TensorPtr* ret, bool convert_string) { + DCHECK(ret != nullptr); // Make sure we dereference this array object in case of error, etc. Safe_PyObjectPtr array_safe(make_safe( @@ -515,26 +518,52 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) { if (dtype == TF_RESOURCE) { size_t size = PyArray_NBYTES(array); array_safe.release(); - *out_tensor = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array), - size, &DelayedNumpyDecref, array)); + + if (ctx) { + *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor( + static_cast(dtype), {}, 0, PyArray_DATA(array), + size, convert_string, &DelayedNumpyDecref, array)}); + } else { + *ret = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array), size, + &DelayedNumpyDecref, array)); + } } else if (dtype != TF_STRING) { size_t size = PyArray_NBYTES(array); array_safe.release(); - *out_tensor = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), - PyArray_DATA(array), size, - &DelayedNumpyDecref, array)); + if (ctx) { + *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor( + static_cast(dtype), dims.data(), dims.size(), + PyArray_DATA(array), size, convert_string, &DelayedNumpyDecref, + array)}); + } else { + *ret = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(), + PyArray_DATA(array), size, + &DelayedNumpyDecref, array)); + } + } else { size_t size = 0; void* encoded = nullptr; TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded)); - *out_tensor = make_safe(TF_NewTensor( - dtype, dims.data(), dims.size(), encoded, size, - [](void* data, size_t len, void* arg) { - delete[] reinterpret_cast(data); - }, - nullptr)); + if (ctx) { + *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor( + static_cast(dtype), dims.data(), dims.size(), + encoded, size, convert_string, + [](void* data, size_t len, void* arg) { + delete[] reinterpret_cast(data); + }, + nullptr)}); + } else { + *ret = make_safe(TF_NewTensor( + dtype, dims.data(), dims.size(), encoded, size, + [](void* data, size_t len, void* arg) { + delete[] reinterpret_cast(data); + }, + nullptr)); + } } + return Status::OK(); } @@ -543,7 +572,8 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status); Status NdarrayToTensor(PyObject* obj, Tensor* ret) { Safe_TF_TensorPtr tf_tensor = make_safe(static_cast(nullptr)); - Status s = PyArrayToTF_Tensor(obj, &tf_tensor); + Status s = NdarrayToTensor(nullptr /*ctx*/, obj, &tf_tensor, + false /*convert_string*/); if (!s.ok()) { return s; } diff --git a/tensorflow/python/lib/core/ndarray_tensor.h b/tensorflow/python/lib/core/ndarray_tensor.h index c5cd24cff2d..38c098417d5 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.h +++ b/tensorflow/python/lib/core/ndarray_tensor.h @@ -28,15 +28,21 @@ Status TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor, Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray); -// Converts the given numpy ndarray to a (safe) TF_Tensor. The returned -// TF_Tensor in `out_tensor` may have its own Python reference to `ndarray`s -// data. After `out_tensor` is destroyed, this reference must (eventually) be -// decremented via ClearDecrefCache(). -// -// `out_tensor` must be non-null. Caller retains ownership of `ndarray`. -Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor); +// Creates a tensor in 'ret' from the input `ndarray`. The returned TF_Tensor +// in `ret` may have its own Python reference to `ndarray`s data. After `ret` +// is destroyed, this reference must (eventually) be decremented via +// ClearDecrefCache(). +// `convert_string` indicates whether it has to handle tstring conversion. +// Expected to be removed once tstring migration is done. +ABSL_MUST_USE_RESULT +Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray, + Safe_TF_TensorPtr* ret, bool convert_string); // Creates a tensor in 'ret' from the input Ndarray. +// TODO(kkb): This is an old conversion function that does not support TFRT. +// Currently it's used for session, py_func, and an internal project. Migrate +// them. +ABSL_MUST_USE_RESULT Status NdarrayToTensor(PyObject* obj, Tensor* ret); // Creates a numpy array in 'ret' which either aliases the content of 't' or has diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index ecf4a92f0e7..22829f546b1 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -681,9 +681,11 @@ typedef Converter BoolConverter; // The two may share underlying storage so changes to one may reflect in the // other. TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { - tensorflow::Tensor tensor; - tensorflow::Status status = tensorflow::NdarrayToTensor(obj, &tensor); - if (!status.ok()) { + Safe_TF_TensorPtr tf_tensor = make_safe(static_cast(nullptr)); + Status status = tensorflow::NdarrayToTensor(ctx, obj, &tf_tensor, + true /*convert_string*/); + + if (TF_PREDICT_FALSE(!status.ok())) { PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat( "Failed to convert a NumPy array to a Tensor (", @@ -692,8 +694,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { return nullptr; } - TensorInterface t(std::move(tensor)); - return tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(&t)); + return tensorflow::wrap( + tensorflow::unwrap(ctx)->CreateLocalHandle(tf_tensor->tensor)); } } // namespace