address comments

This commit is contained in:
VoVAllen 2020-02-19 20:44:59 +00:00
parent 883dcc553a
commit e61323b14e
3 changed files with 80 additions and 116 deletions

View File

@ -32,7 +32,6 @@ struct TFDLManagedTensorCtx {
DLManagedTensor tensor; DLManagedTensor tensor;
}; };
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -45,8 +44,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (handle->IsRemote()) { if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDevicePointer may not be called on a remote tensor " "DLPack doesn't support remote tensor");
"handle.");
return nullptr; return nullptr;
} }
const tensorflow::Tensor* tensor; const tensorflow::Tensor* tensor;
@ -58,7 +56,8 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
}; };
void DLManagedTensorDeleter(DLManagedTensor* arg) { void DLManagedTensorDeleter(DLManagedTensor* arg) {
TFDLManagedTensorCtx* owner = static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx); TFDLManagedTensorCtx* owner =
static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
owner->handle->Unref(); owner->handle->Unref();
delete owner; delete owner;
} }
@ -68,87 +67,31 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
dtype.lanes = 1; dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8; dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) { switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT: case TF_DataType::TF_FLOAT:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_DOUBLE: case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat; dtype.code = DLDataTypeCode::kDLFloat;
break; break;
case TF_DataType::TF_INT32:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_UINT8:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_INT8: case TF_DataType::TF_INT8:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_INT16: case TF_DataType::TF_INT16:
dtype.code = DLDataTypeCode::kDLInt; case TF_DataType::TF_INT32:
break;
case TF_DataType::TF_STRING:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_COMPLEX64:
status->status = tensorflow::errors::InvalidArgument(
"TF_COMPLEX64 is not supported by dlpack");
break;
case TF_DataType::TF_INT64: case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt; dtype.code = DLDataTypeCode::kDLInt;
break; break;
case TF_DataType::TF_BOOL: case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt; dtype.code = DLDataTypeCode::kDLUInt;
break; break;
case TF_DataType::TF_QINT8:
status->status = tensorflow::errors::InvalidArgument(
"TF_QINT8 is not supported by dlpack");
break;
case TF_DataType::TF_QUINT8:
status->status = tensorflow::errors::InvalidArgument(
"TF_QUINT8 is not supported by dlpack");
break;
case TF_DataType::TF_QINT32:
status->status = tensorflow::errors::InvalidArgument(
"TF_QINT32 is not supported by dlpack");
break;
case TF_DataType::TF_BFLOAT16: case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat; dtype.code = DLDataTypeCode::kDLBfloat;
break; break;
case TF_DataType::TF_QINT16:
status->status = tensorflow::errors::InvalidArgument(
"TF_QINT16 is not supported by dlpack");
break;
case TF_DataType::TF_QUINT16:
status->status = tensorflow::errors::InvalidArgument(
"TF_QUINT16 is not supported by dlpack");
break;
case TF_DataType::TF_UINT16:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_COMPLEX128:
status->status = tensorflow::errors::InvalidArgument(
"TF_COMPLEX128 is not supported by dlpack");
break;
case TF_DataType::TF_HALF:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_RESOURCE:
status->status = tensorflow::errors::InvalidArgument(
"TF_RESOURCE is not supported by dlpack");
break;
case TF_DataType::TF_VARIANT:
status->status = tensorflow::errors::InvalidArgument(
"TF_VARIANT is not supported by dlpack");
break;
case TF_DataType::TF_UINT32:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
default: default:
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported TF_DataType is not supported by dlpack"); DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break; break;
} }
return dtype; return dtype;
@ -158,13 +101,12 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx; DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status); const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(absl::string_view(device_name), tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
&parsed_name);
std::string device_type = parsed_name.type; std::string device_type = parsed_name.type;
int device_id = -1; int device_id = -1;
if (parsed_name.has_id) { if (parsed_name.has_id) {
device_id = parsed_name.id; device_id = parsed_name.id;
} // Question? device_id?=-1 } // Question: Is it possible that it doens't have id?
ctx.device_id = device_id; ctx.device_id = device_id;
if (device_type == "CPU") { if (device_type == "CPU") {
@ -173,7 +115,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
ctx.device_type = DLDeviceType::kDLGPU; ctx.device_type = DLDeviceType::kDLGPU;
} else { } else {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for DLPack"); "Unsupported Device Type for dlpack");
} }
return ctx; return ctx;
@ -183,43 +125,45 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
TF_Status* status) { TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status); const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype()); TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx); TFDLManagedTensorCtx* tf_dlm_tensor_ctx(new TFDLManagedTensorCtx);
TensorReference* tensor_ref = TensorReference* tensor_ref =
new TensorReference(*tensor); // This will call buf_->Ref() new TensorReference(*tensor); // This will call buf_->Ref()
tfDLMTensor->handle = tensor_ref; tf_dlm_tensor_ctx->handle = tensor_ref;
tfDLMTensor->tensor.manager_ctx = tfDLMTensor; tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;
tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter; tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;
tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status); tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims(); int ndim = tensor->dims();
tfDLMTensor->tensor.dl_tensor.ndim = ndim; tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;
tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); tf_dlm_tensor_ctx->tensor.dl_tensor.data =
tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); TFE_TensorHandleDevicePointer(h, status);
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
int64_t* shape_arr = new int64_t[ndim]; int64_t* shape_arr = new int64_t[ndim];
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
shape_arr[i] = tensor->dim_size(i); shape_arr[i] = tensor->dim_size(i);
} }
tfDLMTensor->tensor.dl_tensor.shape = shape_arr; tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;
tfDLMTensor->tensor.dl_tensor.strides = tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;
nullptr; // Whether this is null at all the time? tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
tfDLMTensor->tensor.dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here
0; // Whether this is 0 at all the time? return &tf_dlm_tensor_ctx->tensor;
return &tfDLMTensor->tensor;
} }
std::string FromDLContext(const DLContext& ctx, TF_Status* status) { absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) { switch (ctx.device_type) {
case DLDeviceType::kDLCPU: case DLDeviceType::kDLCPU:
return "CPU:0"; return "CPU:0";
case DLDeviceType::kDLGPU: case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id); return absl::StrCat("GPU:", ctx.device_id);
default: default:
return ""; return absl::nullopt;
}; };
} }
TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_Status* status) {
TF_DataType tf_dtype; TF_DataType tf_dtype;
switch (dtype.code) { switch (dtype.code) {
case DLDataTypeCode::kDLUInt: case DLDataTypeCode::kDLUInt:
@ -241,7 +185,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break; break;
default: default:
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported UInt bits", dtype.bits); "Unsupported UInt bits: ", dtype.bits);
} }
break; break;
case DLDataTypeCode::kDLInt: case DLDataTypeCode::kDLInt:
@ -260,7 +204,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break; break;
default: default:
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported Int bits", dtype.bits); "Unsupported Int bits: ", dtype.bits);
} }
break; break;
case DLDataTypeCode::kDLFloat: case DLDataTypeCode::kDLFloat:
@ -276,7 +220,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break; break;
default: default:
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported Float bits", dtype.bits); "Unsupported Float bits: ", dtype.bits);
} }
break; break;
case DLDataTypeCode::kDLBfloat: case DLDataTypeCode::kDLBfloat:
@ -286,12 +230,12 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
break; break;
default: default:
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits", dtype.bits); "Unsupported BFloat bits: ", dtype.bits);
} }
break; break;
default: default:
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Unsupported Type Codes", dtype.code); "Unsupported Type Codes: ", dtype.code);
} }
return tf_dtype; return tf_dtype;
@ -321,8 +265,14 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm); DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
std::string device_name = FromDLContext(dlmt->dl_tensor.ctx, status); absl::optional<std::string> device_name =
TF_DataType dtype = FromDLDataType(dlmt->dl_tensor.dtype, status); DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);
int num_dims = dlmt->dl_tensor.ndim; int num_dims = dlmt->dl_tensor.ndim;
const int64_t* dims = dlmt->dl_tensor.shape; const int64_t* dims = dlmt->dl_tensor.shape;
void* data = dlmt->dl_tensor.data; void* data = dlmt->dl_tensor.data;
@ -332,8 +282,8 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
total_bytes *= dims[i]; total_bytes *= dims[i];
} }
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.c_str(), dtype, dims, num_dims, data, total_bytes, ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
&DeallocatorWrapperFunc, &dlmt, status); total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle; return handle;
}; };

View File

@ -16,7 +16,6 @@ int_dtypes = [
float_dtypes = [np.float16, np.float32, np.float64] float_dtypes = [np.float16, np.float32, np.float64]
complex_dtypes = [np.complex64, np.complex128] complex_dtypes = [np.complex64, np.complex128]
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
testcase_shapes = [ testcase_shapes = [
@ -62,6 +61,22 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
".*a DLPack tensor may be consumed at most once.*", ".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor) ConsumeDLPackTensor)
def testUnsupportedType(self):
def case1():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.qint16)
dlcapsule = to_dlpack(tf_tensor)
def case2():
tf_tensor = constant_op.constant(
[[1, 4], [5, 2]], dtype=dtypes.complex64)
dlcapsule = to_dlpack(tf_tensor)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case1)
self.assertRaisesRegex(
Exception, ".* is not supported by dlpack", case2)
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution() ops.enable_eager_execution()

View File

@ -1041,21 +1041,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::Safe_TF_StatusPtr status = tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus()); tensorflow::make_safe(TF_NewStatus());
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
py::capsule capsule( py::capsule capsule(
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) { dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
void* dlm_rptr = void* dlm_rptr =
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
if (dlm_rptr) { if (dlm_rptr) {
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
PyCapsule_SetDestructor(capsule, nullptr); PyCapsule_SetDestructor(capsule, nullptr);
} else { }
// The tensor has been deleted. Clear any error from
// PyCapsule_GetPointer.
PyErr_Clear();
} }
}); });
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return capsule; return capsule;
}); });
@ -1073,9 +1071,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
TFE_TensorHandle* thandle = TFE_TensorHandle* thandle =
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return py::handle(EagerTensorFromHandle(thandle)); return py::handle(EagerTensorFromHandle(thandle));
}); });