This commit is contained in:
VoVAllen 2020-02-24 11:21:23 +00:00
parent 61da5aaff3
commit 7c3ac77ee1
2 changed files with 56 additions and 47 deletions

View File

@ -40,6 +40,7 @@ struct TfDlManagedTensorCtx {
: reference(ref), shape(), tensor() {} : reference(ref), shape(), tensor() {}
}; };
// Get tensor from eager tensor handle
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) { if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -63,6 +64,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
return tensor; return tensor;
}; };
// Deleter for DLManagedTensor
void DLManagedTensorDeleter(DLManagedTensor* arg) { void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner = TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx); static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
@ -129,47 +131,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) {
return ctx; return ctx;
} }
DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, // Convert DLContext to TF device name
TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data
// Here we choose option 2, since some framework didn't handle the strides
// argument properly
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return &tf_dlm_tensor_ctx->tensor;
}
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx, absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) { TF_Status* status) {
switch (ctx.device_type) { switch (ctx.device_type) {
@ -181,6 +143,8 @@ absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
return absl::nullopt; return absl::nullopt;
}; };
} }
// Convert DLPack data type to TF_DATATYPE
TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_Status* status) { TF_Status* status) {
TF_DataType tf_dtype; TF_DataType tf_dtype;
@ -257,11 +221,16 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype,
return tf_dtype; return tf_dtype;
} }
// Wrapper function to match the function signature
// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the
// DLManagedTensor
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr); DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt)); dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
} }
// Check whether the stride array matches the layout of compact, row-majored
// data
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) { int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) { if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
@ -284,8 +253,43 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
} }
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
DLManagedTensor* tfdlmtensor = TFEHandleToTfDlManagedTensorCtx(h, status); const Tensor* tensor = GetTensorFromHandle(h, status);
return static_cast<void*>(tfdlmtensor); TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDLContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data
// Here we choose option 2, since some framework didn't handle the strides
// argument properly
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
} }
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {

View File

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_DLPACK_H_ #ifndef TENSORFLOW_C_DLPACK_H_
#define TENSORFLOW_C_DLPACK_H_ #define TENSORFLOW_C_DLPACK_H_
@ -22,14 +21,20 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor"; const char* const kDlTensorCapsuleName = "dltensor";
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); // Convert eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); // Convert DLPack (DLManagedTensor*) to eager tensor handle
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_DLPACK_H_ #endif // TENSORFLOW_C_DLPACK_H_