Merge pull request #36862 from VoVAllen:dlpack

PiperOrigin-RevId: 297728301
Change-Id: I22a74c21f3459189f3e36a94ad521cdedb9b761b
This commit is contained in:
TensorFlower Gardener 2020-02-27 17:38:16 -08:00
commit 9cd1a63a74
14 changed files with 662 additions and 1 deletions

View File

@ -95,6 +95,7 @@ filegroup(
srcs = [ srcs = [
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"dlpack.h",
"operation_interface.h", "operation_interface.h",
"tensor_handle_interface.h", "tensor_handle_interface.h",
], ],
@ -328,10 +329,33 @@ filegroup(
srcs = [ srcs = [
"c_api.h", "c_api.h",
"c_api_experimental.h", "c_api_experimental.h",
"dlpack.h",
], ],
visibility = ["//tensorflow:__subpackages__"], visibility = ["//tensorflow:__subpackages__"],
) )
cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
hdrs = ["dlpack.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@dlpack",
],
)
# TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime # TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime
# right now, remove this public rule when no longer needed (it should be # right now, remove this public rule when no longer needed (it should be
# replaced by TF Lite) # replaced by TF Lite)
@ -345,6 +369,7 @@ filegroup(
exclude = [ exclude = [
"c_api_experimental.cc", "c_api_experimental.cc",
"*test*", "*test*",
"*dlpack*",
], ],
), ),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],

View File

@ -0,0 +1,334 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // TF:dlpack
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
// Managing context for the DLManagedTensor, will manage the lifetime of
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
TensorReference reference;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
};
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || !h->handle->IsValid(&status->status)) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return tensor;
}
// Deleter for DLManagedTensor
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
owner->reference.Unref();
delete owner;
}
// Converts TF_DATAType to DLPack data type.
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT8:
case TF_DataType::TF_INT16:
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name = h->handle->DeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = 0;
if (parsed_name.has_id) {
device_id = parsed_name.id;
}
ctx.device_id = device_id;
if (device_type == "CPU") {
ctx.device_type = DLDeviceType::kDLCPU;
} else if (device_type == "GPU") {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for dlpack");
}
return ctx;
}
// Converts DLContext to TF device name.
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return absl::nullopt;
}
}
// Converts DLPack data type to TF_DATATYPE.
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_DataType* tf_dtype) {
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_UINT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_UINT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_UINT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_UINT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_INT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_INT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_INT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_INT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_HALF;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_FLOAT;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_DOUBLE;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_BFLOAT16;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
dtype.code);
}
}
// Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory.
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
}
// Checks whether the stride array matches the layout of compact, row-majored
// data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
}
}
return true;
}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(dlMTensor);
}
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype;
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
if (!s.ok()) {
status->status = std::move(s);
return nullptr;
}
int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;
size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid strides array from DLPack");
return nullptr;
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
return handle;
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
#define TENSORFLOW_C_EAGER_DLPACK_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
// PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor";
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction.
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status);
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_DLPACK_H_

View File

@ -190,6 +190,7 @@ py_library(
"//tensorflow/python/distribute:estimator_training", "//tensorflow/python/distribute:estimator_training",
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/dlpack",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:monitoring", "//tensorflow/python/eager:monitoring",
"//tensorflow/python/eager:profiler", "//tensorflow/python/eager:profiler",
@ -8069,7 +8070,7 @@ tf_python_pybind_extension(
"//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:lib_headers_for_pybind",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:platform", "//tensorflow/core/platform",
] + if_static( ] + if_static(
extra_deps = [ extra_deps = [
"//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:eager_service_proto_cc",

View File

@ -159,6 +159,10 @@ from tensorflow.python.debug.lib import check_numerics_callback
from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.debug.lib import dumping_callback
from tensorflow.python.ops import gen_debug_ops from tensorflow.python.ops import gen_debug_ops
# DLPack
from tensorflow.python.dlpack.dlpack import from_dlpack
from tensorflow.python.dlpack.dlpack import to_dlpack
# XLA JIT compiler APIs. # XLA JIT compiler APIs.
from tensorflow.python.compiler.xla import jit from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla from tensorflow.python.compiler.xla import xla

View File

@ -0,0 +1,28 @@
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "dlpack",
srcs = ["dlpack.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
],
)
cuda_py_test(
name = "dlpack_test",
srcs = ["dlpack_test.py"],
srcs_version = "PY2AND3",
deps = [
":dlpack",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -0,0 +1,65 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DLPack modules for Tensorflow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe
from tensorflow.python.util.tf_export import tf_export
@tf_export("experimental.dlpack.to_dlpack", v1=[])
def to_dlpack(tf_tensor):
"""Returns the dlpack capsule representing the tensor.
This operation ensures the underlying data memory is ready when returns.
```python
a = tf.tensor([1, 10])
dlcapsule = tf.experimental.dlpack.to_dlpack(a)
# dlcapsule represents the dlpack data structure
```
Args:
tf_tensor: Tensorflow eager tensor, to be converted to dlpack capsule.
Returns:
A PyCapsule named as dltensor, which shares the underlying memory to other
framework. This PyCapsule can be consumed only once.
"""
return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor)
@tf_export("experimental.dlpack.from_dlpack", v1=[])
def from_dlpack(dlcapsule):
"""Returns the Tensorflow eager tensor.
The returned tensor uses the memory shared by dlpack capsules from other
framework.
```python
a = tf.experimental.dlpack.from_dlpack(dlcapsule)
# `a` uses the memory shared by dlpack
```
Args:
dlcapsule: A PyCapsule named as dltensor
Returns:
A Tensorflow eager tensor
"""
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)

View File

@ -0,0 +1,101 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DLPack functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.dlpack import dlpack
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
int_dtypes = [
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64
]
float_dtypes = [np.float16, np.float32, np.float64]
complex_dtypes = [np.complex64, np.complex128]
dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)]
def FormatShapeAndDtype(shape, dtype):
return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
def GetNamedTestParameters():
result = []
for dtype in dlpack_dtypes:
for shape in testcase_shapes:
result.append({
"testcase_name": FormatShapeAndDtype(shape, dtype),
"dtype": dtype,
"shape": shape
})
return result
class DLPackTest(parameterized.TestCase, test.TestCase):
@parameterized.named_parameters(GetNamedTestParameters())
def testRoundTrip(self, dtype, shape):
np.random.seed(42)
np_array = np.random.randint(0, 10, shape)
tf_tensor = constant_op.constant(np_array, dtype=dtype)
dlcapsule = dlpack.to_dlpack(tf_tensor)
del tf_tensor # should still work
tf_tensor2 = dlpack.from_dlpack(dlcapsule)
self.assertAllClose(np_array, tf_tensor2)
def testTensorsCanBeConsumedOnceOnly(self):
np.random.seed(42)
np_array = np.random.randint(0, 10, (2, 3, 4))
tf_tensor = constant_op.constant(np_array, dtype=np.float32)
dlcapsule = dlpack.to_dlpack(tf_tensor)
del tf_tensor # should still work
_ = dlpack.from_dlpack(dlcapsule)
def ConsumeDLPackTensor():
dlpack.from_dlpack(dlcapsule) # Should can be consumed only once
self.assertRaisesRegex(Exception,
".*a DLPack tensor may be consumed at most once.*",
ConsumeDLPackTensor)
def testUnsupportedTypeToDLPack(self):
def UnsupportedQint16():
tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16)
_ = dlpack.to_dlpack(tf_tensor)
def UnsupportedComplex64():
tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64)
_ = dlpack.to_dlpack(tf_tensor)
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
UnsupportedQint16)
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
UnsupportedComplex64)
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

View File

@ -35,6 +35,7 @@ cc_library(
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:dlpack",
"//tensorflow/c/eager:tape", "//tensorflow/c/eager:tape",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -93,6 +94,7 @@ py_library(
":test", ":test",
":wrap_function", ":wrap_function",
"//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python/dlpack",
"//tensorflow/python/eager/memory_tests:memory_test_util", "//tensorflow/python/eager/memory_tests:memory_test_util",
], ],
) )

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/dlpack.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
@ -1047,6 +1048,50 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m.def("TF_NewBufferFromString", &TF_NewBufferFromString, m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
py::return_value_policy::reference); py::return_value_policy::reference);
// DLPack functions
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
PyObject* eager_tensor_pyobject_ptr = o.ptr();
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
py::capsule capsule(
dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
void* dlm_rptr =
PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
if (dlm_rptr) {
tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
PyCapsule_SetDestructor(capsule, nullptr);
}
}
});
return capsule;
});
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(pycapsule.name()) !=
tensorflow::kDlTensorCapsuleName) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
"Note that a DLPack tensor may be consumed at most once.",
absl::string_view(pycapsule.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle =
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
return py::handle(EagerTensorFromHandle(thandle));
});
// C API Enum // C API Enum
py::enum_<TFE_ContextDevicePlacementPolicy>( py::enum_<TFE_ContextDevicePlacementPolicy>(

View File

@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
"errors/__init__.py", "errors/__init__.py",
"experimental/__init__.py", "experimental/__init__.py",
"experimental/tensorrt/__init__.py", "experimental/tensorrt/__init__.py",
"experimental/dlpack/__init__.py",
"feature_column/__init__.py", "feature_column/__init__.py",
"io/gfile/__init__.py", "io/gfile/__init__.py",
"graph_util/__init__.py", "graph_util/__init__.py",

View File

@ -0,0 +1,11 @@
path: "tensorflow.experimental.dlpack"
tf_module {
member_method {
name: "from_dlpack"
argspec: "args=[\'dlcapsule\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "to_dlpack"
argspec: "args=[\'tf_tensor\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.experimental" path: "tensorflow.experimental"
tf_module { tf_module {
member {
name: "dlpack"
mtype: "<type \'module\'>"
}
member { member {
name: "tensorrt" name: "tensorrt"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -159,6 +159,7 @@ filegroup(
"@com_google_protobuf//:LICENSE", "@com_google_protobuf//:LICENSE",
"@com_googlesource_code_re2//:LICENSE", "@com_googlesource_code_re2//:LICENSE",
"@curl//:COPYING", "@curl//:COPYING",
"@dlpack//:LICENSE",
"@double_conversion//:LICENSE", "@double_conversion//:LICENSE",
"@eigen_archive//:COPYING.MPL2", "@eigen_archive//:COPYING.MPL2",
"@enum34_archive//:LICENSE", "@enum34_archive//:LICENSE",