Merge pull request #36862 from VoVAllen:dlpack
PiperOrigin-RevId: 297728301 Change-Id: I22a74c21f3459189f3e36a94ad521cdedb9b761b
This commit is contained in:
commit
9cd1a63a74
@ -95,6 +95,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_internal.h",
|
"c_api_internal.h",
|
||||||
|
"dlpack.h",
|
||||||
"operation_interface.h",
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
],
|
],
|
||||||
@ -328,10 +329,33 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"c_api.h",
|
"c_api.h",
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
|
"dlpack.h",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "dlpack",
|
||||||
|
srcs = ["dlpack.cc"],
|
||||||
|
hdrs = ["dlpack.h"],
|
||||||
|
copts = [
|
||||||
|
"-fexceptions",
|
||||||
|
"-fno-strict-aliasing",
|
||||||
|
],
|
||||||
|
features = ["-use_header_modules"],
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
":c_api_experimental",
|
||||||
|
":c_api_internal",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@dlpack",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||||
# right now, remove this public rule when no longer needed (it should be
|
# right now, remove this public rule when no longer needed (it should be
|
||||||
# replaced by TF Lite)
|
# replaced by TF Lite)
|
||||||
@ -345,6 +369,7 @@ filegroup(
|
|||||||
exclude = [
|
exclude = [
|
||||||
"c_api_experimental.cc",
|
"c_api_experimental.cc",
|
||||||
"*test*",
|
"*test*",
|
||||||
|
"*dlpack*",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
334
tensorflow/c/eager/dlpack.cc
Normal file
334
tensorflow/c/eager/dlpack.cc
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/dlpack.h"
|
||||||
|
|
||||||
|
#include "include/dlpack/dlpack.h" // TF:dlpack
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_reference.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Managing context for the DLManagedTensor, will manage the lifetime of
|
||||||
|
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
|
||||||
|
// original framework of destruction, and this context will be deleted also.
|
||||||
|
struct TfDlManagedTensorCtx {
|
||||||
|
TensorReference reference;
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
std::vector<int64_t> strides;
|
||||||
|
DLManagedTensor tensor;
|
||||||
|
|
||||||
|
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Gets tensor from eager tensor handle.
|
||||||
|
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"The passed in handle is a nullptr");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensorflow::TensorHandle* handle =
|
||||||
|
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||||
|
->Handle();
|
||||||
|
|
||||||
|
if (handle->IsRemote()) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"DLPack doesn't support remote tensor");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const tensorflow::Tensor* tensor;
|
||||||
|
status->status = handle->Tensor(&tensor);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deleter for DLManagedTensor
|
||||||
|
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||||
|
TfDlManagedTensorCtx* owner =
|
||||||
|
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
|
||||||
|
owner->reference.Unref();
|
||||||
|
delete owner;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts TF_DATAType to DLPack data type.
|
||||||
|
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||||
|
DLDataType dtype;
|
||||||
|
dtype.lanes = 1;
|
||||||
|
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||||
|
switch (data_type) {
|
||||||
|
case TF_DataType::TF_HALF:
|
||||||
|
case TF_DataType::TF_FLOAT:
|
||||||
|
case TF_DataType::TF_DOUBLE:
|
||||||
|
dtype.code = DLDataTypeCode::kDLFloat;
|
||||||
|
break;
|
||||||
|
case TF_DataType::TF_INT8:
|
||||||
|
case TF_DataType::TF_INT16:
|
||||||
|
case TF_DataType::TF_INT32:
|
||||||
|
case TF_DataType::TF_INT64:
|
||||||
|
dtype.code = DLDataTypeCode::kDLInt;
|
||||||
|
break;
|
||||||
|
case TF_DataType::TF_BOOL:
|
||||||
|
case TF_DataType::TF_UINT8:
|
||||||
|
case TF_DataType::TF_UINT16:
|
||||||
|
case TF_DataType::TF_UINT32:
|
||||||
|
case TF_DataType::TF_UINT64:
|
||||||
|
dtype.code = DLDataTypeCode::kDLUInt;
|
||||||
|
break;
|
||||||
|
case TF_DataType::TF_BFLOAT16:
|
||||||
|
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
DataType_Name(static_cast<DataType>(data_type)),
|
||||||
|
" is not supported by dlpack");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets DLPack's DLContext from eager tensor handle.
|
||||||
|
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
DLContext ctx;
|
||||||
|
const char* device_name = h->handle->DeviceName(&status->status);
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||||
|
std::string device_type = parsed_name.type;
|
||||||
|
int device_id = 0;
|
||||||
|
if (parsed_name.has_id) {
|
||||||
|
device_id = parsed_name.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.device_id = device_id;
|
||||||
|
if (device_type == "CPU") {
|
||||||
|
ctx.device_type = DLDeviceType::kDLCPU;
|
||||||
|
} else if (device_type == "GPU") {
|
||||||
|
ctx.device_type = DLDeviceType::kDLGPU;
|
||||||
|
} else {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Unsupported Device Type for dlpack");
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts DLContext to TF device name.
|
||||||
|
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
||||||
|
TF_Status* status) {
|
||||||
|
switch (ctx.device_type) {
|
||||||
|
case DLDeviceType::kDLCPU:
|
||||||
|
return "CPU:0";
|
||||||
|
case DLDeviceType::kDLGPU:
|
||||||
|
return absl::StrCat("GPU:", ctx.device_id);
|
||||||
|
default:
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts DLPack data type to TF_DATATYPE.
|
||||||
|
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
|
||||||
|
TF_DataType* tf_dtype) {
|
||||||
|
switch (dtype.code) {
|
||||||
|
case DLDataTypeCode::kDLUInt:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 8:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT8;
|
||||||
|
return Status::OK();
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT16;
|
||||||
|
return Status::OK();
|
||||||
|
case 32:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT32;
|
||||||
|
return Status::OK();
|
||||||
|
case 64:
|
||||||
|
*tf_dtype = TF_DataType::TF_UINT64;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
|
||||||
|
dtype.bits);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
case DLDataTypeCode::kDLInt:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 8:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT8;
|
||||||
|
return Status::OK();
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT16;
|
||||||
|
return Status::OK();
|
||||||
|
case 32:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT32;
|
||||||
|
return Status::OK();
|
||||||
|
case 64:
|
||||||
|
*tf_dtype = TF_DataType::TF_INT64;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
|
||||||
|
dtype.bits);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
case DLDataTypeCode::kDLFloat:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_HALF;
|
||||||
|
return Status::OK();
|
||||||
|
case 32:
|
||||||
|
*tf_dtype = TF_DataType::TF_FLOAT;
|
||||||
|
return Status::OK();
|
||||||
|
case 64:
|
||||||
|
*tf_dtype = TF_DataType::TF_DOUBLE;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
|
||||||
|
dtype.bits);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case DLDataTypeCode::kDLBfloat:
|
||||||
|
switch (dtype.bits) {
|
||||||
|
case 16:
|
||||||
|
*tf_dtype = TF_DataType::TF_BFLOAT16;
|
||||||
|
return Status::OK();
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument(
|
||||||
|
"Unsupported BFloat bits: ", dtype.bits);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
|
||||||
|
dtype.code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||||
|
// TFE_NewTensorHandleFromDeviceMemory.
|
||||||
|
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||||
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||||
|
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks whether the stride array matches the layout of compact, row-majored
|
||||||
|
// data.
|
||||||
|
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||||
|
int ndim) {
|
||||||
|
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
|
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||||
|
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
|
||||||
|
if (dlMTensor->deleter != nullptr) {
|
||||||
|
dlMTensor->deleter(dlMTensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
|
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||||
|
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||||
|
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||||
|
|
||||||
|
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||||
|
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||||
|
|
||||||
|
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||||
|
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||||
|
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||||
|
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||||
|
int ndim = tensor->dims();
|
||||||
|
dlm_tensor->dl_tensor.ndim = ndim;
|
||||||
|
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||||
|
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||||
|
|
||||||
|
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||||
|
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||||
|
shape_arr->resize(ndim);
|
||||||
|
stride_arr->resize(ndim, 1);
|
||||||
|
for (int i = 0; i < ndim; i++) {
|
||||||
|
(*shape_arr)[i] = tensor->dim_size(i);
|
||||||
|
}
|
||||||
|
for (int i = ndim - 2; i >= 0; --i) {
|
||||||
|
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||||
|
// There are two ways to represent compact row-major data
|
||||||
|
// 1) nullptr indicates tensor is compact and row-majored.
|
||||||
|
// 2) fill in the strides array as the real case for compact row-major data.
|
||||||
|
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||||
|
// argument properly.
|
||||||
|
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||||
|
dlm_tensor->dl_tensor.byte_offset =
|
||||||
|
0; // TF doesn't handle the strides and byte_offsets here
|
||||||
|
return static_cast<void*>(dlm_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||||
|
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||||
|
absl::optional<std::string> device_name =
|
||||||
|
DeviceNameFromDlContext(dl_tensor->ctx, status);
|
||||||
|
if (!device_name.has_value()) {
|
||||||
|
status->status =
|
||||||
|
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TF_DataType dtype;
|
||||||
|
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
|
||||||
|
if (!s.ok()) {
|
||||||
|
status->status = std::move(s);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
int num_dims = dl_tensor->ndim;
|
||||||
|
const int64_t* dims = dl_tensor->shape;
|
||||||
|
void* data = dl_tensor->data;
|
||||||
|
|
||||||
|
size_t total_bytes = dl_tensor->dtype.bits / 8;
|
||||||
|
for (int i = 0; i < num_dims; i++) {
|
||||||
|
total_bytes *= dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dl_tensor->strides != nullptr &&
|
||||||
|
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||||
|
num_dims)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Invalid strides array from DLPack");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||||
|
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||||
|
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||||
|
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
39
tensorflow/c/eager/dlpack.h
Normal file
39
tensorflow/c/eager/dlpack.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_DLPACK_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// PyCapsule name for DLPack Tensor
|
||||||
|
const char* const kDlTensorCapsuleName = "dltensor";
|
||||||
|
|
||||||
|
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
|
||||||
|
// void* for further PyCapsule construction.
|
||||||
|
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||||
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_DLPACK_H_
|
@ -190,6 +190,7 @@ py_library(
|
|||||||
"//tensorflow/python/distribute:estimator_training",
|
"//tensorflow/python/distribute:estimator_training",
|
||||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||||
"//tensorflow/python/distribute:strategy_combinations",
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
|
"//tensorflow/python/dlpack",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:monitoring",
|
"//tensorflow/python/eager:monitoring",
|
||||||
"//tensorflow/python/eager:profiler",
|
"//tensorflow/python/eager:profiler",
|
||||||
@ -8069,7 +8070,7 @@ tf_python_pybind_extension(
|
|||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
"//tensorflow/core:lib_headers_for_pybind",
|
"//tensorflow/core:lib_headers_for_pybind",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:platform",
|
"//tensorflow/core/platform",
|
||||||
] + if_static(
|
] + if_static(
|
||||||
extra_deps = [
|
extra_deps = [
|
||||||
"//tensorflow/core:eager_service_proto_cc",
|
"//tensorflow/core:eager_service_proto_cc",
|
||||||
|
@ -159,6 +159,10 @@ from tensorflow.python.debug.lib import check_numerics_callback
|
|||||||
from tensorflow.python.debug.lib import dumping_callback
|
from tensorflow.python.debug.lib import dumping_callback
|
||||||
from tensorflow.python.ops import gen_debug_ops
|
from tensorflow.python.ops import gen_debug_ops
|
||||||
|
|
||||||
|
# DLPack
|
||||||
|
from tensorflow.python.dlpack.dlpack import from_dlpack
|
||||||
|
from tensorflow.python.dlpack.dlpack import to_dlpack
|
||||||
|
|
||||||
# XLA JIT compiler APIs.
|
# XLA JIT compiler APIs.
|
||||||
from tensorflow.python.compiler.xla import jit
|
from tensorflow.python.compiler.xla import jit
|
||||||
from tensorflow.python.compiler.xla import xla
|
from tensorflow.python.compiler.xla import xla
|
||||||
|
28
tensorflow/python/dlpack/BUILD
Normal file
28
tensorflow/python/dlpack/BUILD
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:private"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dlpack",
|
||||||
|
srcs = ["dlpack.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "dlpack_test",
|
||||||
|
srcs = ["dlpack_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":dlpack",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
"@absl_py//absl/testing:absltest",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
65
tensorflow/python/dlpack/dlpack.py
Normal file
65
tensorflow/python/dlpack/dlpack.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""DLPack modules for Tensorflow."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tfe
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("experimental.dlpack.to_dlpack", v1=[])
|
||||||
|
def to_dlpack(tf_tensor):
|
||||||
|
"""Returns the dlpack capsule representing the tensor.
|
||||||
|
|
||||||
|
This operation ensures the underlying data memory is ready when returns.
|
||||||
|
|
||||||
|
```python
|
||||||
|
a = tf.tensor([1, 10])
|
||||||
|
dlcapsule = tf.experimental.dlpack.to_dlpack(a)
|
||||||
|
# dlcapsule represents the dlpack data structure
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_tensor: Tensorflow eager tensor, to be converted to dlpack capsule.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PyCapsule named as dltensor, which shares the underlying memory to other
|
||||||
|
framework. This PyCapsule can be consumed only once.
|
||||||
|
"""
|
||||||
|
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("experimental.dlpack.from_dlpack", v1=[])
|
||||||
|
def from_dlpack(dlcapsule):
|
||||||
|
"""Returns the Tensorflow eager tensor.
|
||||||
|
|
||||||
|
The returned tensor uses the memory shared by dlpack capsules from other
|
||||||
|
framework.
|
||||||
|
|
||||||
|
```python
|
||||||
|
a = tf.experimental.dlpack.from_dlpack(dlcapsule)
|
||||||
|
# `a` uses the memory shared by dlpack
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dlcapsule: A PyCapsule named as dltensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensorflow eager tensor
|
||||||
|
"""
|
||||||
|
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
101
tensorflow/python/dlpack/dlpack_test.py
Normal file
101
tensorflow/python/dlpack/dlpack_test.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for DLPack functions."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.dlpack import dlpack
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
int_dtypes = [
|
||||||
|
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
|
||||||
|
np.uint64
|
||||||
|
]
|
||||||
|
float_dtypes = [np.float16, np.float32, np.float64]
|
||||||
|
complex_dtypes = [np.complex64, np.complex128]
|
||||||
|
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
|
||||||
|
|
||||||
|
testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)]
|
||||||
|
|
||||||
|
|
||||||
|
def FormatShapeAndDtype(shape, dtype):
|
||||||
|
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
||||||
|
|
||||||
|
|
||||||
|
def GetNamedTestParameters():
|
||||||
|
result = []
|
||||||
|
for dtype in dlpack_dtypes:
|
||||||
|
for shape in testcase_shapes:
|
||||||
|
result.append({
|
||||||
|
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
||||||
|
"dtype": dtype,
|
||||||
|
"shape": shape
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(GetNamedTestParameters())
|
||||||
|
def testRoundTrip(self, dtype, shape):
|
||||||
|
np.random.seed(42)
|
||||||
|
np_array = np.random.randint(0, 10, shape)
|
||||||
|
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
||||||
|
dlcapsule = dlpack.to_dlpack(tf_tensor)
|
||||||
|
del tf_tensor # should still work
|
||||||
|
tf_tensor2 = dlpack.from_dlpack(dlcapsule)
|
||||||
|
self.assertAllClose(np_array, tf_tensor2)
|
||||||
|
|
||||||
|
def testTensorsCanBeConsumedOnceOnly(self):
|
||||||
|
np.random.seed(42)
|
||||||
|
np_array = np.random.randint(0, 10, (2, 3, 4))
|
||||||
|
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
||||||
|
dlcapsule = dlpack.to_dlpack(tf_tensor)
|
||||||
|
del tf_tensor # should still work
|
||||||
|
_ = dlpack.from_dlpack(dlcapsule)
|
||||||
|
|
||||||
|
def ConsumeDLPackTensor():
|
||||||
|
dlpack.from_dlpack(dlcapsule) # Should can be consumed only once
|
||||||
|
|
||||||
|
self.assertRaisesRegex(Exception,
|
||||||
|
".*a DLPack tensor may be consumed at most once.*",
|
||||||
|
ConsumeDLPackTensor)
|
||||||
|
|
||||||
|
def testUnsupportedTypeToDLPack(self):
|
||||||
|
|
||||||
|
def UnsupportedQint16():
|
||||||
|
tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16)
|
||||||
|
_ = dlpack.to_dlpack(tf_tensor)
|
||||||
|
|
||||||
|
def UnsupportedComplex64():
|
||||||
|
tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64)
|
||||||
|
_ = dlpack.to_dlpack(tf_tensor)
|
||||||
|
|
||||||
|
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
||||||
|
UnsupportedQint16)
|
||||||
|
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
||||||
|
UnsupportedComplex64)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
@ -35,6 +35,7 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/c/eager:c_api_experimental",
|
"//tensorflow/c/eager:c_api_experimental",
|
||||||
"//tensorflow/c/eager:c_api_internal",
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
|
"//tensorflow/c/eager:dlpack",
|
||||||
"//tensorflow/c/eager:tape",
|
"//tensorflow/c/eager:tape",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -93,6 +94,7 @@ py_library(
|
|||||||
":test",
|
":test",
|
||||||
":wrap_function",
|
":wrap_function",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
|
"//tensorflow/python/dlpack",
|
||||||
"//tensorflow/python/eager/memory_tests:memory_test_util",
|
"//tensorflow/python/eager/memory_tests:memory_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/dlpack.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
@ -1047,6 +1048,50 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
|
m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
|
||||||
py::return_value_policy::reference);
|
py::return_value_policy::reference);
|
||||||
|
|
||||||
|
// DLPack functions
|
||||||
|
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
|
||||||
|
PyObject* eager_tensor_pyobject_ptr = o.ptr();
|
||||||
|
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
|
||||||
|
py::capsule capsule(
|
||||||
|
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
|
||||||
|
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
|
||||||
|
void* dlm_rptr =
|
||||||
|
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
|
||||||
|
if (dlm_rptr) {
|
||||||
|
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
|
||||||
|
PyCapsule_SetDestructor(capsule, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return capsule;
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) {
|
||||||
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
if (absl::string_view(pycapsule.name()) !=
|
||||||
|
tensorflow::kDlTensorCapsuleName) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
||||||
|
"Note that a DLPack tensor may be consumed at most once.",
|
||||||
|
absl::string_view(pycapsule.name()));
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
}
|
||||||
|
TFE_TensorHandle* thandle =
|
||||||
|
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
||||||
|
|
||||||
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
|
||||||
|
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||||
|
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||||
|
return py::handle(EagerTensorFromHandle(thandle));
|
||||||
|
});
|
||||||
|
|
||||||
// C API Enum
|
// C API Enum
|
||||||
|
|
||||||
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
||||||
|
@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
|
|||||||
"errors/__init__.py",
|
"errors/__init__.py",
|
||||||
"experimental/__init__.py",
|
"experimental/__init__.py",
|
||||||
"experimental/tensorrt/__init__.py",
|
"experimental/tensorrt/__init__.py",
|
||||||
|
"experimental/dlpack/__init__.py",
|
||||||
"feature_column/__init__.py",
|
"feature_column/__init__.py",
|
||||||
"io/gfile/__init__.py",
|
"io/gfile/__init__.py",
|
||||||
"graph_util/__init__.py",
|
"graph_util/__init__.py",
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
path: "tensorflow.experimental.dlpack"
|
||||||
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "from_dlpack"
|
||||||
|
argspec: "args=[\'dlcapsule\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "to_dlpack"
|
||||||
|
argspec: "args=[\'tf_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,9 @@
|
|||||||
path: "tensorflow.experimental"
|
path: "tensorflow.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "dlpack"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "tensorrt"
|
name: "tensorrt"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
|
@ -159,6 +159,7 @@ filegroup(
|
|||||||
"@com_google_protobuf//:LICENSE",
|
"@com_google_protobuf//:LICENSE",
|
||||||
"@com_googlesource_code_re2//:LICENSE",
|
"@com_googlesource_code_re2//:LICENSE",
|
||||||
"@curl//:COPYING",
|
"@curl//:COPYING",
|
||||||
|
"@dlpack//:LICENSE",
|
||||||
"@double_conversion//:LICENSE",
|
"@double_conversion//:LICENSE",
|
||||||
"@eigen_archive//:COPYING.MPL2",
|
"@eigen_archive//:COPYING.MPL2",
|
||||||
"@enum34_archive//:LICENSE",
|
"@enum34_archive//:LICENSE",
|
||||||
|
Loading…
Reference in New Issue
Block a user