fix leak
This commit is contained in:
parent
0dad831803
commit
48a353cdaf
@ -29,6 +29,8 @@ namespace {
|
|||||||
|
|
||||||
struct TFDLManagedTensorCtx {
|
struct TFDLManagedTensorCtx {
|
||||||
TensorReference* handle;
|
TensorReference* handle;
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
std::vector<int64_t> strides;
|
||||||
DLManagedTensor tensor;
|
DLManagedTensor tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -60,7 +62,6 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
|||||||
static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
|
static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
|
||||||
owner->handle->Unref();
|
owner->handle->Unref();
|
||||||
delete owner->handle;
|
delete owner->handle;
|
||||||
delete owner->tensor.dl_tensor.shape;
|
|
||||||
delete owner;
|
delete owner;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,14 +142,17 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
|
|||||||
TFE_TensorHandleDevicePointer(h, status);
|
TFE_TensorHandleDevicePointer(h, status);
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
|
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
|
||||||
|
|
||||||
int64_t* shape_arr = new int64_t[ndim];
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
|
shape_arr->resize(ndim);
|
||||||
|
stride_arr->resize(ndim);
|
||||||
for (int i = 0; i < ndim; i++) {
|
for (int i = 0; i < ndim; i++) {
|
||||||
shape_arr[i] = tensor->dim_size(i);
|
(*shape_arr)[i] = tensor->dim_size(i);
|
||||||
|
(*stride_arr)[i] = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;
|
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(shape_arr->data());
|
||||||
|
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(stride_arr->data());
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;
|
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
|
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
|
||||||
0; // TF doesn't handle the strides and byte_offsets here
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
return &tf_dlm_tensor_ctx->tensor;
|
return &tf_dlm_tensor_ctx->tensor;
|
||||||
@ -171,9 +175,6 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
|
|||||||
switch (dtype.code) {
|
switch (dtype.code) {
|
||||||
case DLDataTypeCode::kDLUInt:
|
case DLDataTypeCode::kDLUInt:
|
||||||
switch (dtype.bits) {
|
switch (dtype.bits) {
|
||||||
case 1:
|
|
||||||
tf_dtype = TF_DataType::TF_BOOL;
|
|
||||||
break;
|
|
||||||
case 8:
|
case 8:
|
||||||
tf_dtype = TF_DataType::TF_UINT8;
|
tf_dtype = TF_DataType::TF_UINT8;
|
||||||
break;
|
break;
|
||||||
@ -253,8 +254,8 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
|||||||
|
|
||||||
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||||
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
|
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
|
||||||
if (dlMTensor) {
|
if (dlMTensor->deleter != nullptr) {
|
||||||
dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));
|
dlMTensor->deleter(dlMTensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tfe
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
# tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix?
|
||||||
@tf_export("dlpack.to_dlpack")
|
@tf_export("dlpack.to_dlpack")
|
||||||
def to_dlpack(tf_tensor):
|
def to_dlpack(tf_tensor):
|
||||||
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||||
|
Loading…
Reference in New Issue
Block a user