From 48a353cdaf5697b843aa37299a49201d7f4541e8 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 14:14:51 +0000 Subject: [PATCH] fix leak --- tensorflow/c/eager/dlpack.cc | 23 ++++++++++++----------- tensorflow/python/dlpack/dlpack.py | 1 + 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 8c4c70bf453..f982e483bbc 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -29,6 +29,8 @@ namespace { struct TFDLManagedTensorCtx { TensorReference* handle; + std::vector shape; + std::vector strides; DLManagedTensor tensor; }; @@ -60,7 +62,6 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) { static_cast(arg->manager_ctx); owner->handle->Unref(); delete owner->handle; - delete owner->tensor.dl_tensor.shape; delete owner; } @@ -141,14 +142,17 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, TFE_TensorHandleDevicePointer(h, status); tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); - int64_t* shape_arr = new int64_t[ndim]; + std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; + std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; + shape_arr->resize(ndim); + stride_arr->resize(ndim); for (int i = 0; i < ndim; i++) { - shape_arr[i] = tensor->dim_size(i); + (*shape_arr)[i] = tensor->dim_size(i); + (*stride_arr)[i] = 1; } - tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr; - - tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr; + tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast(shape_arr->data()); + tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast(stride_arr->data()); tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; @@ -171,9 +175,6 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, switch (dtype.code) { case DLDataTypeCode::kDLUInt: switch (dtype.bits) { - case 1: - tf_dtype = TF_DataType::TF_BOOL; - break; case 8: tf_dtype = TF_DataType::TF_UINT8; break; @@ -253,8 +254,8 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { DLManagedTensor* dlMTensor = static_cast(dlm_ptr); - if (dlMTensor) { - dlMTensor->deleter(const_cast(dlMTensor)); + if (dlMTensor->deleter != nullptr) { + dlMTensor->deleter(dlMTensor); } } diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 601dffad847..7a04fca3933 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tfe from tensorflow.python.util.tf_export import tf_export +# tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix? @tf_export("dlpack.to_dlpack") def to_dlpack(tf_tensor): return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)