address comments
This commit is contained in:
parent
883dcc553a
commit
e61323b14e
@ -32,7 +32,6 @@ struct TFDLManagedTensorCtx {
|
||||
DLManagedTensor tensor;
|
||||
};
|
||||
|
||||
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -45,8 +44,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
|
||||
"handle.");
|
||||
"DLPack doesn't support remote tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
@ -58,7 +56,8 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
};
|
||||
|
||||
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||
TFDLManagedTensorCtx* owner = static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
|
||||
TFDLManagedTensorCtx* owner =
|
||||
static_cast<TFDLManagedTensorCtx*>(arg->manager_ctx);
|
||||
owner->handle->Unref();
|
||||
delete owner;
|
||||
}
|
||||
@ -68,87 +67,31 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
|
||||
dtype.lanes = 1;
|
||||
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||
switch (data_type) {
|
||||
case TF_DataType::TF_HALF:
|
||||
case TF_DataType::TF_FLOAT:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
break;
|
||||
case TF_DataType::TF_DOUBLE:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
break;
|
||||
case TF_DataType::TF_INT32:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_UINT8:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_INT8:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_INT16:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_STRING:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
break;
|
||||
case TF_DataType::TF_COMPLEX64:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_COMPLEX64 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_INT32:
|
||||
case TF_DataType::TF_INT64:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_BOOL:
|
||||
case TF_DataType::TF_UINT8:
|
||||
case TF_DataType::TF_UINT16:
|
||||
case TF_DataType::TF_UINT32:
|
||||
case TF_DataType::TF_UINT64:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_QINT8:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_QINT8 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_QUINT8:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_QUINT8 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_QINT32:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_QINT32 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_BFLOAT16:
|
||||
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||
break;
|
||||
case TF_DataType::TF_QINT16:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_QINT16 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_QUINT16:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_QUINT16 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_UINT16:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_COMPLEX128:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_COMPLEX128 is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_HALF:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
break;
|
||||
case TF_DataType::TF_RESOURCE:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_RESOURCE is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_VARIANT:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TF_VARIANT is not supported by dlpack");
|
||||
break;
|
||||
case TF_DataType::TF_UINT32:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_UINT64:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported TF_DataType is not supported by dlpack");
|
||||
DataType_Name(static_cast<DataType>(data_type)),
|
||||
" is not supported by dlpack");
|
||||
break;
|
||||
}
|
||||
return dtype;
|
||||
@ -158,13 +101,12 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(absl::string_view(device_name),
|
||||
&parsed_name);
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
int device_id = -1;
|
||||
if (parsed_name.has_id) {
|
||||
device_id = parsed_name.id;
|
||||
} // Question? device_id?=-1
|
||||
} // Question: Is it possible that it doens't have id?
|
||||
|
||||
ctx.device_id = device_id;
|
||||
if (device_type == "CPU") {
|
||||
@ -173,7 +115,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
ctx.device_type = DLDeviceType::kDLGPU;
|
||||
} else {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported Device Type for DLPack");
|
||||
"Unsupported Device Type for dlpack");
|
||||
}
|
||||
|
||||
return ctx;
|
||||
@ -183,43 +125,45 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||
TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx);
|
||||
TFDLManagedTensorCtx* tf_dlm_tensor_ctx(new TFDLManagedTensorCtx);
|
||||
TensorReference* tensor_ref =
|
||||
new TensorReference(*tensor); // This will call buf_->Ref()
|
||||
tfDLMTensor->handle = tensor_ref;
|
||||
tfDLMTensor->tensor.manager_ctx = tfDLMTensor;
|
||||
tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter;
|
||||
tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status);
|
||||
tf_dlm_tensor_ctx->handle = tensor_ref;
|
||||
tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;
|
||||
tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);
|
||||
int ndim = tensor->dims();
|
||||
tfDLMTensor->tensor.dl_tensor.ndim = ndim;
|
||||
tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||
tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.data =
|
||||
TFE_TensorHandleDevicePointer(h, status);
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
|
||||
|
||||
int64_t* shape_arr = new int64_t[ndim];
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
shape_arr[i] = tensor->dim_size(i);
|
||||
}
|
||||
|
||||
tfDLMTensor->tensor.dl_tensor.shape = shape_arr;
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr;
|
||||
|
||||
tfDLMTensor->tensor.dl_tensor.strides =
|
||||
nullptr; // Whether this is null at all the time?
|
||||
tfDLMTensor->tensor.dl_tensor.byte_offset =
|
||||
0; // Whether this is 0 at all the time?
|
||||
return &tfDLMTensor->tensor;
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr;
|
||||
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return &tf_dlm_tensor_ctx->tensor;
|
||||
}
|
||||
|
||||
std::string FromDLContext(const DLContext& ctx, TF_Status* status) {
|
||||
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
||||
TF_Status* status) {
|
||||
switch (ctx.device_type) {
|
||||
case DLDeviceType::kDLCPU:
|
||||
return "CPU:0";
|
||||
case DLDeviceType::kDLGPU:
|
||||
return absl::StrCat("GPU:", ctx.device_id);
|
||||
default:
|
||||
return "";
|
||||
return absl::nullopt;
|
||||
};
|
||||
}
|
||||
TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
|
||||
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
|
||||
TF_Status* status) {
|
||||
TF_DataType tf_dtype;
|
||||
switch (dtype.code) {
|
||||
case DLDataTypeCode::kDLUInt:
|
||||
@ -241,7 +185,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported UInt bits", dtype.bits);
|
||||
"Unsupported UInt bits: ", dtype.bits);
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLInt:
|
||||
@ -260,7 +204,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported Int bits", dtype.bits);
|
||||
"Unsupported Int bits: ", dtype.bits);
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat:
|
||||
@ -276,7 +220,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported Float bits", dtype.bits);
|
||||
"Unsupported Float bits: ", dtype.bits);
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLBfloat:
|
||||
@ -286,18 +230,18 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) {
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported BFloat bits", dtype.bits);
|
||||
"Unsupported BFloat bits: ", dtype.bits);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported Type Codes", dtype.code);
|
||||
"Unsupported Type Codes: ", dtype.code);
|
||||
}
|
||||
|
||||
return tf_dtype;
|
||||
}
|
||||
|
||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr){
|
||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||
}
|
||||
@ -321,8 +265,14 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||
|
||||
std::string device_name = FromDLContext(dlmt->dl_tensor.ctx, status);
|
||||
TF_DataType dtype = FromDLDataType(dlmt->dl_tensor.dtype, status);
|
||||
absl::optional<std::string> device_name =
|
||||
DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);
|
||||
if (!device_name.has_value()) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
||||
return nullptr;
|
||||
}
|
||||
TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);
|
||||
int num_dims = dlmt->dl_tensor.ndim;
|
||||
const int64_t* dims = dlmt->dl_tensor.shape;
|
||||
void* data = dlmt->dl_tensor.data;
|
||||
@ -332,8 +282,8 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
total_bytes *= dims[i];
|
||||
}
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, device_name.c_str(), dtype, dims, num_dims, data, total_bytes,
|
||||
&DeallocatorWrapperFunc, &dlmt, status);
|
||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||
|
||||
return handle;
|
||||
};
|
||||
|
@ -16,7 +16,6 @@ int_dtypes = [
|
||||
float_dtypes = [np.float16, np.float32, np.float64]
|
||||
complex_dtypes = [np.complex64, np.complex128]
|
||||
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
|
||||
standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
|
||||
|
||||
|
||||
testcase_shapes = [
|
||||
@ -62,6 +61,22 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||
".*a DLPack tensor may be consumed at most once.*",
|
||||
ConsumeDLPackTensor)
|
||||
|
||||
def testUnsupportedType(self):
|
||||
def case1():
|
||||
tf_tensor = constant_op.constant(
|
||||
[[1, 4], [5, 2]], dtype=dtypes.qint16)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
|
||||
def case2():
|
||||
tf_tensor = constant_op.constant(
|
||||
[[1, 4], [5, 2]], dtype=dtypes.complex64)
|
||||
dlcapsule = to_dlpack(tf_tensor)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
Exception, ".* is not supported by dlpack", case1)
|
||||
self.assertRaisesRegex(
|
||||
Exception, ".* is not supported by dlpack", case2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -1041,21 +1041,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
|
||||
py::capsule capsule(
|
||||
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
|
||||
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
|
||||
void* dlm_rptr =
|
||||
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
|
||||
if (dlm_rptr) {
|
||||
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
|
||||
PyCapsule_SetDestructor(capsule, nullptr);
|
||||
} else {
|
||||
// The tensor has been deleted. Clear any error from
|
||||
// PyCapsule_GetPointer.
|
||||
PyErr_Clear();
|
||||
}
|
||||
}
|
||||
});
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return capsule;
|
||||
});
|
||||
|
||||
@ -1073,9 +1071,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
TFE_TensorHandle* thandle =
|
||||
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
||||
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
|
||||
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
return py::handle(EagerTensorFromHandle(thandle));
|
||||
});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user