diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index e0624ac4ca1..fa6fb77a2d7 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -32,7 +32,6 @@ struct TFDLManagedTensorCtx { DLManagedTensor tensor; }; - const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( @@ -45,8 +44,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (handle->IsRemote()) { status->status = tensorflow::errors::InvalidArgument( - "TFE_TensorHandleDevicePointer may not be called on a remote tensor " - "handle."); + "DLPack doesn't support remote tensor"); return nullptr; } const tensorflow::Tensor* tensor; @@ -58,7 +56,8 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { }; void DLManagedTensorDeleter(DLManagedTensor* arg) { - TFDLManagedTensorCtx* owner = static_cast(arg->manager_ctx); + TFDLManagedTensorCtx* owner = + static_cast(arg->manager_ctx); owner->handle->Unref(); delete owner; } @@ -68,103 +67,46 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { dtype.lanes = 1; dtype.bits = TF_DataTypeSize(data_type) * 8; switch (data_type) { + case TF_DataType::TF_HALF: case TF_DataType::TF_FLOAT: - dtype.code = DLDataTypeCode::kDLFloat; - break; case TF_DataType::TF_DOUBLE: dtype.code = DLDataTypeCode::kDLFloat; break; - case TF_DataType::TF_INT32: - dtype.code = DLDataTypeCode::kDLInt; - break; - case TF_DataType::TF_UINT8: - dtype.code = DLDataTypeCode::kDLUInt; - break; case TF_DataType::TF_INT8: - dtype.code = DLDataTypeCode::kDLInt; - break; case TF_DataType::TF_INT16: - dtype.code = DLDataTypeCode::kDLInt; - break; - case TF_DataType::TF_STRING: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case TF_DataType::TF_COMPLEX64: - status->status = tensorflow::errors::InvalidArgument( - "TF_COMPLEX64 is not supported by dlpack"); - break; + case TF_DataType::TF_INT32: case TF_DataType::TF_INT64: dtype.code = DLDataTypeCode::kDLInt; break; case TF_DataType::TF_BOOL: + case TF_DataType::TF_UINT8: + case TF_DataType::TF_UINT16: + case TF_DataType::TF_UINT32: + case TF_DataType::TF_UINT64: dtype.code = DLDataTypeCode::kDLUInt; break; - case TF_DataType::TF_QINT8: - status->status = tensorflow::errors::InvalidArgument( - "TF_QINT8 is not supported by dlpack"); - break; - case TF_DataType::TF_QUINT8: - status->status = tensorflow::errors::InvalidArgument( - "TF_QUINT8 is not supported by dlpack"); - break; - case TF_DataType::TF_QINT32: - status->status = tensorflow::errors::InvalidArgument( - "TF_QINT32 is not supported by dlpack"); - break; case TF_DataType::TF_BFLOAT16: dtype.code = DLDataTypeCode::kDLBfloat; break; - case TF_DataType::TF_QINT16: - status->status = tensorflow::errors::InvalidArgument( - "TF_QINT16 is not supported by dlpack"); - break; - case TF_DataType::TF_QUINT16: - status->status = tensorflow::errors::InvalidArgument( - "TF_QUINT16 is not supported by dlpack"); - break; - case TF_DataType::TF_UINT16: - dtype.code = DLDataTypeCode::kDLUInt; - break; - case TF_DataType::TF_COMPLEX128: - status->status = tensorflow::errors::InvalidArgument( - "TF_COMPLEX128 is not supported by dlpack"); - break; - case TF_DataType::TF_HALF: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case TF_DataType::TF_RESOURCE: - status->status = tensorflow::errors::InvalidArgument( - "TF_RESOURCE is not supported by dlpack"); - break; - case TF_DataType::TF_VARIANT: - status->status = tensorflow::errors::InvalidArgument( - "TF_VARIANT is not supported by dlpack"); - break; - case TF_DataType::TF_UINT32: - dtype.code = DLDataTypeCode::kDLUInt; - break; - case TF_DataType::TF_UINT64: - dtype.code = DLDataTypeCode::kDLUInt; - break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported TF_DataType is not supported by dlpack"); + DataType_Name(static_cast(data_type)), + " is not supported by dlpack"); break; } return dtype; } -DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { +DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; const char* device_name = h->handle->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; - tensorflow::DeviceNameUtils::ParseFullName(absl::string_view(device_name), - &parsed_name); + tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; int device_id = -1; if (parsed_name.has_id) { device_id = parsed_name.id; - } // Question? device_id?=-1 + } // Question: Is it possible that it doens't have id? ctx.device_id = device_id; if (device_type == "CPU") { @@ -173,53 +115,55 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { ctx.device_type = DLDeviceType::kDLGPU; } else { status->status = tensorflow::errors::InvalidArgument( - "Unsupported Device Type for DLPack"); + "Unsupported Device Type for dlpack"); } return ctx; } DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, - TF_Status* status) { + TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx); + TFDLManagedTensorCtx* tf_dlm_tensor_ctx(new TFDLManagedTensorCtx); TensorReference* tensor_ref = new TensorReference(*tensor); // This will call buf_->Ref() - tfDLMTensor->handle = tensor_ref; - tfDLMTensor->tensor.manager_ctx = tfDLMTensor; - tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter; - tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status); + tf_dlm_tensor_ctx->handle = tensor_ref; + tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx; + tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter; + tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status); int ndim = tensor->dims(); - tfDLMTensor->tensor.dl_tensor.ndim = ndim; - tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); + tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim; + tf_dlm_tensor_ctx->tensor.dl_tensor.data = + TFE_TensorHandleDevicePointer(h, status); + tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); int64_t* shape_arr = new int64_t[ndim]; for (int i = 0; i < ndim; i++) { shape_arr[i] = tensor->dim_size(i); } - tfDLMTensor->tensor.dl_tensor.shape = shape_arr; + tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr; - tfDLMTensor->tensor.dl_tensor.strides = - nullptr; // Whether this is null at all the time? - tfDLMTensor->tensor.dl_tensor.byte_offset = - 0; // Whether this is 0 at all the time? - return &tfDLMTensor->tensor; + tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr; + tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = + 0; // TF doesn't handle the strides and byte_offsets here + return &tf_dlm_tensor_ctx->tensor; } -std::string FromDLContext(const DLContext& ctx, TF_Status* status) { +absl::optional DeviceNameFromDlContext(const DLContext& ctx, + TF_Status* status) { switch (ctx.device_type) { case DLDeviceType::kDLCPU: return "CPU:0"; case DLDeviceType::kDLGPU: return absl::StrCat("GPU:", ctx.device_id); default: - return ""; + return absl::nullopt; }; } -TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { +TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, + TF_Status* status) { TF_DataType tf_dtype; switch (dtype.code) { case DLDataTypeCode::kDLUInt: @@ -241,7 +185,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported UInt bits", dtype.bits); + "Unsupported UInt bits: ", dtype.bits); } break; case DLDataTypeCode::kDLInt: @@ -260,7 +204,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported Int bits", dtype.bits); + "Unsupported Int bits: ", dtype.bits); } break; case DLDataTypeCode::kDLFloat: @@ -276,7 +220,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported Float bits", dtype.bits); + "Unsupported Float bits: ", dtype.bits); } break; case DLDataTypeCode::kDLBfloat: @@ -286,20 +230,20 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported BFloat bits", dtype.bits); + "Unsupported BFloat bits: ", dtype.bits); } break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported Type Codes", dtype.code); + "Unsupported Type Codes: ", dtype.code); } return tf_dtype; } -void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr){ - DLManagedTensor* dlmt = static_cast(dlmt_vptr); - dlmt->deleter(const_cast(dlmt)); +void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { + DLManagedTensor* dlmt = static_cast(dlmt_vptr); + dlmt->deleter(const_cast(dlmt)); } } // namespace @@ -321,8 +265,14 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { TFE_Context* ctx = TFE_NewContext(opts, status); DLManagedTensor* dlmt = static_cast(dlm); - std::string device_name = FromDLContext(dlmt->dl_tensor.ctx, status); - TF_DataType dtype = FromDLDataType(dlmt->dl_tensor.dtype, status); + absl::optional device_name = + DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status); + if (!device_name.has_value()) { + status->status = + tensorflow::errors::InvalidArgument("Unsupported Device Type"); + return nullptr; + } + TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status); int num_dims = dlmt->dl_tensor.ndim; const int64_t* dims = dlmt->dl_tensor.shape; void* data = dlmt->dl_tensor.data; @@ -332,8 +282,8 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { total_bytes *= dims[i]; } TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( - ctx, device_name.c_str(), dtype, dims, num_dims, data, total_bytes, - &DeallocatorWrapperFunc, &dlmt, status); + ctx, device_name.value().c_str(), dtype, dims, num_dims, data, + total_bytes, &DeallocatorWrapperFunc, &dlmt, status); return handle; }; diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index 8384dfeadea..8a4f1788446 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -16,7 +16,6 @@ int_dtypes = [ float_dtypes = [np.float16, np.float32, np.float64] complex_dtypes = [np.complex64, np.complex128] dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] -standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] testcase_shapes = [ @@ -55,13 +54,29 @@ class DLPackTest(parameterized.TestCase, test.TestCase): dlcapsule = to_dlpack(tf_tensor) del tf_tensor # should still work tf_tensor2 = from_dlpack(dlcapsule) - + def ConsumeDLPackTensor(): from_dlpack(dlcapsule) # Should can be consumed only once self.assertRaisesRegex(Exception, ".*a DLPack tensor may be consumed at most once.*", ConsumeDLPackTensor) + def testUnsupportedType(self): + def case1(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.qint16) + dlcapsule = to_dlpack(tf_tensor) + + def case2(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.complex64) + dlcapsule = to_dlpack(tf_tensor) + + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case1) + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case2) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 0b801b4d51e..4e837c765c4 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1041,21 +1041,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); py::capsule capsule( dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) { - void* dlm_rptr = - PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); - if (dlm_rptr) { - tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); - PyCapsule_SetDestructor(capsule, nullptr); - } else { - // The tensor has been deleted. Clear any error from - // PyCapsule_GetPointer. - PyErr_Clear(); + if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) { + void* dlm_rptr = + PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); + if (dlm_rptr) { + tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); + PyCapsule_SetDestructor(capsule, nullptr); + } } }); - tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return capsule; }); @@ -1072,10 +1070,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) { } TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); + + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); - tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return py::handle(EagerTensorFromHandle(thandle)); });