fix
This commit is contained in:
parent
90a55447a7
commit
7aa5009b35
@ -27,12 +27,17 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Managing context for the DLManagedTensor, will manage the lifetime of
|
||||||
|
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
|
||||||
|
// original framework of destruction, and this context will be deleted also.
|
||||||
struct TfDlManagedTensorCtx {
|
struct TfDlManagedTensorCtx {
|
||||||
TensorReference* reference;
|
TensorReference reference;
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
|
std::vector<int64_t> strides;
|
||||||
DLManagedTensor tensor;
|
DLManagedTensor tensor;
|
||||||
|
|
||||||
TfDlManagedTensorCtx()
|
TfDlManagedTensorCtx(const TensorReference& ref)
|
||||||
|
: reference(ref), shape(), tensor() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
@ -61,8 +66,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||||
TfDlManagedTensorCtx* owner =
|
TfDlManagedTensorCtx* owner =
|
||||||
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
|
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
|
||||||
owner->reference->Unref();
|
owner->reference.Unref();
|
||||||
delete owner->reference;
|
|
||||||
delete owner;
|
delete owner;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,31 +133,41 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h,
|
|||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx;
|
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||||
|
|
||||||
TensorReference* tensor_ref =
|
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||||
new TensorReference(*tensor); // This will call buf_->Ref()
|
|
||||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||||
tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx;
|
|
||||||
tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter;
|
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status);
|
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||||
|
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||||
|
dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
|
||||||
int ndim = tensor->dims();
|
int ndim = tensor->dims();
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim;
|
dlm_tensor->dl_tensor.ndim = ndim;
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.data =
|
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||||
TFE_TensorHandleDevicePointer(h, status);
|
dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status);
|
|
||||||
|
|
||||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
shape_arr->resize(ndim);
|
shape_arr->resize(ndim);
|
||||||
|
stride_arr->resize(ndim, 1);
|
||||||
for (int i = 0; i < ndim; i++) {
|
for (int i = 0; i < ndim; i++) {
|
||||||
(*shape_arr)[i] = tensor->dim_size(i);
|
(*shape_arr)[i] = tensor->dim_size(i);
|
||||||
}
|
}
|
||||||
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
|
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.shape =
|
dlm_tensor->dl_tensor.shape =
|
||||||
reinterpret_cast<std::int64_t*>(shape_arr->data());
|
reinterpret_cast<std::int64_t*>(shape_arr->data());
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.strides =
|
// There are two ways to represent compact row-major data
|
||||||
nullptr; // nullptr indicates tensor is compact and row-majored.
|
// 1) nullptr indicates tensor is compact and row-majored.
|
||||||
tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset =
|
// 2) fill in the strides array as the real case for compact row-major data
|
||||||
|
// Here we choose option 2, since some framework didn't handle the strides
|
||||||
|
// argument properly
|
||||||
|
dlm_tensor->dl_tensor.strides =
|
||||||
|
reinterpret_cast<std::int64_t*>(stride_arr->data());
|
||||||
|
dlm_tensor->dl_tensor.byte_offset =
|
||||||
0; // TF doesn't handle the strides and byte_offsets here
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
return &tf_dlm_tensor_ctx->tensor;
|
return &tf_dlm_tensor_ctx->tensor;
|
||||||
}
|
}
|
||||||
@ -250,6 +264,15 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
|||||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||||
|
int ndim) {
|
||||||
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
|
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||||
@ -268,23 +291,32 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
|||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||||
|
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||||
absl::optional<std::string> device_name =
|
absl::optional<std::string> device_name =
|
||||||
DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status);
|
DeviceNameFromDlContext(dl_tensor->ctx, status);
|
||||||
if (!device_name.has_value()) {
|
if (!device_name.has_value()) {
|
||||||
status->status =
|
status->status =
|
||||||
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status);
|
TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status);
|
||||||
int num_dims = dlmt->dl_tensor.ndim;
|
int num_dims = dl_tensor->ndim;
|
||||||
const int64_t* dims = dlmt->dl_tensor.shape;
|
const int64_t* dims = dl_tensor->shape;
|
||||||
void* data = dlmt->dl_tensor.data;
|
void* data = dl_tensor->data;
|
||||||
|
|
||||||
size_t total_bytes = dlmt->dl_tensor.dtype.bits / 8;
|
size_t total_bytes = dl_tensor->dtype.bits / 8;
|
||||||
for (int i = 0; i < num_dims; i++) {
|
for (int i = 0; i < num_dims; i++) {
|
||||||
total_bytes *= dims[i];
|
total_bytes *= dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((dl_tensor->strides != nullptr) &&
|
||||||
|
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||||
|
num_dims)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Invalid strides array from DLPack");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||||
|
Loading…
Reference in New Issue
Block a user