Merge pull request #36862 from VoVAllen:dlpack
PiperOrigin-RevId: 297728301 Change-Id: I22a74c21f3459189f3e36a94ad521cdedb9b761b
This commit is contained in:
commit
9cd1a63a74
@ -95,6 +95,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
@ -328,10 +329,33 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"dlpack.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.cc"],
|
||||
hdrs = ["dlpack.h"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@dlpack",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
|
||||
# right now, remove this public rule when no longer needed (it should be
|
||||
# replaced by TF Lite)
|
||||
@ -345,6 +369,7 @@ filegroup(
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
),
|
||||
visibility = ["//visibility:public"],
|
||||
|
334
tensorflow/c/eager/dlpack.cc
Normal file
334
tensorflow/c/eager/dlpack.cc
Normal file
@ -0,0 +1,334 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // TF:dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Managing context for the DLManagedTensor, will manage the lifetime of
|
||||
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
|
||||
// original framework of destruction, and this context will be deleted also.
|
||||
struct TfDlManagedTensorCtx {
|
||||
TensorReference reference;
|
||||
std::vector<int64_t> shape;
|
||||
std::vector<int64_t> strides;
|
||||
DLManagedTensor tensor;
|
||||
|
||||
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
|
||||
};
|
||||
|
||||
// Gets tensor from eager tensor handle.
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Deleter for DLManagedTensor
|
||||
void DLManagedTensorDeleter(DLManagedTensor* arg) {
|
||||
TfDlManagedTensorCtx* owner =
|
||||
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
|
||||
owner->reference.Unref();
|
||||
delete owner;
|
||||
}
|
||||
|
||||
// Converts TF_DATAType to DLPack data type.
|
||||
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||
DLDataType dtype;
|
||||
dtype.lanes = 1;
|
||||
dtype.bits = TF_DataTypeSize(data_type) * 8;
|
||||
switch (data_type) {
|
||||
case TF_DataType::TF_HALF:
|
||||
case TF_DataType::TF_FLOAT:
|
||||
case TF_DataType::TF_DOUBLE:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
break;
|
||||
case TF_DataType::TF_INT8:
|
||||
case TF_DataType::TF_INT16:
|
||||
case TF_DataType::TF_INT32:
|
||||
case TF_DataType::TF_INT64:
|
||||
dtype.code = DLDataTypeCode::kDLInt;
|
||||
break;
|
||||
case TF_DataType::TF_BOOL:
|
||||
case TF_DataType::TF_UINT8:
|
||||
case TF_DataType::TF_UINT16:
|
||||
case TF_DataType::TF_UINT32:
|
||||
case TF_DataType::TF_UINT64:
|
||||
dtype.code = DLDataTypeCode::kDLUInt;
|
||||
break;
|
||||
case TF_DataType::TF_BFLOAT16:
|
||||
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||
break;
|
||||
default:
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
DataType_Name(static_cast<DataType>(data_type)),
|
||||
" is not supported by dlpack");
|
||||
break;
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
|
||||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
int device_id = 0;
|
||||
if (parsed_name.has_id) {
|
||||
device_id = parsed_name.id;
|
||||
}
|
||||
|
||||
ctx.device_id = device_id;
|
||||
if (device_type == "CPU") {
|
||||
ctx.device_type = DLDeviceType::kDLCPU;
|
||||
} else if (device_type == "GPU") {
|
||||
ctx.device_type = DLDeviceType::kDLGPU;
|
||||
} else {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported Device Type for dlpack");
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Converts DLContext to TF device name.
|
||||
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
|
||||
TF_Status* status) {
|
||||
switch (ctx.device_type) {
|
||||
case DLDeviceType::kDLCPU:
|
||||
return "CPU:0";
|
||||
case DLDeviceType::kDLGPU:
|
||||
return absl::StrCat("GPU:", ctx.device_id);
|
||||
default:
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Converts DLPack data type to TF_DATATYPE.
|
||||
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
|
||||
TF_DataType* tf_dtype) {
|
||||
switch (dtype.code) {
|
||||
case DLDataTypeCode::kDLUInt:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
*tf_dtype = TF_DataType::TF_UINT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_UINT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*tf_dtype = TF_DataType::TF_UINT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*tf_dtype = TF_DataType::TF_UINT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
|
||||
dtype.bits);
|
||||
}
|
||||
return Status::OK();
|
||||
case DLDataTypeCode::kDLInt:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
*tf_dtype = TF_DataType::TF_INT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_INT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*tf_dtype = TF_DataType::TF_INT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*tf_dtype = TF_DataType::TF_INT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
|
||||
dtype.bits);
|
||||
}
|
||||
return Status::OK();
|
||||
case DLDataTypeCode::kDLFloat:
|
||||
switch (dtype.bits) {
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_HALF;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*tf_dtype = TF_DataType::TF_FLOAT;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*tf_dtype = TF_DataType::TF_DOUBLE;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
|
||||
dtype.bits);
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLBfloat:
|
||||
switch (dtype.bits) {
|
||||
case 16:
|
||||
*tf_dtype = TF_DataType::TF_BFLOAT16;
|
||||
return Status::OK();
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unsupported BFloat bits: ", dtype.bits);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
|
||||
dtype.code);
|
||||
}
|
||||
}
|
||||
|
||||
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||
// TFE_NewTensorHandleFromDeviceMemory.
|
||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||
}
|
||||
|
||||
// Checks whether the stride array matches the layout of compact, row-majored
|
||||
// data.
|
||||
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
|
||||
int ndim) {
|
||||
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
|
||||
return false;
|
||||
}
|
||||
for (int i = ndim - 2; i >= 0; --i) {
|
||||
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
|
||||
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
|
||||
if (dlMTensor->deleter != nullptr) {
|
||||
dlMTensor->deleter(dlMTensor);
|
||||
}
|
||||
}
|
||||
|
||||
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const Tensor* tensor = GetTensorFromHandle(h, status);
|
||||
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
|
||||
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
|
||||
|
||||
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
|
||||
tf_dlm_tensor_ctx->reference = tensor_ref;
|
||||
|
||||
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
|
||||
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
|
||||
dlm_tensor->deleter = &DLManagedTensorDeleter;
|
||||
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
|
||||
int ndim = tensor->dims();
|
||||
dlm_tensor->dl_tensor.ndim = ndim;
|
||||
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
|
||||
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
|
||||
|
||||
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
|
||||
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
|
||||
shape_arr->resize(ndim);
|
||||
stride_arr->resize(ndim, 1);
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
(*shape_arr)[i] = tensor->dim_size(i);
|
||||
}
|
||||
for (int i = ndim - 2; i >= 0; --i) {
|
||||
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
|
||||
}
|
||||
|
||||
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
|
||||
// There are two ways to represent compact row-major data
|
||||
// 1) nullptr indicates tensor is compact and row-majored.
|
||||
// 2) fill in the strides array as the real case for compact row-major data.
|
||||
// Here we choose option 2, since some frameworks didn't handle the strides
|
||||
// argument properly.
|
||||
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
|
||||
dlm_tensor->dl_tensor.byte_offset =
|
||||
0; // TF doesn't handle the strides and byte_offsets here
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||
absl::optional<std::string> device_name =
|
||||
DeviceNameFromDlContext(dl_tensor->ctx, status);
|
||||
if (!device_name.has_value()) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unsupported Device Type");
|
||||
return nullptr;
|
||||
}
|
||||
TF_DataType dtype;
|
||||
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
|
||||
if (!s.ok()) {
|
||||
status->status = std::move(s);
|
||||
return nullptr;
|
||||
}
|
||||
int num_dims = dl_tensor->ndim;
|
||||
const int64_t* dims = dl_tensor->shape;
|
||||
void* data = dl_tensor->data;
|
||||
|
||||
size_t total_bytes = dl_tensor->dtype.bits / 8;
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
total_bytes *= dims[i];
|
||||
}
|
||||
|
||||
if (dl_tensor->strides != nullptr &&
|
||||
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
|
||||
num_dims)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Invalid strides array from DLPack");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/c/eager/dlpack.h
Normal file
39
tensorflow/c/eager/dlpack.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
|
||||
#define TENSORFLOW_C_EAGER_DLPACK_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// PyCapsule name for DLPack Tensor
|
||||
const char* const kDlTensorCapsuleName = "dltensor";
|
||||
|
||||
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
|
||||
// void* for further PyCapsule construction.
|
||||
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||
TF_Status* status);
|
||||
|
||||
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_DLPACK_H_
|
@ -190,6 +190,7 @@ py_library(
|
||||
"//tensorflow/python/distribute:estimator_training",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/dlpack",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
"//tensorflow/python/eager:profiler",
|
||||
@ -8069,7 +8070,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:platform",
|
||||
"//tensorflow/core/platform",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
|
@ -159,6 +159,10 @@ from tensorflow.python.debug.lib import check_numerics_callback
|
||||
from tensorflow.python.debug.lib import dumping_callback
|
||||
from tensorflow.python.ops import gen_debug_ops
|
||||
|
||||
# DLPack
|
||||
from tensorflow.python.dlpack.dlpack import from_dlpack
|
||||
from tensorflow.python.dlpack.dlpack import to_dlpack
|
||||
|
||||
# XLA JIT compiler APIs.
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
|
28
tensorflow/python/dlpack/BUILD
Normal file
28
tensorflow/python/dlpack/BUILD
Normal file
@ -0,0 +1,28 @@
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dlpack",
|
||||
srcs = ["dlpack.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "dlpack_test",
|
||||
srcs = ["dlpack_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dlpack",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:absltest",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
65
tensorflow/python/dlpack/dlpack.py
Normal file
65
tensorflow/python/dlpack/dlpack.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""DLPack modules for Tensorflow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("experimental.dlpack.to_dlpack", v1=[])
|
||||
def to_dlpack(tf_tensor):
|
||||
"""Returns the dlpack capsule representing the tensor.
|
||||
|
||||
This operation ensures the underlying data memory is ready when returns.
|
||||
|
||||
```python
|
||||
a = tf.tensor([1, 10])
|
||||
dlcapsule = tf.experimental.dlpack.to_dlpack(a)
|
||||
# dlcapsule represents the dlpack data structure
|
||||
```
|
||||
|
||||
Args:
|
||||
tf_tensor: Tensorflow eager tensor, to be converted to dlpack capsule.
|
||||
|
||||
Returns:
|
||||
A PyCapsule named as dltensor, which shares the underlying memory to other
|
||||
framework. This PyCapsule can be consumed only once.
|
||||
"""
|
||||
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
|
||||
|
||||
|
||||
@tf_export("experimental.dlpack.from_dlpack", v1=[])
|
||||
def from_dlpack(dlcapsule):
|
||||
"""Returns the Tensorflow eager tensor.
|
||||
|
||||
The returned tensor uses the memory shared by dlpack capsules from other
|
||||
framework.
|
||||
|
||||
```python
|
||||
a = tf.experimental.dlpack.from_dlpack(dlcapsule)
|
||||
# `a` uses the memory shared by dlpack
|
||||
```
|
||||
|
||||
Args:
|
||||
dlcapsule: A PyCapsule named as dltensor
|
||||
|
||||
Returns:
|
||||
A Tensorflow eager tensor
|
||||
"""
|
||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
101
tensorflow/python/dlpack/dlpack_test.py
Normal file
101
tensorflow/python/dlpack/dlpack_test.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for DLPack functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.dlpack import dlpack
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
int_dtypes = [
|
||||
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
|
||||
np.uint64
|
||||
]
|
||||
float_dtypes = [np.float16, np.float32, np.float64]
|
||||
complex_dtypes = [np.complex64, np.complex128]
|
||||
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
|
||||
|
||||
testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)]
|
||||
|
||||
|
||||
def FormatShapeAndDtype(shape, dtype):
|
||||
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
|
||||
|
||||
|
||||
def GetNamedTestParameters():
|
||||
result = []
|
||||
for dtype in dlpack_dtypes:
|
||||
for shape in testcase_shapes:
|
||||
result.append({
|
||||
"testcase_name": FormatShapeAndDtype(shape, dtype),
|
||||
"dtype": dtype,
|
||||
"shape": shape
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
class DLPackTest(parameterized.TestCase, test.TestCase):
|
||||
|
||||
@parameterized.named_parameters(GetNamedTestParameters())
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, shape)
|
||||
tf_tensor = constant_op.constant(np_array, dtype=dtype)
|
||||
dlcapsule = dlpack.to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
tf_tensor2 = dlpack.from_dlpack(dlcapsule)
|
||||
self.assertAllClose(np_array, tf_tensor2)
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
np.random.seed(42)
|
||||
np_array = np.random.randint(0, 10, (2, 3, 4))
|
||||
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
|
||||
dlcapsule = dlpack.to_dlpack(tf_tensor)
|
||||
del tf_tensor # should still work
|
||||
_ = dlpack.from_dlpack(dlcapsule)
|
||||
|
||||
def ConsumeDLPackTensor():
|
||||
dlpack.from_dlpack(dlcapsule) # Should can be consumed only once
|
||||
|
||||
self.assertRaisesRegex(Exception,
|
||||
".*a DLPack tensor may be consumed at most once.*",
|
||||
ConsumeDLPackTensor)
|
||||
|
||||
def testUnsupportedTypeToDLPack(self):
|
||||
|
||||
def UnsupportedQint16():
|
||||
tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16)
|
||||
_ = dlpack.to_dlpack(tf_tensor)
|
||||
|
||||
def UnsupportedComplex64():
|
||||
tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64)
|
||||
_ = dlpack.to_dlpack(tf_tensor)
|
||||
|
||||
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
||||
UnsupportedQint16)
|
||||
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
|
||||
UnsupportedComplex64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -35,6 +35,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:dlpack",
|
||||
"//tensorflow/c/eager:tape",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -93,6 +94,7 @@ py_library(
|
||||
":test",
|
||||
":wrap_function",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python/dlpack",
|
||||
"//tensorflow/python/eager/memory_tests:memory_test_util",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
@ -1047,6 +1048,50 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
|
||||
py::return_value_policy::reference);
|
||||
|
||||
// DLPack functions
|
||||
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
|
||||
PyObject* eager_tensor_pyobject_ptr = o.ptr();
|
||||
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
|
||||
py::capsule capsule(
|
||||
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
|
||||
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
|
||||
void* dlm_rptr =
|
||||
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
|
||||
if (dlm_rptr) {
|
||||
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
|
||||
PyCapsule_SetDestructor(capsule, nullptr);
|
||||
}
|
||||
}
|
||||
});
|
||||
return capsule;
|
||||
});
|
||||
|
||||
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
if (absl::string_view(pycapsule.name()) !=
|
||||
tensorflow::kDlTensorCapsuleName) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
|
||||
"Note that a DLPack tensor may be consumed at most once.",
|
||||
absl::string_view(pycapsule.name()));
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
}
|
||||
TFE_TensorHandle* thandle =
|
||||
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
||||
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
|
||||
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||
return py::handle(EagerTensorFromHandle(thandle));
|
||||
});
|
||||
|
||||
// C API Enum
|
||||
|
||||
py::enum_<TFE_ContextDevicePlacementPolicy>(
|
||||
|
@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"errors/__init__.py",
|
||||
"experimental/__init__.py",
|
||||
"experimental/tensorrt/__init__.py",
|
||||
"experimental/dlpack/__init__.py",
|
||||
"feature_column/__init__.py",
|
||||
"io/gfile/__init__.py",
|
||||
"graph_util/__init__.py",
|
||||
|
@ -0,0 +1,11 @@
|
||||
path: "tensorflow.experimental.dlpack"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "from_dlpack"
|
||||
argspec: "args=[\'dlcapsule\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "to_dlpack"
|
||||
argspec: "args=[\'tf_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "dlpack"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "tensorrt"
|
||||
mtype: "<type \'module\'>"
|
||||
|
@ -159,6 +159,7 @@ filegroup(
|
||||
"@com_google_protobuf//:LICENSE",
|
||||
"@com_googlesource_code_re2//:LICENSE",
|
||||
"@curl//:COPYING",
|
||||
"@dlpack//:LICENSE",
|
||||
"@double_conversion//:LICENSE",
|
||||
"@eigen_archive//:COPYING.MPL2",
|
||||
"@enum34_archive//:LICENSE",
|
||||
|
Loading…
Reference in New Issue
Block a user