address comments

This commit is contained in:
VoVAllen 2020-02-19 20:44:59 +00:00
parent 883dcc553a
commit e61323b14e
3 changed files with 80 additions and 116 deletions

View File

@ -32,7 +32,6 @@ struct TFDLManagedTensorCtx {
DLManagedTensor tensor;
};
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
@ -45,8 +44,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
"handle.");
"DLPack doesn't support remote tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
@ -58,7 +56,8 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
};
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TFDLManagedTensorCtx* owner = static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
TFDLManagedTensorCtx* owner =
static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
owner->handle->Unref();
delete owner;
}
@ -68,103 +67,46 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT32:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_UINT8:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_INT8:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_INT16:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_STRING:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_COMPLEX64:
status->status = tensorflow::errors::InvalidArgument(
"TF_COMPLEX64 is not supported by dlpack");
break;
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_QINT8:
status->status = tensorflow::errors::InvalidArgument(
"TF_QINT8 is not supported by dlpack");
break;
case TF_DataType::TF_QUINT8:
status->status = tensorflow::errors::InvalidArgument(
"TF_QUINT8 is not supported by dlpack");
break;
case TF_DataType::TF_QINT32:
status->status = tensorflow::errors::InvalidArgument(
"TF_QINT32 is not supported by dlpack");
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
case TF_DataType::TF_QINT16:
status->status = tensorflow::errors::InvalidArgument(
"TF_QINT16 is not supported by dlpack");
break;
case TF_DataType::TF_QUINT16:
status->status = tensorflow::errors::InvalidArgument(
"TF_QUINT16 is not supported by dlpack");
break;
case TF_DataType::TF_UINT16:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_COMPLEX128:
status->status = tensorflow::errors::InvalidArgument(
"TF_COMPLEX128 is not supported by dlpack");
break;
case TF_DataType::TF_HALF:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_RESOURCE:
status->status = tensorflow::errors::InvalidArgument(
"TF_RESOURCE is not supported by dlpack");
break;
case TF_DataType::TF_VARIANT:
status->status = tensorflow::errors::InvalidArgument(
"TF_VARIANT is not supported by dlpack");
break;
case TF_DataType::TF_UINT32:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported TF_DataType is not supported by dlpack");
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(absl::string_view(device_name),
&parsed_name);
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = -1;
if (parsed_name.has_id) {
device_id = parsed_name.id;
} // Question? device_id?=-1
} // Question: Is it possible that it doens't have id?
ctx.device_id = device_id;
if (device_type == "CPU") {
@ -173,53 +115,55 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for DLPack");
"Unsupported Device Type for dlpack");
}
return ctx;
}
DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
TF_Status* status) {
TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx);
TFDLManagedTensorCtx* tf_dlm_tensor_ctx(new TFDLManagedTensorCtx);
TensorReference* tensor_ref =
new TensorReference(*tensor); // This will call buf_->Ref()
tfDLMTensor->handle = tensor_ref;
tfDLMTensor->tensor.manager_ctx = tfDLMTensor;
tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter;
tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status);
tf_dlm_tensor_ctx->handle = tensor_ref;
tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;
tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;
tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims();
tfDLMTensor->tensor.dl_tensor.ndim = ndim;
tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;
tf_dlm_tensor_ctx->tensor.dl_tensor.data =
TFE_TensorHandleDevicePointer(h, status);
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
int64_t* shape_arr = new int64_t[ndim];
for (int i = 0; i < ndim; i++) {
shape_arr[i] = tensor->dim_size(i);
}
tfDLMTensor->tensor.dl_tensor.shape = shape_arr;
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;
tfDLMTensor->tensor.dl_tensor.strides =
nullptr; // Whether this is null at all the time?
tfDLMTensor->tensor.dl_tensor.byte_offset =
0; // Whether this is 0 at all the time?
return &tfDLMTensor->tensor;
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor;
}
std::string FromDLContext(const DLContext& ctx, TF_Status* status) {
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return "";
return absl::nullopt;
};
}
TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_Status* status) {
TF_DataType tf_dtype;
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
@ -241,7 +185,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported UInt bits", dtype.bits);
"Unsupported UInt bits: ", dtype.bits);
}
break;
case DLDataTypeCode::kDLInt:
@ -260,7 +204,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Int bits", dtype.bits);
"Unsupported Int bits: ", dtype.bits);
}
break;
case DLDataTypeCode::kDLFloat:
@ -276,7 +220,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Float bits", dtype.bits);
"Unsupported Float bits: ", dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
@ -286,20 +230,20 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits", dtype.bits);
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Type Codes", dtype.code);
"Unsupported Type Codes: ", dtype.code);
}
return tf_dtype;
}
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr){
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
} // namespace
@ -321,8 +265,14 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
std::string device_name = FromDLContext(dlmt->dl_tensor.ctx, status);
TF_DataType dtype = FromDLDataType(dlmt->dl_tensor.dtype, status);
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);
int num_dims = dlmt->dl_tensor.ndim;
const int64_t* dims = dlmt->dl_tensor.shape;
void* data = dlmt->dl_tensor.data;
@ -332,8 +282,8 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
total_bytes *= dims[i];
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.c_str(), dtype, dims, num_dims, data, total_bytes,
&DeallocatorWrapperFunc, &dlmt, status);
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle;
};

View File

@ -16,7 +16,6 @@ int_dtypes = [
float_dtypes = [np.float16, np.float32, np.float64]
complex_dtypes = [np.complex64, np.complex128]
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
testcase_shapes = [
@ -55,13 +54,29 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
dlcapsule = to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = from_dlpack(dlcapsule)
def ConsumeDLPackTensor():
from_dlpack(dlcapsule) # Should can be consumed only once
self.assertRaisesRegex(Exception,
".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor)
def testUnsupportedType(self):
def case1():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.qint16)
dlcapsule = to_dlpack(tf_tensor)
def case2():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.complex64)
dlcapsule = to_dlpack(tf_tensor)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case1)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case2)
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -1041,21 +1041,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
py::capsule capsule(
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
void* dlm_rptr =
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
if (dlm_rptr) {
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
PyCapsule_SetDestructor(capsule, nullptr);
} else {
// The tensor has been deleted. Clear any error from
// PyCapsule_GetPointer.
PyErr_Clear();
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
void* dlm_rptr =
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
if (dlm_rptr) {
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
PyCapsule_SetDestructor(capsule, nullptr);
}
}
});
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return capsule;
});
@ -1072,10 +1070,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
}
TFE_TensorHandle* thandle =
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return py::handle(EagerTensorFromHandle(thandle));
});