From 4d291eba07ec612f2152fa1d0c6180d72c38c669 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Tue, 18 Feb 2020 17:14:31 +0000 Subject: [PATCH 01/20] dlpack --- tensorflow/c/eager/BUILD | 22 +++++++++++++++++ tensorflow/python/eager/BUILD | 1 + tensorflow/python/tfe_wrapper.cc | 41 ++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 5901ddb6182..566fef023ea 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -93,6 +93,7 @@ filegroup( "c_api_experimental.h", "c_api_internal.h", "tensor_handle_interface.h", + "dlpack.h", ], visibility = [ "//tensorflow/core:__pkg__", @@ -321,10 +322,31 @@ filegroup( srcs = [ "c_api.h", "c_api_experimental.h", + "dlpack.h", ], visibility = ["//tensorflow:__subpackages__"], ) + +cc_library( + name = "dlpack", + srcs = ["dlpack.cc"], + hdrs = ["dlpack.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/core:framework", + "@dlpack", + ], +) + + # TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime # right now, remove this public rule when no longer needed (it should be # replaced by TF Lite) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 65d07846cea..0a792bb2747 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -36,6 +36,7 @@ cc_library( "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:tape", + "//tensorflow/c/eager:dlpack", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 160b817d937..1f9bb5c434e 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/dlpack.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/compiler/jit/flags.h" @@ -1033,6 +1034,46 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TF_NewBufferFromString", &TF_NewBufferFromString, py::return_value_policy::reference); + // DLPack functions + m.def("TFE_ToDlpackCapsule", [](py::handle& o) { + PyObject* eager_tensor_pyobject_ptr = o.ptr(); + TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); + + py::capsule capsule( + dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* obj) { + void* dlm_rptr = + PyCapsule_GetPointer(obj, tensorflow::kDlTensorCapsuleName); + if (dlm_rptr) { + tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); + } else { + // The tensor has been deleted. Clear any error from + // PyCapsule_GetPointer. + PyErr_Clear(); + } + }); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return capsule; + }); + + m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) { + tensorflow::Safe_TF_StatusPtr status = + tensorflow::make_safe(TF_NewStatus()); + if (absl::string_view(pycapsule.name()) != + tensorflow::kDlTensorCapsuleName) { + status->status = tensorflow::errors::InvalidArgument( + "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " + "Note that a DLPack tensor may be consumed at most once.", + absl::string_view(pycapsule.name())); + } + TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( + static_cast(pycapsule), status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + return py::handle(EagerTensorFromHandle(thandle)); + }); + // C API Enum py::enum_( From 966d1e3155348ad9ceeb2e05a18ded05062b01a4 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Tue, 18 Feb 2020 17:15:00 +0000 Subject: [PATCH 02/20] dlpack --- tensorflow/python/tfe_wrapper.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 1f9bb5c434e..d9f603709ad 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1070,6 +1070,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { } TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( static_cast(pycapsule), status.get()); + + PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); + PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return py::handle(EagerTensorFromHandle(thandle)); }); From 8f51939a8d269f8d52e94e363a1bb54759378508 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Tue, 18 Feb 2020 17:15:13 +0000 Subject: [PATCH 03/20] dlpack --- tensorflow/python/tfe_wrapper.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index d9f603709ad..a7545a02904 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1043,11 +1043,12 @@ PYBIND11_MODULE(_pywrap_tfe, m) { void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); py::capsule capsule( - dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* obj) { + dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) { void* dlm_rptr = - PyCapsule_GetPointer(obj, tensorflow::kDlTensorCapsuleName); + PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); if (dlm_rptr) { tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); + PyCapsule_SetDestructor(capsule, nullptr); } else { // The tensor has been deleted. Clear any error from // PyCapsule_GetPointer. From 89c73caf126cdad3b6f3490658f5779f1fd9a05e Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Tue, 18 Feb 2020 17:56:29 +0000 Subject: [PATCH 04/20] dlpack --- tensorflow/c/eager/dlpack.cc | 342 +++++++++++++++++++++++++++++++++++ tensorflow/c/eager/dlpack.h | 26 +++ 2 files changed, 368 insertions(+) create mode 100644 tensorflow/c/eager/dlpack.cc create mode 100644 tensorflow/c/eager/dlpack.h diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc new file mode 100644 index 00000000000..6a927a09286 --- /dev/null +++ b/tensorflow/c/eager/dlpack.cc @@ -0,0 +1,342 @@ +#include "tensorflow/c/eager/dlpack.h" +#include "include/dlpack/dlpack.h" // TF:dlpack +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/casts.h" + +#include "tensorflow/core/framework/tensor_reference.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +using tensorflow::Tensor; +using tensorflow::TensorHandleInterface; + +namespace { + +struct TFDLMTensor { + TensorReference* handle; + DLManagedTensor tensor; +}; + +TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h, + TF_Status* status) { + if (h == nullptr || !h->handle->IsValid(&status->status)) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + tensorflow::TensorHandle* handle = + tensorflow::down_cast(h->handle.get()) + ->Handle(); + + if (handle->IsRemote()) { + status->status = tensorflow::errors::InvalidArgument( + "TFE_TensorHandleDevicePointer may not be called on a remote tensor " + "handle."); + return nullptr; + } + return handle; +} + +const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { + TensorHandle* handle = GetTensorHandleFromTFEHandle(h, status); + + if (handle->IsRemote()) { + status->status = tensorflow::errors::InvalidArgument( + "TFE_TensorHandleDevicePointer may not be called on a remote tensor " + "handle."); + return nullptr; + } + tensorflow::Device* device(absl::get(handle->device())); + if (device != nullptr) { + status->status = device->Sync(); + if (!status->status.ok()) { + return nullptr; + } + } + const tensorflow::Tensor* tensor; + status->status = handle->Tensor(&tensor); + if (!status->status.ok()) { + return nullptr; + } + return tensor; +}; + +void deleter(DLManagedTensor* arg) { + TFDLMTensor* owner = static_cast(arg->manager_ctx); + owner->handle->Unref(); + delete owner; +} + +DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) { + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = TF_DataTypeSize(data_type) * 8; + switch (data_type) { + case TF_DataType::TF_FLOAT: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case TF_DataType::TF_DOUBLE: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case TF_DataType::TF_INT32: + dtype.code = DLDataTypeCode::kDLInt; + break; + case TF_DataType::TF_UINT8: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case TF_DataType::TF_INT16: + dtype.code = DLDataTypeCode::kDLInt; + break; + case TF_DataType::TF_STRING: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case TF_DataType::TF_COMPLEX64: + status->status = tensorflow::errors::InvalidArgument( + "TF_COMPLEX64 is not supported by dlpack"); + break; + case TF_DataType::TF_INT64: + dtype.code = DLDataTypeCode::kDLInt; + break; + case TF_DataType::TF_BOOL: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case TF_DataType::TF_QINT8: + status->status = tensorflow::errors::InvalidArgument( + "TF_QINT8 is not supported by dlpack"); + break; + case TF_DataType::TF_QUINT8: + status->status = tensorflow::errors::InvalidArgument( + "TF_QUINT8 is not supported by dlpack"); + break; + case TF_DataType::TF_QINT32: + status->status = tensorflow::errors::InvalidArgument( + "TF_QINT32 is not supported by dlpack"); + break; + case TF_DataType::TF_BFLOAT16: + dtype.code = DLDataTypeCode::kDLBfloat; + break; + case TF_DataType::TF_QINT16: + status->status = tensorflow::errors::InvalidArgument( + "TF_QINT16 is not supported by dlpack"); + break; + case TF_DataType::TF_QUINT16: + status->status = tensorflow::errors::InvalidArgument( + "TF_QUINT16 is not supported by dlpack"); + break; + case TF_DataType::TF_COMPLEX128: + status->status = tensorflow::errors::InvalidArgument( + "TF_COMPLEX128 is not supported by dlpack"); + break; + case TF_DataType::TF_HALF: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case TF_DataType::TF_RESOURCE: + status->status = tensorflow::errors::InvalidArgument( + "TF_RESOURCE is not supported by dlpack"); + break; + case TF_DataType::TF_VARIANT: + status->status = tensorflow::errors::InvalidArgument( + "TF_VARIANT is not supported by dlpack"); + break; + case TF_DataType::TF_UINT32: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case TF_DataType::TF_UINT64: + dtype.code = DLDataTypeCode::kDLUInt; + break; + default: + status->status = tensorflow::errors::InvalidArgument( + "Unsupported TF_DataType is not supported by dlpack"); + break; + } + return dtype; +} + +DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) { + DLContext ctx; + const char* device_name = h->handle->DeviceName(&status->status); + DeviceNameUtils::ParsedName parsed_name; + tensorflow::DeviceNameUtils::ParseFullName(absl::string_view(device_name), + &parsed_name); + std::string device_type = parsed_name.type; + int device_id = -1; + if (parsed_name.has_id) { + device_id = parsed_name.id; + } // Question? device_id?=-1 + + ctx.device_id = device_id; + if (device_type == "CPU") { + ctx.device_type = DLDeviceType::kDLCPU; + } else if (device_type == "GPU") { + ctx.device_type = DLDeviceType::kDLGPU; + } else { + status->status = tensorflow::errors::InvalidArgument( + "Unsupported Device Type for DLPack"); + } + + return ctx; +} + +DLManagedTensor* TFEHandleToTFDLMTensor(TFE_TensorHandle* h, + TF_Status* status) { + const Tensor* tensor = GetTensorFromHandle(h, status); + TF_DataType data_type = static_cast(tensor->dtype()); + TFDLMTensor* tfDLMTensor(new TFDLMTensor); + TensorReference* tensor_ref = + new TensorReference(*tensor); // This will call buf_->Ref() + tfDLMTensor->handle = tensor_ref; + tfDLMTensor->tensor.manager_ctx = tfDLMTensor; + tfDLMTensor->tensor.deleter = &deleter; + tfDLMTensor->tensor.dl_tensor.ctx = getDLContext(h, status); + int ndim = tensor->dims(); + tfDLMTensor->tensor.dl_tensor.ndim = ndim; + tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); + tfDLMTensor->tensor.dl_tensor.dtype = getDLDataType(data_type, status); + + int64_t* shape_arr = new int64_t[ndim]; + for (int i = 0; i < ndim; i++) { + shape_arr[i] = tensor->dim_size(i); + } + + tfDLMTensor->tensor.dl_tensor.shape = shape_arr; + + tfDLMTensor->tensor.dl_tensor.strides = + nullptr; // Whether this is null at all the time? + tfDLMTensor->tensor.dl_tensor.byte_offset = + 0; // Whether this is 0 at all the time? + return &tfDLMTensor->tensor; +} + +std::string FromDLContext(const DLContext& ctx, TF_Status* status) { + switch (ctx.device_type) { + case DLDeviceType::kDLCPU: + return "CPU:0"; + case DLDeviceType::kDLGPU: + return absl::StrCat("GPU:", ctx.device_id); + default: + return ""; + }; +} +TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { + TF_DataType tf_dtype; + switch (dtype.code) { + case DLDataTypeCode::kDLUInt: + switch (dtype.bits) { + case 1: + tf_dtype = TF_DataType::TF_BOOL; + break; + case 8: + tf_dtype = TF_DataType::TF_UINT8; + break; + case 16: + tf_dtype = TF_DataType::TF_UINT16; + break; + case 32: + tf_dtype = TF_DataType::TF_UINT32; + break; + case 64: + tf_dtype = TF_DataType::TF_UINT64; + break; + default: + status->status = tensorflow::errors::InvalidArgument( + "Unsupported UInt bits", dtype.bits); + } + break; + case DLDataTypeCode::kDLInt: + switch (dtype.bits) { + case 8: + tf_dtype = TF_DataType::TF_INT8; + break; + case 16: + tf_dtype = TF_DataType::TF_INT16; + break; + case 32: + tf_dtype = TF_DataType::TF_INT32; + break; + case 64: + tf_dtype = TF_DataType::TF_INT64; + break; + default: + status->status = tensorflow::errors::InvalidArgument( + "Unsupported Int bits", dtype.bits); + } + break; + case DLDataTypeCode::kDLFloat: + switch (dtype.bits) { + case 16: + tf_dtype = TF_DataType::TF_HALF; + break; + case 32: + tf_dtype = TF_DataType::TF_FLOAT; + break; + case 64: + tf_dtype = TF_DataType::TF_DOUBLE; + break; + default: + status->status = tensorflow::errors::InvalidArgument( + "Unsupported Float bits", dtype.bits); + } + break; + case DLDataTypeCode::kDLBfloat: + switch (dtype.bits) { + case 16: + tf_dtype = TF_DataType::TF_BFLOAT16; + break; + default: + status->status = tensorflow::errors::InvalidArgument( + "Unsupported BFloat bits", dtype.bits); + } + break; + default: + status->status = tensorflow::errors::InvalidArgument( + "Unsupported Type Codes", dtype.code); + } + + return tf_dtype; +} + +void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr){ + DLManagedTensor* dlmt = static_cast(dlmt_vptr); + dlmt->deleter(const_cast(dlmt)); +} + +} // namespace + +void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { + DLManagedTensor* dlMTensor = static_cast(dlm_ptr); + if (dlMTensor) { + dlMTensor->deleter(const_cast(dlMTensor)); + } +} + +void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { + DLManagedTensor* tfdlmtensor = TFEHandleToTFDLMTensor(h, status); + return static_cast(tfdlmtensor); +} + +TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + DLManagedTensor* dlmt = static_cast(dlm); + + std::string device_name = FromDLContext(dlmt->dl_tensor.ctx, status); + TF_DataType dtype = FromDLDataType(dlmt->dl_tensor.dtype, status); + int num_dims = dlmt->dl_tensor.ndim; + const int64_t* dims = dlmt->dl_tensor.shape; + void* data = dlmt->dl_tensor.data; + + size_t total_bytes = dlmt->dl_tensor.dtype.bits / 8; + for (int i = 0; i < num_dims; i++) { + total_bytes *= dims[i]; + } + TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( + ctx, device_name.c_str(), dtype, dims, num_dims, data, total_bytes, + &DeallocatorWrapperFunc, &dlmt, status); + + return handle; +}; + +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h new file mode 100644 index 00000000000..7993a6ef8e0 --- /dev/null +++ b/tensorflow/c/eager/dlpack.h @@ -0,0 +1,26 @@ +#ifndef TENSORFLOW_C_DLPACK_H_ +#define TENSORFLOW_C_DLPACK_H_ + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/framework/tensor.h" + +#ifdef __cplusplus +extern "C" { +#endif + +namespace tensorflow { + +const char* const kDlTensorCapsuleName = "dltensor"; + +void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); + +TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); + +void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); +} // namespace tensorflow + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_DLPACK_H_ From 88d46f618496e294020b926d158dca59257dbb02 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 19 Feb 2020 11:59:16 +0000 Subject: [PATCH 05/20] address comment --- tensorflow/c/eager/dlpack.cc | 63 ++++++++++++++------------------ tensorflow/c/eager/dlpack.h | 23 ++++++++---- tensorflow/python/tfe_wrapper.cc | 6 +-- 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 6a927a09286..40186f39947 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -1,3 +1,18 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "tensorflow/c/eager/dlpack.h" #include "include/dlpack/dlpack.h" // TF:dlpack #include "tensorflow/c/eager/c_api_internal.h" @@ -10,18 +25,15 @@ namespace tensorflow { -using tensorflow::Tensor; -using tensorflow::TensorHandleInterface; - namespace { -struct TFDLMTensor { +struct TFDLManagedTensorCtx { TensorReference* handle; DLManagedTensor tensor; }; -TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h, - TF_Status* status) { + +const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( "The passed in handle is a nullptr"); @@ -37,25 +49,6 @@ TensorHandle* GetTensorHandleFromTFEHandle(TFE_TensorHandle* h, "handle."); return nullptr; } - return handle; -} - -const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { - TensorHandle* handle = GetTensorHandleFromTFEHandle(h, status); - - if (handle->IsRemote()) { - status->status = tensorflow::errors::InvalidArgument( - "TFE_TensorHandleDevicePointer may not be called on a remote tensor " - "handle."); - return nullptr; - } - tensorflow::Device* device(absl::get(handle->device())); - if (device != nullptr) { - status->status = device->Sync(); - if (!status->status.ok()) { - return nullptr; - } - } const tensorflow::Tensor* tensor; status->status = handle->Tensor(&tensor); if (!status->status.ok()) { @@ -64,13 +57,13 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { return tensor; }; -void deleter(DLManagedTensor* arg) { - TFDLMTensor* owner = static_cast(arg->manager_ctx); +void DLManagedTensorDeleter(DLManagedTensor* arg) { + TFDLManagedTensorCtx* owner = static_cast(arg->manager_ctx); owner->handle->Unref(); delete owner; } -DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) { +DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { DLDataType dtype; dtype.lanes = 1; dtype.bits = TF_DataTypeSize(data_type) * 8; @@ -155,7 +148,7 @@ DLDataType getDLDataType(TF_DataType data_type, TF_Status* status) { return dtype; } -DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) { +DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; const char* device_name = h->handle->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; @@ -180,21 +173,21 @@ DLContext getDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -DLManagedTensor* TFEHandleToTFDLMTensor(TFE_TensorHandle* h, +DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TFDLMTensor* tfDLMTensor(new TFDLMTensor); + TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx); TensorReference* tensor_ref = new TensorReference(*tensor); // This will call buf_->Ref() tfDLMTensor->handle = tensor_ref; tfDLMTensor->tensor.manager_ctx = tfDLMTensor; - tfDLMTensor->tensor.deleter = &deleter; - tfDLMTensor->tensor.dl_tensor.ctx = getDLContext(h, status); + tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter; + tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status); int ndim = tensor->dims(); tfDLMTensor->tensor.dl_tensor.ndim = ndim; tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - tfDLMTensor->tensor.dl_tensor.dtype = getDLDataType(data_type, status); + tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); int64_t* shape_arr = new int64_t[ndim]; for (int i = 0; i < ndim; i++) { @@ -313,7 +306,7 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { - DLManagedTensor* tfdlmtensor = TFEHandleToTFDLMTensor(h, status); + DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status); return static_cast(tfdlmtensor); } diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 7993a6ef8e0..43c205e30c4 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -1,13 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + #ifndef TENSORFLOW_C_DLPACK_H_ #define TENSORFLOW_C_DLPACK_H_ #include "tensorflow/c/eager/c_api.h" #include "tensorflow/core/framework/tensor.h" -#ifdef __cplusplus -extern "C" { -#endif - namespace tensorflow { const char* const kDlTensorCapsuleName = "dltensor"; @@ -19,8 +31,5 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow -#ifdef __cplusplus -} /* end extern "C" */ -#endif #endif // TENSORFLOW_C_DLPACK_H_ diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index a7545a02904..870693c8190 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1069,9 +1069,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { "Note that a DLPack tensor may be consumed at most once.", absl::string_view(pycapsule.name())); } - TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( - static_cast(pycapsule), status.get()); - + TFE_TensorHandle* thandle = + tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); + PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); From 883dcc553a24c8e5c09c84a5255f20aa18c7f1d7 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 19 Feb 2020 18:49:03 +0000 Subject: [PATCH 06/20] add python api and test --- tensorflow/c/eager/dlpack.cc | 6 +++ tensorflow/python/dlpack/BUILD | 22 ++++++++ tensorflow/python/dlpack/dlpack.py | 25 +++++++++ tensorflow/python/dlpack/dlpack_test.py | 68 +++++++++++++++++++++++++ tensorflow/python/tfe_wrapper.cc | 1 + 5 files changed, 122 insertions(+) create mode 100644 tensorflow/python/dlpack/BUILD create mode 100644 tensorflow/python/dlpack/dlpack.py create mode 100644 tensorflow/python/dlpack/dlpack_test.py diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 40186f39947..e0624ac4ca1 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -80,6 +80,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { case TF_DataType::TF_UINT8: dtype.code = DLDataTypeCode::kDLUInt; break; + case TF_DataType::TF_INT8: + dtype.code = DLDataTypeCode::kDLInt; + break; case TF_DataType::TF_INT16: dtype.code = DLDataTypeCode::kDLInt; break; @@ -119,6 +122,9 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { status->status = tensorflow::errors::InvalidArgument( "TF_QUINT16 is not supported by dlpack"); break; + case TF_DataType::TF_UINT16: + dtype.code = DLDataTypeCode::kDLUInt; + break; case TF_DataType::TF_COMPLEX128: status->status = tensorflow::errors::InvalidArgument( "TF_COMPLEX128 is not supported by dlpack"); diff --git a/tensorflow/python/dlpack/BUILD b/tensorflow/python/dlpack/BUILD new file mode 100644 index 00000000000..3c890ec8b8f --- /dev/null +++ b/tensorflow/python/dlpack/BUILD @@ -0,0 +1,22 @@ +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "dlpack", + srcs = ["dlpack.py"], + deps = [ + ], + srcs_version = "PY3", +) + +cuda_py_test( + name = "dlpack_test", + srcs = ["dlpack_test.py"], + python_version = "PY3", + deps = [ + ":dlpack", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ] +) \ No newline at end of file diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py new file mode 100644 index 00000000000..00be73f3670 --- /dev/null +++ b/tensorflow/python/dlpack/dlpack.py @@ -0,0 +1,25 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from tensorflow.python import pywrap_tfe +from tensorflow.python.util.tf_export import tf_export + +@tf_export("dlpack.to_dlpack") +def to_dlpack(tf_tensor): + return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) + +@tf_export("dlpack.from_dlpack") +def from_dlpack(dlcapsule): + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) \ No newline at end of file diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py new file mode 100644 index 00000000000..8384dfeadea --- /dev/null +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -0,0 +1,68 @@ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.framework import dtypes +from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack + +from absl.testing import absltest +from absl.testing import parameterized + +import numpy as np + +int_dtypes = [ + np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, + np.uint64 +] +float_dtypes = [np.float16, np.float32, np.float64] +complex_dtypes = [np.complex64, np.complex128] +dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] +standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + + +testcase_shapes = [ + (), + (1,), + (2, 3), + (2, 0), + (0, 7), + (4, 1, 2) +] + + +def FormatShapeAndDtype(shape, dtype): + return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) + + +class DLPackTest(parameterized.TestCase, test.TestCase): + + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + np.random.seed(42) + np_array = np.random.randint(0, 10, shape) + tf_tensor = constant_op.constant(np_array, dtype=dtype) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) + self.assertAllClose(np_array, tf_tensor2) + + def testTensorsCanBeConsumedOnceOnly(self): + np.random.seed(42) + np_array = np.random.randint(0, 10, (2, 3, 4)) + tf_tensor = constant_op.constant(np_array, dtype=np.float32) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) + + def ConsumeDLPackTensor(): + from_dlpack(dlcapsule) # Should can be consumed only once + self.assertRaisesRegex(Exception, + ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 870693c8190..0b801b4d51e 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1068,6 +1068,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) { "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " "Note that a DLPack tensor may be consumed at most once.", absl::string_view(pycapsule.name())); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); } TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); From e61323b14ef7b10e2d4440c38d69cfe1b06fa58a Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 19 Feb 2020 20:44:59 +0000 Subject: [PATCH 07/20] address comments --- tensorflow/c/eager/dlpack.cc | 156 ++++++++---------------- tensorflow/python/dlpack/dlpack_test.py | 19 ++- tensorflow/python/tfe_wrapper.cc | 21 ++-- 3 files changed, 80 insertions(+), 116 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index e0624ac4ca1..fa6fb77a2d7 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -32,7 +32,6 @@ struct TFDLManagedTensorCtx { DLManagedTensor tensor; }; - const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( @@ -45,8 +44,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (handle->IsRemote()) { status->status = tensorflow::errors::InvalidArgument( - "TFE_TensorHandleDevicePointer may not be called on a remote tensor " - "handle."); + "DLPack doesn't support remote tensor"); return nullptr; } const tensorflow::Tensor* tensor; @@ -58,7 +56,8 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { }; void DLManagedTensorDeleter(DLManagedTensor* arg) { - TFDLManagedTensorCtx* owner = static_cast(arg->manager_ctx); + TFDLManagedTensorCtx* owner = + static_cast(arg->manager_ctx); owner->handle->Unref(); delete owner; } @@ -68,103 +67,46 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { dtype.lanes = 1; dtype.bits = TF_DataTypeSize(data_type) * 8; switch (data_type) { + case TF_DataType::TF_HALF: case TF_DataType::TF_FLOAT: - dtype.code = DLDataTypeCode::kDLFloat; - break; case TF_DataType::TF_DOUBLE: dtype.code = DLDataTypeCode::kDLFloat; break; - case TF_DataType::TF_INT32: - dtype.code = DLDataTypeCode::kDLInt; - break; - case TF_DataType::TF_UINT8: - dtype.code = DLDataTypeCode::kDLUInt; - break; case TF_DataType::TF_INT8: - dtype.code = DLDataTypeCode::kDLInt; - break; case TF_DataType::TF_INT16: - dtype.code = DLDataTypeCode::kDLInt; - break; - case TF_DataType::TF_STRING: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case TF_DataType::TF_COMPLEX64: - status->status = tensorflow::errors::InvalidArgument( - "TF_COMPLEX64 is not supported by dlpack"); - break; + case TF_DataType::TF_INT32: case TF_DataType::TF_INT64: dtype.code = DLDataTypeCode::kDLInt; break; case TF_DataType::TF_BOOL: + case TF_DataType::TF_UINT8: + case TF_DataType::TF_UINT16: + case TF_DataType::TF_UINT32: + case TF_DataType::TF_UINT64: dtype.code = DLDataTypeCode::kDLUInt; break; - case TF_DataType::TF_QINT8: - status->status = tensorflow::errors::InvalidArgument( - "TF_QINT8 is not supported by dlpack"); - break; - case TF_DataType::TF_QUINT8: - status->status = tensorflow::errors::InvalidArgument( - "TF_QUINT8 is not supported by dlpack"); - break; - case TF_DataType::TF_QINT32: - status->status = tensorflow::errors::InvalidArgument( - "TF_QINT32 is not supported by dlpack"); - break; case TF_DataType::TF_BFLOAT16: dtype.code = DLDataTypeCode::kDLBfloat; break; - case TF_DataType::TF_QINT16: - status->status = tensorflow::errors::InvalidArgument( - "TF_QINT16 is not supported by dlpack"); - break; - case TF_DataType::TF_QUINT16: - status->status = tensorflow::errors::InvalidArgument( - "TF_QUINT16 is not supported by dlpack"); - break; - case TF_DataType::TF_UINT16: - dtype.code = DLDataTypeCode::kDLUInt; - break; - case TF_DataType::TF_COMPLEX128: - status->status = tensorflow::errors::InvalidArgument( - "TF_COMPLEX128 is not supported by dlpack"); - break; - case TF_DataType::TF_HALF: - dtype.code = DLDataTypeCode::kDLFloat; - break; - case TF_DataType::TF_RESOURCE: - status->status = tensorflow::errors::InvalidArgument( - "TF_RESOURCE is not supported by dlpack"); - break; - case TF_DataType::TF_VARIANT: - status->status = tensorflow::errors::InvalidArgument( - "TF_VARIANT is not supported by dlpack"); - break; - case TF_DataType::TF_UINT32: - dtype.code = DLDataTypeCode::kDLUInt; - break; - case TF_DataType::TF_UINT64: - dtype.code = DLDataTypeCode::kDLUInt; - break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported TF_DataType is not supported by dlpack"); + DataType_Name(static_cast(data_type)), + " is not supported by dlpack"); break; } return dtype; } -DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { +DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; const char* device_name = h->handle->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; - tensorflow::DeviceNameUtils::ParseFullName(absl::string_view(device_name), - &parsed_name); + tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; int device_id = -1; if (parsed_name.has_id) { device_id = parsed_name.id; - } // Question? device_id?=-1 + } // Question: Is it possible that it doens't have id? ctx.device_id = device_id; if (device_type == "CPU") { @@ -173,53 +115,55 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { ctx.device_type = DLDeviceType::kDLGPU; } else { status->status = tensorflow::errors::InvalidArgument( - "Unsupported Device Type for DLPack"); + "Unsupported Device Type for dlpack"); } return ctx; } DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, - TF_Status* status) { + TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TFDLManagedTensorCtx* tfDLMTensor(new TFDLManagedTensorCtx); + TFDLManagedTensorCtx* tf_dlm_tensor_ctx(new TFDLManagedTensorCtx); TensorReference* tensor_ref = new TensorReference(*tensor); // This will call buf_->Ref() - tfDLMTensor->handle = tensor_ref; - tfDLMTensor->tensor.manager_ctx = tfDLMTensor; - tfDLMTensor->tensor.deleter = &DLManagedTensorDeleter; - tfDLMTensor->tensor.dl_tensor.ctx = GetDLContext(h, status); + tf_dlm_tensor_ctx->handle = tensor_ref; + tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx; + tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter; + tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status); int ndim = tensor->dims(); - tfDLMTensor->tensor.dl_tensor.ndim = ndim; - tfDLMTensor->tensor.dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - tfDLMTensor->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); + tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim; + tf_dlm_tensor_ctx->tensor.dl_tensor.data = + TFE_TensorHandleDevicePointer(h, status); + tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); int64_t* shape_arr = new int64_t[ndim]; for (int i = 0; i < ndim; i++) { shape_arr[i] = tensor->dim_size(i); } - tfDLMTensor->tensor.dl_tensor.shape = shape_arr; + tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr; - tfDLMTensor->tensor.dl_tensor.strides = - nullptr; // Whether this is null at all the time? - tfDLMTensor->tensor.dl_tensor.byte_offset = - 0; // Whether this is 0 at all the time? - return &tfDLMTensor->tensor; + tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr; + tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = + 0; // TF doesn't handle the strides and byte_offsets here + return &tf_dlm_tensor_ctx->tensor; } -std::string FromDLContext(const DLContext& ctx, TF_Status* status) { +absl::optional DeviceNameFromDlContext(const DLContext& ctx, + TF_Status* status) { switch (ctx.device_type) { case DLDeviceType::kDLCPU: return "CPU:0"; case DLDeviceType::kDLGPU: return absl::StrCat("GPU:", ctx.device_id); default: - return ""; + return absl::nullopt; }; } -TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { +TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, + TF_Status* status) { TF_DataType tf_dtype; switch (dtype.code) { case DLDataTypeCode::kDLUInt: @@ -241,7 +185,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported UInt bits", dtype.bits); + "Unsupported UInt bits: ", dtype.bits); } break; case DLDataTypeCode::kDLInt: @@ -260,7 +204,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported Int bits", dtype.bits); + "Unsupported Int bits: ", dtype.bits); } break; case DLDataTypeCode::kDLFloat: @@ -276,7 +220,7 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported Float bits", dtype.bits); + "Unsupported Float bits: ", dtype.bits); } break; case DLDataTypeCode::kDLBfloat: @@ -286,20 +230,20 @@ TF_DataType FromDLDataType(const DLDataType& dtype, TF_Status* status) { break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported BFloat bits", dtype.bits); + "Unsupported BFloat bits: ", dtype.bits); } break; default: status->status = tensorflow::errors::InvalidArgument( - "Unsupported Type Codes", dtype.code); + "Unsupported Type Codes: ", dtype.code); } return tf_dtype; } -void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr){ - DLManagedTensor* dlmt = static_cast(dlmt_vptr); - dlmt->deleter(const_cast(dlmt)); +void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { + DLManagedTensor* dlmt = static_cast(dlmt_vptr); + dlmt->deleter(const_cast(dlmt)); } } // namespace @@ -321,8 +265,14 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { TFE_Context* ctx = TFE_NewContext(opts, status); DLManagedTensor* dlmt = static_cast(dlm); - std::string device_name = FromDLContext(dlmt->dl_tensor.ctx, status); - TF_DataType dtype = FromDLDataType(dlmt->dl_tensor.dtype, status); + absl::optional device_name = + DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status); + if (!device_name.has_value()) { + status->status = + tensorflow::errors::InvalidArgument("Unsupported Device Type"); + return nullptr; + } + TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status); int num_dims = dlmt->dl_tensor.ndim; const int64_t* dims = dlmt->dl_tensor.shape; void* data = dlmt->dl_tensor.data; @@ -332,8 +282,8 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { total_bytes *= dims[i]; } TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( - ctx, device_name.c_str(), dtype, dims, num_dims, data, total_bytes, - &DeallocatorWrapperFunc, &dlmt, status); + ctx, device_name.value().c_str(), dtype, dims, num_dims, data, + total_bytes, &DeallocatorWrapperFunc, &dlmt, status); return handle; }; diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index 8384dfeadea..8a4f1788446 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -16,7 +16,6 @@ int_dtypes = [ float_dtypes = [np.float16, np.float32, np.float64] complex_dtypes = [np.complex64, np.complex128] dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] -standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] testcase_shapes = [ @@ -55,13 +54,29 @@ class DLPackTest(parameterized.TestCase, test.TestCase): dlcapsule = to_dlpack(tf_tensor) del tf_tensor # should still work tf_tensor2 = from_dlpack(dlcapsule) - + def ConsumeDLPackTensor(): from_dlpack(dlcapsule) # Should can be consumed only once self.assertRaisesRegex(Exception, ".*a DLPack tensor may be consumed at most once.*", ConsumeDLPackTensor) + def testUnsupportedType(self): + def case1(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.qint16) + dlcapsule = to_dlpack(tf_tensor) + + def case2(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.complex64) + dlcapsule = to_dlpack(tf_tensor) + + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case1) + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case2) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 0b801b4d51e..4e837c765c4 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1041,21 +1041,19 @@ PYBIND11_MODULE(_pywrap_tfe, m) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); py::capsule capsule( dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) { - void* dlm_rptr = - PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); - if (dlm_rptr) { - tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); - PyCapsule_SetDestructor(capsule, nullptr); - } else { - // The tensor has been deleted. Clear any error from - // PyCapsule_GetPointer. - PyErr_Clear(); + if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) { + void* dlm_rptr = + PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); + if (dlm_rptr) { + tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); + PyCapsule_SetDestructor(capsule, nullptr); + } } }); - tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return capsule; }); @@ -1072,10 +1070,11 @@ PYBIND11_MODULE(_pywrap_tfe, m) { } TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); + + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); - tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); return py::handle(EagerTensorFromHandle(thandle)); }); From 3f218b1cd85d79e854641535ac7b2853c63ec289 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 19 Feb 2020 20:49:08 +0000 Subject: [PATCH 08/20] fix --- tensorflow/c/eager/dlpack.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index fa6fb77a2d7..f78df720511 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -125,7 +125,8 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - TFDLManagedTensorCtx* tf_dlm_tensor_ctx(new TFDLManagedTensorCtx); + auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx; + TensorReference* tensor_ref = new TensorReference(*tensor); // This will call buf_->Ref() tf_dlm_tensor_ctx->handle = tensor_ref; From 295a8b30d51c2fedff57a2e2f6d6479f7fdbbf67 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 09:30:28 +0000 Subject: [PATCH 09/20] fix android and windows building --- tensorflow/c/eager/BUILD | 1 + tensorflow/c/eager/dlpack.cc | 2 ++ tensorflow/c/eager/dlpack.h | 6 +++--- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 566fef023ea..224b36a170c 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -360,6 +360,7 @@ filegroup( exclude = [ "c_api_experimental.cc", "*test*", + "*dlpack*", ], ), visibility = ["//visibility:public"], diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index f78df720511..95213948859 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -59,6 +59,8 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) { TFDLManagedTensorCtx* owner = static_cast(arg->manager_ctx); owner->handle->Unref(); + delete owner->handle; + delete owner->tensor.dl_tensor.shape; delete owner; } diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 43c205e30c4..35dfb682114 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -24,11 +24,11 @@ namespace tensorflow { const char* const kDlTensorCapsuleName = "dltensor"; -void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); -TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); -void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); +TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow From cae654ea9989ddeafa6a02fa9add9d6cf806f848 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 09:33:20 +0000 Subject: [PATCH 10/20] space --- tensorflow/c/eager/dlpack.cc | 2 +- tensorflow/python/dlpack/BUILD | 2 +- tensorflow/python/dlpack/dlpack.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 95213948859..8c4c70bf453 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -291,4 +291,4 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { return handle; }; -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow/python/dlpack/BUILD b/tensorflow/python/dlpack/BUILD index 3c890ec8b8f..c5347c020e5 100644 --- a/tensorflow/python/dlpack/BUILD +++ b/tensorflow/python/dlpack/BUILD @@ -19,4 +19,4 @@ cuda_py_test( "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", ] -) \ No newline at end of file +) diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 00be73f3670..5e364fdc593 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -22,4 +22,4 @@ def to_dlpack(tf_tensor): @tf_export("dlpack.from_dlpack") def from_dlpack(dlcapsule): - return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) \ No newline at end of file + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) From 6ef713fec6100c59d1c1f16c825a841ea5c34625 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 12:03:11 +0000 Subject: [PATCH 11/20] fix sanity check --- tensorflow/c/eager/BUILD | 4 +- tensorflow/python/dlpack/BUILD | 10 ++- tensorflow/python/dlpack/dlpack.py | 11 ++- tensorflow/python/dlpack/dlpack_test.py | 100 ++++++++++++++---------- tensorflow/python/eager/BUILD | 3 +- tensorflow/tools/pip_package/BUILD | 1 + 6 files changed, 79 insertions(+), 50 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 224b36a170c..509a6205274 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -92,8 +92,8 @@ filegroup( srcs = [ "c_api_experimental.h", "c_api_internal.h", - "tensor_handle_interface.h", "dlpack.h", + "tensor_handle_interface.h", ], visibility = [ "//tensorflow/core:__pkg__", @@ -327,7 +327,6 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) - cc_library( name = "dlpack", srcs = ["dlpack.cc"], @@ -346,7 +345,6 @@ cc_library( ], ) - # TODO(karllessard): only used by //tensorflow/core:mobile_srcs_only_runtime # right now, remove this public rule when no longer needed (it should be # replaced by TF Lite) diff --git a/tensorflow/python/dlpack/BUILD b/tensorflow/python/dlpack/BUILD index c5347c020e5..4e1b3c47070 100644 --- a/tensorflow/python/dlpack/BUILD +++ b/tensorflow/python/dlpack/BUILD @@ -4,9 +4,11 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "dlpack", srcs = ["dlpack.py"], - deps = [ - ], srcs_version = "PY3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/python:pywrap_tensorflow", + ], ) cuda_py_test( @@ -14,9 +16,9 @@ cuda_py_test( srcs = ["dlpack_test.py"], python_version = "PY3", deps = [ - ":dlpack", + ":dlpack", "//tensorflow/python/eager:test", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", - ] + ], ) diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 5e364fdc593..601dffad847 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""DLPack modules for Tensorflow""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function from tensorflow.python import pywrap_tfe from tensorflow.python.util.tf_export import tf_export + @tf_export("dlpack.to_dlpack") def to_dlpack(tf_tensor): - return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) + return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) + @tf_export("dlpack.from_dlpack") def from_dlpack(dlcapsule): - return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index 8a4f1788446..8b47c71dc6b 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -1,3 +1,23 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for DLPack functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -29,55 +49,55 @@ testcase_shapes = [ def FormatShapeAndDtype(shape, dtype): - return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) + return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) class DLPackTest(parameterized.TestCase, test.TestCase): - @parameterized.named_parameters({ - "testcase_name": FormatShapeAndDtype(shape, dtype), - "dtype": dtype, - "shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes) - def testRoundTrip(self, dtype, shape): - np.random.seed(42) - np_array = np.random.randint(0, 10, shape) - tf_tensor = constant_op.constant(np_array, dtype=dtype) - dlcapsule = to_dlpack(tf_tensor) - del tf_tensor # should still work - tf_tensor2 = from_dlpack(dlcapsule) - self.assertAllClose(np_array, tf_tensor2) + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape} for dtype in dlpack_dtypes for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + np.random.seed(42) + np_array = np.random.randint(0, 10, shape) + tf_tensor = constant_op.constant(np_array, dtype=dtype) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) + self.assertAllClose(np_array, tf_tensor2) - def testTensorsCanBeConsumedOnceOnly(self): - np.random.seed(42) - np_array = np.random.randint(0, 10, (2, 3, 4)) - tf_tensor = constant_op.constant(np_array, dtype=np.float32) - dlcapsule = to_dlpack(tf_tensor) - del tf_tensor # should still work - tf_tensor2 = from_dlpack(dlcapsule) + def testTensorsCanBeConsumedOnceOnly(self): + np.random.seed(42) + np_array = np.random.randint(0, 10, (2, 3, 4)) + tf_tensor = constant_op.constant(np_array, dtype=np.float32) + dlcapsule = to_dlpack(tf_tensor) + del tf_tensor # should still work + tf_tensor2 = from_dlpack(dlcapsule) - def ConsumeDLPackTensor(): - from_dlpack(dlcapsule) # Should can be consumed only once - self.assertRaisesRegex(Exception, - ".*a DLPack tensor may be consumed at most once.*", - ConsumeDLPackTensor) + def ConsumeDLPackTensor(): + from_dlpack(dlcapsule) # Should can be consumed only once + self.assertRaisesRegex(Exception, + ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) - def testUnsupportedType(self): - def case1(): - tf_tensor = constant_op.constant( - [[1, 4], [5, 2]], dtype=dtypes.qint16) - dlcapsule = to_dlpack(tf_tensor) + def testUnsupportedType(self): + def case1(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.qint16) + dlcapsule = to_dlpack(tf_tensor) - def case2(): - tf_tensor = constant_op.constant( - [[1, 4], [5, 2]], dtype=dtypes.complex64) - dlcapsule = to_dlpack(tf_tensor) + def case2(): + tf_tensor = constant_op.constant( + [[1, 4], [5, 2]], dtype=dtypes.complex64) + dlcapsule = to_dlpack(tf_tensor) - self.assertRaisesRegex( - Exception, ".* is not supported by dlpack", case1) - self.assertRaisesRegex( - Exception, ".* is not supported by dlpack", case2) + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case1) + self.assertRaisesRegex( + Exception, ".* is not supported by dlpack", case2) if __name__ == '__main__': - ops.enable_eager_execution() - test.main() + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 0a792bb2747..d6e1e07b6e2 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -35,8 +35,8 @@ cc_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_internal", - "//tensorflow/c/eager:tape", "//tensorflow/c/eager:dlpack", + "//tensorflow/c/eager:tape", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -94,6 +94,7 @@ py_library( ":test", ":wrap_function", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/dlpack", "//tensorflow/python/eager/memory_tests:memory_test_util", ], ) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index f6e17a6e46c..0e2ba08d1a7 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -159,6 +159,7 @@ filegroup( "@com_google_protobuf//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@curl//:COPYING", + "@dlpack//:LICENSE", "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", "@enum34_archive//:LICENSE", From 48a353cdaf5697b843aa37299a49201d7f4541e8 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 14:14:51 +0000 Subject: [PATCH 12/20] fix leak --- tensorflow/c/eager/dlpack.cc | 23 ++++++++++++----------- tensorflow/python/dlpack/dlpack.py | 1 + 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 8c4c70bf453..f982e483bbc 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -29,6 +29,8 @@ namespace { struct TFDLManagedTensorCtx { TensorReference* handle; + std::vector shape; + std::vector strides; DLManagedTensor tensor; }; @@ -60,7 +62,6 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) { static_cast(arg->manager_ctx); owner->handle->Unref(); delete owner->handle; - delete owner->tensor.dl_tensor.shape; delete owner; } @@ -141,14 +142,17 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, TFE_TensorHandleDevicePointer(h, status); tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); - int64_t* shape_arr = new int64_t[ndim]; + std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; + std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; + shape_arr->resize(ndim); + stride_arr->resize(ndim); for (int i = 0; i < ndim; i++) { - shape_arr[i] = tensor->dim_size(i); + (*shape_arr)[i] = tensor->dim_size(i); + (*stride_arr)[i] = 1; } - tf_dlm_tensor_ctx->tensor.dl_tensor.shape = shape_arr; - - tf_dlm_tensor_ctx->tensor.dl_tensor.strides = nullptr; + tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast(shape_arr->data()); + tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast(stride_arr->data()); tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; @@ -171,9 +175,6 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, switch (dtype.code) { case DLDataTypeCode::kDLUInt: switch (dtype.bits) { - case 1: - tf_dtype = TF_DataType::TF_BOOL; - break; case 8: tf_dtype = TF_DataType::TF_UINT8; break; @@ -253,8 +254,8 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { DLManagedTensor* dlMTensor = static_cast(dlm_ptr); - if (dlMTensor) { - dlMTensor->deleter(const_cast(dlMTensor)); + if (dlMTensor->deleter != nullptr) { + dlMTensor->deleter(dlMTensor); } } diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 601dffad847..7a04fca3933 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -22,6 +22,7 @@ from tensorflow.python import pywrap_tfe from tensorflow.python.util.tf_export import tf_export +# tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix? @tf_export("dlpack.to_dlpack") def to_dlpack(tf_tensor): return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) From 59d8c5b6c0d91532b081f342ebfffd7fb464d5df Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Fri, 21 Feb 2020 15:22:48 +0000 Subject: [PATCH 13/20] fix --- tensorflow/c/eager/dlpack.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index f982e483bbc..fbe64499d3f 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -30,7 +30,6 @@ namespace { struct TFDLManagedTensorCtx { TensorReference* handle; std::vector shape; - std::vector strides; DLManagedTensor tensor; }; @@ -145,14 +144,14 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; shape_arr->resize(ndim); - stride_arr->resize(ndim); for (int i = 0; i < ndim; i++) { - (*shape_arr)[i] = tensor->dim_size(i); - (*stride_arr)[i] = 1; + (*shape_arr)[i] = tensor->dim_size(i); } - tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast(shape_arr->data()); - tf_dlm_tensor_ctx->tensor.dl_tensor.strides = reinterpret_cast(stride_arr->data()); + tf_dlm_tensor_ctx->tensor.dl_tensor.shape = + reinterpret_cast(shape_arr->data()); + tf_dlm_tensor_ctx->tensor.dl_tensor.strides = + nullptr; // NULL indicates tensor is compact and row-majored. tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; From 90a55447a7f4bd9c7c215f82ddb9a0d539bc1bee Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Sat, 22 Feb 2020 13:01:04 +0000 Subject: [PATCH 14/20] fix copyright --- tensorflow/c/eager/dlpack.cc | 27 ++++++++++++++------------- tensorflow/c/eager/dlpack.h | 4 ++-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index fbe64499d3f..ce36a5f3a10 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,10 +27,12 @@ namespace tensorflow { namespace { -struct TFDLManagedTensorCtx { - TensorReference* handle; +struct TfDlManagedTensorCtx { + TensorReference* reference; std::vector shape; DLManagedTensor tensor; + + TfDlManagedTensorCtx() }; const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { @@ -57,10 +59,10 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { }; void DLManagedTensorDeleter(DLManagedTensor* arg) { - TFDLManagedTensorCtx* owner = - static_cast(arg->manager_ctx); - owner->handle->Unref(); - delete owner->handle; + TfDlManagedTensorCtx* owner = + static_cast(arg->manager_ctx); + owner->reference->Unref(); + delete owner->reference; delete owner; } @@ -123,15 +125,15 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, +DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - auto* tf_dlm_tensor_ctx = new TFDLManagedTensorCtx; + auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx; TensorReference* tensor_ref = new TensorReference(*tensor); // This will call buf_->Ref() - tf_dlm_tensor_ctx->handle = tensor_ref; + tf_dlm_tensor_ctx->reference = tensor_ref; tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx; tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter; tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status); @@ -142,7 +144,6 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; - std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; shape_arr->resize(ndim); for (int i = 0; i < ndim; i++) { (*shape_arr)[i] = tensor->dim_size(i); @@ -151,7 +152,7 @@ DLManagedTensor* TFEHandleToTFDLManagedTensorCtx(TFE_TensorHandle* h, tf_dlm_tensor_ctx->tensor.dl_tensor.shape = reinterpret_cast(shape_arr->data()); tf_dlm_tensor_ctx->tensor.dl_tensor.strides = - nullptr; // NULL indicates tensor is compact and row-majored. + nullptr; // nullptr indicates tensor is compact and row-majored. tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; @@ -259,7 +260,7 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { - DLManagedTensor* tfdlmtensor = TFEHandleToTFDLManagedTensorCtx(h, status); + DLManagedTensor* tfdlmtensor = TFEHandleToTfDlManagedTensorCtx(h, status); return static_cast(tfdlmtensor); } diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 35dfb682114..b563bc24495 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -const char* const kDlTensorCapsuleName = "dltensor"; +TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor"; TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); From 7aa5009b359a0704ec23021187a687d2476361e5 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Sat, 22 Feb 2020 16:28:58 +0000 Subject: [PATCH 15/20] fix --- tensorflow/c/eager/dlpack.cc | 82 +++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index ce36a5f3a10..fdc439da4b5 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -27,12 +27,17 @@ namespace tensorflow { namespace { +// Managing context for the DLManagedTensor, will manage the lifetime of +// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the +// original framework of destruction, and this context will be deleted also. struct TfDlManagedTensorCtx { - TensorReference* reference; + TensorReference reference; std::vector shape; + std::vector strides; DLManagedTensor tensor; - TfDlManagedTensorCtx() + TfDlManagedTensorCtx(const TensorReference& ref) + : reference(ref), shape(), tensor() {} }; const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { @@ -61,8 +66,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { void DLManagedTensorDeleter(DLManagedTensor* arg) { TfDlManagedTensorCtx* owner = static_cast(arg->manager_ctx); - owner->reference->Unref(); - delete owner->reference; + owner->reference.Unref(); delete owner; } @@ -129,31 +133,41 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, TF_Status* status) { const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast(tensor->dtype()); - auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx; + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() - TensorReference* tensor_ref = - new TensorReference(*tensor); // This will call buf_->Ref() + auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); tf_dlm_tensor_ctx->reference = tensor_ref; - tf_dlm_tensor_ctx->tensor.manager_ctx = tf_dlm_tensor_ctx; - tf_dlm_tensor_ctx->tensor.deleter = &DLManagedTensorDeleter; - tf_dlm_tensor_ctx->tensor.dl_tensor.ctx = GetDLContext(h, status); + + DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; + dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; + dlm_tensor->deleter = &DLManagedTensorDeleter; + dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); int ndim = tensor->dims(); - tf_dlm_tensor_ctx->tensor.dl_tensor.ndim = ndim; - tf_dlm_tensor_ctx->tensor.dl_tensor.data = - TFE_TensorHandleDevicePointer(h, status); - tf_dlm_tensor_ctx->tensor.dl_tensor.dtype = GetDLDataType(data_type, status); + dlm_tensor->dl_tensor.ndim = ndim; + dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); + dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; + std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; shape_arr->resize(ndim); + stride_arr->resize(ndim, 1); for (int i = 0; i < ndim; i++) { (*shape_arr)[i] = tensor->dim_size(i); } + for (int i = ndim - 2; i >= 0; --i) { + (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; + } - tf_dlm_tensor_ctx->tensor.dl_tensor.shape = + dlm_tensor->dl_tensor.shape = reinterpret_cast(shape_arr->data()); - tf_dlm_tensor_ctx->tensor.dl_tensor.strides = - nullptr; // nullptr indicates tensor is compact and row-majored. - tf_dlm_tensor_ctx->tensor.dl_tensor.byte_offset = + // There are two ways to represent compact row-major data + // 1) nullptr indicates tensor is compact and row-majored. + // 2) fill in the strides array as the real case for compact row-major data + // Here we choose option 2, since some framework didn't handle the strides + // argument properly + dlm_tensor->dl_tensor.strides = + reinterpret_cast(stride_arr->data()); + dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; } @@ -250,6 +264,15 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { dlmt->deleter(const_cast(dlmt)); } +bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, + int ndim) { + for (int i = ndim - 2; i >= 0; --i) { + if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { + return false; + }; + } + return true; +} } // namespace void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { @@ -268,23 +291,32 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); DLManagedTensor* dlmt = static_cast(dlm); - + DLTensor* dl_tensor = &dlmt->dl_tensor; absl::optional device_name = - DeviceNameFromDlContext(dlmt->dl_tensor.ctx, status); + DeviceNameFromDlContext(dl_tensor->ctx, status); if (!device_name.has_value()) { status->status = tensorflow::errors::InvalidArgument("Unsupported Device Type"); return nullptr; } - TF_DataType dtype = TfDataTypeFormDlDataType(dlmt->dl_tensor.dtype, status); - int num_dims = dlmt->dl_tensor.ndim; - const int64_t* dims = dlmt->dl_tensor.shape; - void* data = dlmt->dl_tensor.data; + TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status); + int num_dims = dl_tensor->ndim; + const int64_t* dims = dl_tensor->shape; + void* data = dl_tensor->data; - size_t total_bytes = dlmt->dl_tensor.dtype.bits / 8; + size_t total_bytes = dl_tensor->dtype.bits / 8; for (int i = 0; i < num_dims; i++) { total_bytes *= dims[i]; } + + if ((dl_tensor->strides != nullptr) && + !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, + num_dims)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid strides array from DLPack"); + return nullptr; + } + TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( ctx, device_name.value().c_str(), dtype, dims, num_dims, data, total_bytes, &DeallocatorWrapperFunc, &dlmt, status); From 61da5aaff3852fc4a114cb5325a1feba8d87432e Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Mon, 24 Feb 2020 09:06:29 +0000 Subject: [PATCH 16/20] fix --- tensorflow/c/eager/dlpack.cc | 11 ++++++----- tensorflow/c/eager/dlpack.h | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index fdc439da4b5..794550e840a 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -158,15 +158,13 @@ DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; } - dlm_tensor->dl_tensor.shape = - reinterpret_cast(shape_arr->data()); + dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. // 2) fill in the strides array as the real case for compact row-major data // Here we choose option 2, since some framework didn't handle the strides // argument properly - dlm_tensor->dl_tensor.strides = - reinterpret_cast(stride_arr->data()); + dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return &tf_dlm_tensor_ctx->tensor; @@ -266,6 +264,9 @@ void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, int ndim) { + if (ndim >= 1 && stride_arr[ndim - 1] != 1) { + return false; + } for (int i = ndim - 2; i >= 0; --i) { if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { return false; @@ -309,7 +310,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { total_bytes *= dims[i]; } - if ((dl_tensor->strides != nullptr) && + if (dl_tensor->strides != nullptr && !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, num_dims)) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index b563bc24495..f656f4393f6 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -TF_CAPI_EXPORT extern const char* const kDlTensorCapsuleName = "dltensor"; +const char* const kDlTensorCapsuleName = "dltensor"; TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); From 7c3ac77ee1582f0aa8546ecb74de2b34235a3e3c Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Mon, 24 Feb 2020 11:21:23 +0000 Subject: [PATCH 17/20] fix --- tensorflow/c/eager/dlpack.cc | 90 +++++++++++++++++++----------------- tensorflow/c/eager/dlpack.h | 13 ++++-- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 794550e840a..0ec8321230c 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -40,6 +40,7 @@ struct TfDlManagedTensorCtx { : reference(ref), shape(), tensor() {} }; +// Get tensor from eager tensor handle const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( @@ -63,6 +64,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { return tensor; }; +// Deleter for DLManagedTensor void DLManagedTensorDeleter(DLManagedTensor* arg) { TfDlManagedTensorCtx* owner = static_cast(arg->manager_ctx); @@ -129,47 +131,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -DLManagedTensor* TFEHandleToTfDlManagedTensorCtx(TFE_TensorHandle* h, - TF_Status* status) { - const Tensor* tensor = GetTensorFromHandle(h, status); - TF_DataType data_type = static_cast(tensor->dtype()); - TensorReference tensor_ref(*tensor); // This will call buf_->Ref() - - auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); - tf_dlm_tensor_ctx->reference = tensor_ref; - - DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; - dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; - dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); - int ndim = tensor->dims(); - dlm_tensor->dl_tensor.ndim = ndim; - dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); - - std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; - std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; - shape_arr->resize(ndim); - stride_arr->resize(ndim, 1); - for (int i = 0; i < ndim; i++) { - (*shape_arr)[i] = tensor->dim_size(i); - } - for (int i = ndim - 2; i >= 0; --i) { - (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; - } - - dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; - // There are two ways to represent compact row-major data - // 1) nullptr indicates tensor is compact and row-majored. - // 2) fill in the strides array as the real case for compact row-major data - // Here we choose option 2, since some framework didn't handle the strides - // argument properly - dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; - dlm_tensor->dl_tensor.byte_offset = - 0; // TF doesn't handle the strides and byte_offsets here - return &tf_dlm_tensor_ctx->tensor; -} - +// Convert DLContext to TF device name absl::optional DeviceNameFromDlContext(const DLContext& ctx, TF_Status* status) { switch (ctx.device_type) { @@ -181,6 +143,8 @@ absl::optional DeviceNameFromDlContext(const DLContext& ctx, return absl::nullopt; }; } + +// Convert DLPack data type to TF_DATATYPE TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, TF_Status* status) { TF_DataType tf_dtype; @@ -257,11 +221,16 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, return tf_dtype; } +// Wrapper function to match the function signature +// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the +// DLManagedTensor void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { DLManagedTensor* dlmt = static_cast(dlmt_vptr); dlmt->deleter(const_cast(dlmt)); } +// Check whether the stride array matches the layout of compact, row-majored +// data bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, int ndim) { if (ndim >= 1 && stride_arr[ndim - 1] != 1) { @@ -284,8 +253,43 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { - DLManagedTensor* tfdlmtensor = TFEHandleToTfDlManagedTensorCtx(h, status); - return static_cast(tfdlmtensor); + const Tensor* tensor = GetTensorFromHandle(h, status); + TF_DataType data_type = static_cast(tensor->dtype()); + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() + + auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); + tf_dlm_tensor_ctx->reference = tensor_ref; + + DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; + dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; + dlm_tensor->deleter = &DLManagedTensorDeleter; + dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); + int ndim = tensor->dims(); + dlm_tensor->dl_tensor.ndim = ndim; + dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); + dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); + + std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; + std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; + shape_arr->resize(ndim); + stride_arr->resize(ndim, 1); + for (int i = 0; i < ndim; i++) { + (*shape_arr)[i] = tensor->dim_size(i); + } + for (int i = ndim - 2; i >= 0; --i) { + (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; + } + + dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; + // There are two ways to represent compact row-major data + // 1) nullptr indicates tensor is compact and row-majored. + // 2) fill in the strides array as the real case for compact row-major data + // Here we choose option 2, since some framework didn't handle the strides + // argument properly + dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; + dlm_tensor->dl_tensor.byte_offset = + 0; // TF doesn't handle the strides and byte_offsets here + return static_cast(dlm_tensor); } TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index f656f4393f6..21ee37b78d8 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_C_DLPACK_H_ #define TENSORFLOW_C_DLPACK_H_ @@ -22,14 +21,20 @@ limitations under the License. namespace tensorflow { +// PyCapsule name for DLPack Tensor const char* const kDlTensorCapsuleName = "dltensor"; -TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); +// Convert eager tensor handle to DLPack (DLManagedTensor*), and return the +// void* for further PyCapsule construction +TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, + TF_Status* status); -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); +// Convert DLPack (DLManagedTensor*) to eager tensor handle +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, + TF_Status* status); +// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow - #endif // TENSORFLOW_C_DLPACK_H_ From b90808b7b4568a7a58992248a8247405667c5c36 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Tue, 25 Feb 2020 16:45:19 +0000 Subject: [PATCH 18/20] fix --- tensorflow/c/eager/dlpack.cc | 38 ++++++++++++++++++------------------ tensorflow/c/eager/dlpack.h | 8 ++++---- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 0ec8321230c..6276371bd68 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -36,11 +36,10 @@ struct TfDlManagedTensorCtx { std::vector strides; DLManagedTensor tensor; - TfDlManagedTensorCtx(const TensorReference& ref) - : reference(ref), shape(), tensor() {} + TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {} }; -// Get tensor from eager tensor handle +// Gets tensor from eager tensor handle. const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || !h->handle->IsValid(&status->status)) { status->status = tensorflow::errors::InvalidArgument( @@ -72,7 +71,8 @@ void DLManagedTensorDeleter(DLManagedTensor* arg) { delete owner; } -DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { +// Converts TF_DATAType to DLPack data type. +DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { DLDataType dtype; dtype.lanes = 1; dtype.bits = TF_DataTypeSize(data_type) * 8; @@ -107,16 +107,17 @@ DLDataType GetDLDataType(TF_DataType data_type, TF_Status* status) { return dtype; } -DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { +// Gets DLPack's DLContext from eager tensor handle. +DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; const char* device_name = h->handle->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; - int device_id = -1; + int device_id = 0; if (parsed_name.has_id) { device_id = parsed_name.id; - } // Question: Is it possible that it doens't have id? + } ctx.device_id = device_id; if (device_type == "CPU") { @@ -131,7 +132,7 @@ DLContext GetDLContext(TFE_TensorHandle* h, TF_Status* status) { return ctx; } -// Convert DLContext to TF device name +// Converts DLContext to TF device name. absl::optional DeviceNameFromDlContext(const DLContext& ctx, TF_Status* status) { switch (ctx.device_type) { @@ -144,7 +145,7 @@ absl::optional DeviceNameFromDlContext(const DLContext& ctx, }; } -// Convert DLPack data type to TF_DATATYPE +// Converts DLPack data type to TF_DATATYPE. TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, TF_Status* status) { TF_DataType tf_dtype; @@ -221,16 +222,15 @@ TF_DataType TfDataTypeFormDlDataType(const DLDataType& dtype, return tf_dtype; } -// Wrapper function to match the function signature -// TFE_NewTensorHandleFromDeviceMemory, calling the deleter of the -// DLManagedTensor +// Wraps the deleter function of DLManagedTensor to match the function signature +// TFE_NewTensorHandleFromDeviceMemory. void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { DLManagedTensor* dlmt = static_cast(dlmt_vptr); dlmt->deleter(const_cast(dlmt)); } -// Check whether the stride array matches the layout of compact, row-majored -// data +// Checks whether the stride array matches the layout of compact, row-majored +// data. bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, int ndim) { if (ndim >= 1 && stride_arr[ndim - 1] != 1) { @@ -263,11 +263,11 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDLContext(h, status); + dlm_tensor->dl_tensor.ctx = GetDlContext(h, status); int ndim = tensor->dims(); dlm_tensor->dl_tensor.ndim = ndim; dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDLDataType(data_type, status); + dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status); std::vector* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector* stride_arr = &tf_dlm_tensor_ctx->strides; @@ -283,9 +283,9 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. - // 2) fill in the strides array as the real case for compact row-major data - // Here we choose option 2, since some framework didn't handle the strides - // argument properly + // 2) fill in the strides array as the real case for compact row-major data. + // Here we choose option 2, since some frameworks didn't handle the strides + // argument properly. dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 21ee37b78d8..cf83b79b573 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -24,16 +24,16 @@ namespace tensorflow { // PyCapsule name for DLPack Tensor const char* const kDlTensorCapsuleName = "dltensor"; -// Convert eager tensor handle to DLPack (DLManagedTensor*), and return the -// void* for further PyCapsule construction +// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the +// void* for further PyCapsule construction. TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status); -// Convert DLPack (DLManagedTensor*) to eager tensor handle +// Converts DLPack (DLManagedTensor*) to eager tensor handle. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status); -// Call the destructor of DLManagedTensor, used in the destructor of PyCapsule +// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); } // namespace tensorflow From 29856e63547209acf37ababdf9ba03c26c4b13f4 Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 26 Feb 2020 14:48:30 +0000 Subject: [PATCH 19/20] fix python sympol export --- tensorflow/c/eager/dlpack.cc | 3 +++ tensorflow/python/BUILD | 1 + tensorflow/python/__init__.py | 3 +++ tensorflow/python/dlpack/dlpack.py | 6 +++--- tensorflow/python/dlpack/dlpack_test.py | 4 ++-- tensorflow/python/tools/api/generator/api_init_files.bzl | 1 + 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 6276371bd68..bb898ac1fed 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -305,6 +305,9 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { return nullptr; } TF_DataType dtype = TfDataTypeFormDlDataType(dl_tensor->dtype, status); + if (!status->status.ok()) { + return nullptr; + } int num_dims = dl_tensor->ndim; const int64_t* dims = dl_tensor->shape; void* data = dl_tensor->data; diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 63593f1a428..70ae3aa96f1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -190,6 +190,7 @@ py_library( "//tensorflow/python/distribute:estimator_training", "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/dlpack", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:monitoring", "//tensorflow/python/eager:profiler", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 6d88cb566ae..c5a4207b476 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -170,6 +170,9 @@ from tensorflow.python.debug.lib import check_numerics_callback from tensorflow.python.debug.lib import dumping_callback from tensorflow.python.ops import gen_debug_ops +# DLPack +from tensorflow.python.dlpack.dlpack import from_dlpack, to_dlpack + # XLA JIT compiler APIs. from tensorflow.python.compiler.xla import jit from tensorflow.python.compiler.xla import xla diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 7a04fca3933..5b278db36ba 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -1,4 +1,4 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,11 +23,11 @@ from tensorflow.python.util.tf_export import tf_export # tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix? -@tf_export("dlpack.to_dlpack") +@tf_export("experimental.dlpack.to_dlpack", v1=[]) def to_dlpack(tf_tensor): return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) -@tf_export("dlpack.from_dlpack") +@tf_export("experimental.dlpack.from_dlpack", v1=[]) def from_dlpack(dlcapsule): return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index 8b47c71dc6b..206c2b7d926 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,7 +81,7 @@ class DLPackTest(parameterized.TestCase, test.TestCase): ".*a DLPack tensor may be consumed at most once.*", ConsumeDLPackTensor) - def testUnsupportedType(self): + def testUnsupportedTypeToDLPack(self): def case1(): tf_tensor = constant_op.constant( [[1, 4], [5, 2]], dtype=dtypes.qint16) diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 3aab59e50aa..99981a5ce2e 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [ "errors/__init__.py", "experimental/__init__.py", "experimental/tensorrt/__init__.py", + "experimental/dlpack/__init__.py", "feature_column/__init__.py", "io/gfile/__init__.py", "graph_util/__init__.py", From dab3c26d038c980cc54a5754a6b69ce508e7d89b Mon Sep 17 00:00:00 2001 From: VoVAllen Date: Wed, 26 Feb 2020 15:16:27 +0000 Subject: [PATCH 20/20] add comment --- tensorflow/python/dlpack/dlpack.py | 31 +++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 5b278db36ba..360b1651ca8 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -22,12 +22,41 @@ from tensorflow.python import pywrap_tfe from tensorflow.python.util.tf_export import tf_export -# tf.dlpack.to_dlpack/from_dlpack doesn't work. How to fix? @tf_export("experimental.dlpack.to_dlpack", v1=[]) def to_dlpack(tf_tensor): + """Returns the dlpack capsule representing the tensor. This operation + ensures the underlying data memory is ready when returns. + + ```python + a = tf.tensor([1, 10]) + dlcapsule = tf.experimental.dlpack.to_dlpack(a) + # dlcapsule represents the dlpack data structure + ``` + + Args: + tf_tensor: Tensorflow eager tensor, to be converted to dlpack capsule. + + Returns: + A PyCapsule named as dltensor, which shares the underlying memory to other + framework. This PyCapsule can be consumed only once. + """ return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) @tf_export("experimental.dlpack.from_dlpack", v1=[]) def from_dlpack(dlcapsule): + """Returns the Tensorflow eager tensor, which uses the memory shared by + dlpack capsules from other framework. + + ```python + a = tf.experimental.dlpack.from_dlpack(dlcapsule) + # `a` uses the memory shared by dlpack + ``` + + Args: + dlcapsule: A PyCapsule named as dltensor + + Returns: + A Tensorflow eager tensor + """ return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)