This commit is contained in:
VoVAllen 2020-02-21 14:14:51 +00:00
parent 0dad831803
commit 48a353cdaf
2 changed files with 13 additions and 11 deletions

View File

@ -29,6 +29,8 @@ namespace {
struct TFDLManagedTensorCtx { struct TFDLManagedTensorCtx {
TensorReference* handle; TensorReference* handle;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor; DLManagedTensor tensor;
}; };
@ -60,7 +62,6 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) {
static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx); static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
owner->handle->Unref(); owner->handle->Unref();
delete owner->handle; delete owner->handle;
delete owner->tensor.dl_tensor.shape;
delete owner; delete owner;
} }
@ -141,14 +142,17 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
TFE_TensorHandleDevicePointer(h, status); TFE_TensorHandleDevicePointer(h, status);
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
int64_t* shape_arr = new int64_t[ndim]; std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim);
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
shape_arr[i] = tensor->dim_size(i); (*shape_arr)[i] = tensor->dim_size(i);
(*stride_arr)[i] = 1;
} }
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr; tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast<std::int64_t*>(shape_arr->data());
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast<std::int64_t*>(stride_arr->data());
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here 0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor; return &tf_dlm_tensor_ctx->tensor;
@ -171,9 +175,6 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
switch (dtype.code) { switch (dtype.code) {
case DLDataTypeCode::kDLUInt: case DLDataTypeCode::kDLUInt:
switch (dtype.bits) { switch (dtype.bits) {
case 1:
tf_dtype = TF_DataType::TF_BOOL;
break;
case 8: case 8:
tf_dtype = TF_DataType::TF_UINT8; tf_dtype = TF_DataType::TF_UINT8;
break; break;
@ -253,8 +254,8 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr); DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor) { if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor)); dlMTensor->deleter(dlMTensor);
} }
} }

View File

@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tfe
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix?
@tf_export("dlpack.to_dlpack") @tf_export("dlpack.to_dlpack")
def to_dlpack(tf_tensor): def to_dlpack(tf_tensor):
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)