fix
This commit is contained in:
parent
7c3ac77ee1
commit
b90808b7b4
@ -36,11 +36,10 @@ struct TfDlManagedTensorCtx {
|
|||||||
std::vector<int64_t> strides;
|
std::vector<int64_t> strides;
|
||||||
DLManagedTensor tensor;
|
DLManagedTensor tensor;
|
||||||
|
|
||||||
TfDlManagedTensorCtx(const TensorReference& ref)
|
TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
|
||||||
: reference(ref), shape(), tensor() {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get tensor from eager tensor handle
|
// Gets tensor from eager tensor handle.
|
||||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
@ -72,7 +71,8 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
|||||||
delete owner;
|
delete owner;
|
||||||
}
|
}
|
||||||
|
|
||||||
DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
|
// Converts TF_DATAType to DLPack data type.
|
||||||
|
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||||
DLDataType dtype;
|
DLDataType dtype;
|
||||||
dtype.lanes = 1;
|
dtype.lanes = 1;
|
||||||
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||||
@ -107,16 +107,17 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) {
|
|||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
// Gets DLPack's DLContext from eager tensor handle.
|
||||||
|
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
DLContext ctx;
|
DLContext ctx;
|
||||||
const char* device_name = h->handle->DeviceName(&status->status);
|
const char* device_name = h->handle->DeviceName(&status->status);
|
||||||
DeviceNameUtils::ParsedName parsed_name;
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||||
std::string device_type = parsed_name.type;
|
std::string device_type = parsed_name.type;
|
||||||
int device_id = -1;
|
int device_id = 0;
|
||||||
if (parsed_name.has_id) {
|
if (parsed_name.has_id) {
|
||||||
device_id = parsed_name.id;
|
device_id = parsed_name.id;
|
||||||
} // Question: Is it possible that it doens't have id?
|
}
|
||||||
|
|
||||||
ctx.device_id = device_id;
|
ctx.device_id = device_id;
|
||||||
if (device_type == "CPU") {
|
if (device_type == "CPU") {
|
||||||
@ -131,7 +132,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert DLContext to TF device name
|
// Converts DLContext to TF device name.
|
||||||
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
switch (ctx.device_type) {
|
switch (ctx.device_type) {
|
||||||
@ -144,7 +145,7 @@ absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert DLPack data type to TF_DATATYPE
|
// Converts DLPack data type to TF_DATATYPE.
|
||||||
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
|
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
TF_DataType tf_dtype;
|
TF_DataType tf_dtype;
|
||||||
@ -221,16 +222,15 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
|
|||||||
return tf_dtype;
|
return tf_dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrapper function to match the function signature
|
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||||
// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the
|
// TFE_NewTensorHandleFromDeviceMemory.
|
||||||
// DLManagedTensor
|
|
||||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether the stride array matches the layout of compact, row-majored
|
// Checks whether the stride array matches the layout of compact, row-majored
|
||||||
// data
|
// data.
|
||||||
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||||
@ -263,11 +263,11 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||||
dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
|
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||||
int ndim = tensor->dims();
|
int ndim = tensor->dims();
|
||||||
dlm_tensor->dl_tensor.ndim = ndim;
|
dlm_tensor->dl_tensor.ndim = ndim;
|
||||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||||
dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
|
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||||
|
|
||||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
@ -283,9 +283,9 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||||
// There are two ways to represent compact row-major data
|
// There are two ways to represent compact row-major data
|
||||||
// 1) nullptr indicates tensor is compact and row-majored.
|
// 1) nullptr indicates tensor is compact and row-majored.
|
||||||
// 2) fill in the strides array as the real case for compact row-major data
|
// 2) fill in the strides array as the real case for compact row-major data.
|
||||||
// Here we choose option 2, since some framework didn't handle the strides
|
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||||
// argument properly
|
// argument properly.
|
||||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||||
dlm_tensor->dl_tensor.byte_offset =
|
dlm_tensor->dl_tensor.byte_offset =
|
||||||
0; // TF doesn't handle the strides and byte_offsets here
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
|
@ -24,16 +24,16 @@ namespace tensorflow {
|
|||||||
// PyCapsule name for DLPack Tensor
|
// PyCapsule name for DLPack Tensor
|
||||||
const char* const kDlTensorCapsuleName = "dltensor";
|
const char* const kDlTensorCapsuleName = "dltensor";
|
||||||
|
|
||||||
// Convert eager tensor handle to DLPack (DLManagedTensor*), and return the
|
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
|
||||||
// void* for further PyCapsule construction
|
// void* for further PyCapsule construction.
|
||||||
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Convert DLPack (DLManagedTensor*) to eager tensor handle
|
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule
|
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user