diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 794550e840a..0ec8321230c 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -40,6 +40,7 @@ struct TfDlManagedTensorCtx { : reference(ref), shape(), tensor() {} }; +// Get tensor from eager tensor handle const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( @@ -63,6 +64,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { return tensor; }; +// Deleter for DLManagedTensor void DLManagedTensorDeleter(DLManagedTensor* arg) { TfDlManagedTensorCtx* owner = static_cast(arg->manager_ctx); @@ -129,47 +131,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, - TF_Status* status) { - const Tensor* tensor = GetTensorFromHandle(h, status); - TF_DataType data_type = static_cast(tensor->dtype()); - TensorReference tensor_ref(*tensor); // This will call buf_->Ref() - - auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); - tf_dlm_tensor_ctx->reference = tensor_ref; - - DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; - dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; - dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); - int ndim = tensor->dims(); - dlm_tensor->dl_tensor.ndim = ndim; - dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); - - std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; - std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; - shape_arr->resize(ndim); - stride_arr->resize(ndim, 1); - for (int i = 0; i < ndim; i++) { - (*shape_arr)[i] = tensor->dim_size(i); - } - for (int i = ndim - 2; i >= 0; --i) { - (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; - } - - dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; - // There are two ways to represent compact row-major data - // 1) nullptr indicates tensor is compact and row-majored. - // 2) fill in the strides array as the real case for compact row-major data - // Here we choose option 2, since some framework didn't handle the strides - // argument properly - dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; - dlm_tensor->dl_tensor.byte_offset = - 0; // TF doesn't handle the strides and byte_offsets here - return &tf_dlm_tensor_ctx->tensor; -} - +// Convert DLContext to TF device name absl::optional DeviceNameFromDlContext(const DLContext& ctx, TF_Status* status) { switch (ctx.device_type) { @@ -181,6 +143,8 @@ absl::optional DeviceNameFromDlContext(const DLContext& ctx, return absl::nullopt; }; } + +// Convert DLPack data type to TF_DATATYPE TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, TF_Status* status) { TF_DataType tf_dtype; @@ -257,11 +221,16 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, return tf_dtype; } +// Wrapper function to match the function signature +// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the +// DLManagedTensor void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { DLManagedTensor* dlmt = static_cast(dlmt_vptr); dlmt->deleter(const_cast(dlmt)); } +// Check whether the stride array matches the layout of compact, row-majored +// data bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, int ndim) { if (ndim >= 1 && stride_arr[ndim - 1] != 1) { @@ -284,8 +253,43 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { - DLManagedTensor* tfdlmtensor = TFEHandleToTfDlManagedTensorCtx(h, status); - return static_cast(tfdlmtensor); + const Tensor* tensor = GetTensorFromHandle(h, status); + TF_DataType data_type = static_cast(tensor->dtype()); + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() + + auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); + tf_dlm_tensor_ctx->reference = tensor_ref; + + DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; + dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; + dlm_tensor->deleter = &DLManagedTensorDeleter; + dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); + int ndim = tensor->dims(); + dlm_tensor->dl_tensor.ndim = ndim; + dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); + dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); + + std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; + std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; + shape_arr->resize(ndim); + stride_arr->resize(ndim, 1); + for (int i = 0; i < ndim; i++) { + (*shape_arr)[i] = tensor->dim_size(i); + } + for (int i = ndim - 2; i >= 0; --i) { + (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; + } + + dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; + // There are two ways to represent compact row-major data + // 1) nullptr indicates tensor is compact and row-majored. + // 2) fill in the strides array as the real case for compact row-major data + // Here we choose option 2, since some framework didn't handle the strides + // argument properly + dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; + dlm_tensor->dl_tensor.byte_offset = + 0; // TF doesn't handle the strides and byte_offsets here + return static_cast(dlm_tensor); } TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index f656f4393f6..21ee37b78d8 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_C_DLPACK_H_ #define TENSORFLOW_C_DLPACK_H_ @@ -22,14 +21,20 @@ limitations under the License. namespace tensorflow { +// PyCapsule name for DLPack Tensor const char* const kDlTensorCapsuleName = "dltensor"; -TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); +// Convert eager tensor handle to DLPack (DLManagedTensor*), and return the +// void* for further PyCapsule construction +TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, + TF_Status* status); -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); +// Convert DLPack (DLManagedTensor*) to eager tensor handle +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, + TF_Status* status); +// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow - #endif // TENSORFLOW_C_DLPACK_H_