This commit is contained in:
VoVAllen 2020-02-25 16:45:19 +00:00
parent 7c3ac77ee1
commit b90808b7b4
2 changed files with 23 additions and 23 deletions

View File

@ -36,11 +36,10 @@ struct TfDlManagedTensorCtx {
std::vector<int64_t> strides; std::vector<int64_t> strides;
DLManagedTensor tensor; DLManagedTensor tensor;
TfDlManagedTensorCtx(const TensorReference& ref) TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
: reference(ref), shape(), tensor() {}
}; };
// Get tensor from eager tensor handle // Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -72,7 +71,8 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) {
delete owner; delete owner;
} }
DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { // Converts TF_DATAType to DLPack data type.
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype; DLDataType dtype;
dtype.lanes = 1; dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8; dtype.bits = TF_DataTypeSize(data_type) * 8;
@ -107,16 +107,17 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
return dtype; return dtype;
} }
DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { // Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx; DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status); const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type; std::string device_type = parsed_name.type;
int device_id = -1; int device_id = 0;
if (parsed_name.has_id) { if (parsed_name.has_id) {
device_id = parsed_name.id; device_id = parsed_name.id;
} // Question: Is it possible that it doens't have id? }
ctx.device_id = device_id; ctx.device_id = device_id;
if (device_type == "CPU") { if (device_type == "CPU") {
@ -131,7 +132,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
return ctx; return ctx;
} }
// Convert DLContext to TF device name // Converts DLContext to TF device name.
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx, absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) { TF_Status* status) {
switch (ctx.device_type) { switch (ctx.device_type) {
@ -144,7 +145,7 @@ absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
}; };
} }
// Convert DLPack data type to TF_DATATYPE // Converts DLPack data type to TF_DATATYPE.
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_Status* status) { TF_Status* status) {
TF_DataType tf_dtype; TF_DataType tf_dtype;
@ -221,16 +222,15 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
return tf_dtype; return tf_dtype;
} }
// Wrapper function to match the function signature // Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the // TFE_NewTensorHandleFromDeviceMemory.
// DLManagedTensor
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr); DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt)); dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
} }
// Check whether the stride array matches the layout of compact, row-majored // Checks whether the stride array matches the layout of compact, row-majored
// data // data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) { int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) { if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
@ -263,11 +263,11 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter; dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
int ndim = tensor->dims(); int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim; dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides; std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
@ -283,9 +283,9 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data // There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored. // 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data // 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some framework didn't handle the strides // Here we choose option 2, since some frameworks didn't handle the strides
// argument properly // argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset = dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here 0; // TF doesn't handle the strides and byte_offsets here

View File

@ -24,16 +24,16 @@ namespace tensorflow {
// PyCapsule name for DLPack Tensor // PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor"; const char* const kDlTensorCapsuleName = "dltensor";
// Convert eager tensor handle to DLPack (DLManagedTensor*), and return the // Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction // void* for further PyCapsule construction.
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status); TF_Status* status);
// Convert DLPack (DLManagedTensor*) to eager tensor handle // Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status); TF_Status* status);
// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule // Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow } // namespace tensorflow