From 7aa5009b359a0704ec23021187a687d2476361e5 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Sat, 22 Feb 2020 16:28:58 +0000 Subject: [PATCH] fix --- tensorflow/c/eager/dlpack.cc | 82 +++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index ce36a5f3a10..fdc439da4b5 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -27,12 +27,17 @@ namespace tensorflow { namespace { +// Managing context for the DLManagedTensor, will manage the lifetime of +// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the +// original framework of destruction, and this context will be deleted also. struct TfDlManagedTensorCtx { - TensorReference* reference; + TensorReference reference; std::vector shape; + std::vector strides; DLManagedTensor tensor; - TfDlManagedTensorCtx() + TfDlManagedTensorCtx(const TensorReference& ref) + : reference(ref), shape(), tensor() {} }; const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { @@ -61,8 +66,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { void DLManagedTensorDeleter(DLManagedTensor* arg) { TfDlManagedTensorCtx* owner = static_cast(arg->manager_ctx); - owner->reference->Unref(); - delete owner->reference; + owner->reference.Unref(); delete owner; } @@ -129,31 +133,41 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx; + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() - TensorReference* tensor_ref = - new TensorReference(*tensor); // This will call buf_->Ref() + auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); tf_dlm_tensor_ctx->reference = tensor_ref; - tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx; - tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter; - tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status); + + DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; + dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; + dlm_tensor->deleter = &DLManagedTensorDeleter; + dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); int ndim = tensor->dims(); - tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim; - tf_dlm_tensor_ctx->tensor.dl_tensor.data = - TFE_TensorHandleDevicePointer(h, status); - tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); + dlm_tensor->dl_tensor.ndim = ndim; + dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); + dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; + std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; shape_arr->resize(ndim); + stride_arr->resize(ndim, 1); for (int i = 0; i < ndim; i++) { (*shape_arr)[i] = tensor->dim_size(i); } + for (int i = ndim - 2; i >= 0; --i) { + (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; + } - tf_dlm_tensor_ctx->tensor.dl_tensor.shape = + dlm_tensor->dl_tensor.shape = reinterpret_cast(shape_arr->data()); - tf_dlm_tensor_ctx->tensor.dl_tensor.strides = - nullptr; // nullptr indicates tensor is compact and row-majored. - tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = + // There are two ways to represent compact row-major data + // 1) nullptr indicates tensor is compact and row-majored. + // 2) fill in the strides array as the real case for compact row-major data + // Here we choose option 2, since some framework didn't handle the strides + // argument properly + dlm_tensor->dl_tensor.strides = + reinterpret_cast(stride_arr->data()); + dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; } @@ -250,6 +264,15 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { dlmt->deleter(const_cast(dlmt)); } +bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, + int ndim) { + for (int i = ndim - 2; i >= 0; --i) { + if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { + return false; + }; + } + return true; +} } // namespace void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { @@ -268,23 +291,32 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); DLManagedTensor* dlmt = static_cast(dlm); - + DLTensor* dl_tensor = &dlmt->dl_tensor; absl::optional device_name = - DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status); + DeviceNameFromDlContext(dl_tensor->ctx, status); if (!device_name.has_value()) { status->status = tensorflow::errors::InvalidArgument("Unsupported Device Type"); return nullptr; } - TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status); - int num_dims = dlmt->dl_tensor.ndim; - const int64_t* dims = dlmt->dl_tensor.shape; - void* data = dlmt->dl_tensor.data; + TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status); + int num_dims = dl_tensor->ndim; + const int64_t* dims = dl_tensor->shape; + void* data = dl_tensor->data; - size_t total_bytes = dlmt->dl_tensor.dtype.bits / 8; + size_t total_bytes = dl_tensor->dtype.bits / 8; for (int i = 0; i < num_dims; i++) { total_bytes *= dims[i]; } + + if ((dl_tensor->strides != nullptr) && + !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, + num_dims)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid strides array from DLPack"); + return nullptr; + } + TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( ctx, device_name.value().c_str(), dtype, dims, num_dims, data, total_bytes, &DeallocatorWrapperFunc, &dlmt, status);