From b90808b7b4568a7a58992248a8247405667c5c36 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Tue, 25 Feb 2020 16:45:19 +0000 Subject: [PATCH] fix --- tensorflow/c/eager/dlpack.cc | 38 ++++++++++++++++++------------------ tensorflow/c/eager/dlpack.h | 8 ++++---- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 0ec8321230c..6276371bd68 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -36,11 +36,10 @@ struct TfDlManagedTensorCtx { std::vector strides; DLManagedTensor tensor; - TfDlManagedTensorCtx(const TensorReference& ref) - : reference(ref), shape(), tensor() {} + TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {} }; -// Get tensor from eager tensor handle +// Gets tensor from eager tensor handle. const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( @@ -72,7 +71,8 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) { delete owner; } -DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { +// Converts TF_DATAType to DLPack data type. +DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { DLDataType dtype; dtype.lanes = 1; dtype.bits = TF_DataTypeSize(data_type) * 8; @@ -107,16 +107,17 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { return dtype; } -DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { +// Gets DLPack's DLContext from eager tensor handle. +DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; const char* device_name = h->handle->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; - int device_id = -1; + int device_id = 0; if (parsed_name.has_id) { device_id = parsed_name.id; - } // Question: Is it possible that it doens't have id? + } ctx.device_id = device_id; if (device_type == "CPU") { @@ -131,7 +132,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -// Convert DLContext to TF device name +// Converts DLContext to TF device name. absl::optional DeviceNameFromDlContext(const DLContext& ctx, TF_Status* status) { switch (ctx.device_type) { @@ -144,7 +145,7 @@ absl::optional DeviceNameFromDlContext(const DLContext& ctx, }; } -// Convert DLPack data type to TF_DATATYPE +// Converts DLPack data type to TF_DATATYPE. TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, TF_Status* status) { TF_DataType tf_dtype; @@ -221,16 +222,15 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, return tf_dtype; } -// Wrapper function to match the function signature -// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the -// DLManagedTensor +// Wraps the deleter function of DLManagedTensor to match the function signature +// TFE_NewTensorHandleFromDeviceMemory. void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { DLManagedTensor* dlmt = static_cast(dlmt_vptr); dlmt->deleter(const_cast(dlmt)); } -// Check whether the stride array matches the layout of compact, row-majored -// data +// Checks whether the stride array matches the layout of compact, row-majored +// data. bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, int ndim) { if (ndim >= 1 && stride_arr[ndim - 1] != 1) { @@ -263,11 +263,11 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); + dlm_tensor->dl_tensor.ctx = GetDlContext(h, status); int ndim = tensor->dims(); dlm_tensor->dl_tensor.ndim = ndim; dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); + dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status); std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; @@ -283,9 +283,9 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. - // 2) fill in the strides array as the real case for compact row-major data - // Here we choose option 2, since some framework didn't handle the strides - // argument properly + // 2) fill in the strides array as the real case for compact row-major data. + // Here we choose option 2, since some frameworks didn't handle the strides + // argument properly. dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 21ee37b78d8..cf83b79b573 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -24,16 +24,16 @@ namespace tensorflow { // PyCapsule name for DLPack Tensor const char* const kDlTensorCapsuleName = "dltensor"; -// Convert eager tensor handle to DLPack (DLManagedTensor*), and return the -// void* for further PyCapsule construction +// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the +// void* for further PyCapsule construction. TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); -// Convert DLPack (DLManagedTensor*) to eager tensor handle +// Converts DLPack (DLManagedTensor*) to eager tensor handle. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); -// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule +// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow