diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 45048bd6efb..f6d6ee0710a 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -248,21 +248,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { } void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { + auto tf_dlm_context = GetDlContext(h, status); + if (!status->status.ok()) { + return nullptr; + } + + auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status); + if (!status->status.ok()) { + return nullptr; + } + const Tensor* tensor = GetTensorFromHandle(h, status); TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype()); - TensorReference tensor_ref(*tensor); // This will call buf_->Ref() + auto tf_dlm_type = GetDlDataType(data_type, status); + if (!status->status.ok()) { + return nullptr; + } + + TensorReference tensor_ref(*tensor); // This will call buf_->Ref() auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); tf_dlm_tensor_ctx->reference = tensor_ref; DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; dlm_tensor->deleter = &DLManagedTensorDeleter; - dlm_tensor->dl_tensor.ctx = GetDlContext(h, status); + dlm_tensor->dl_tensor.ctx = tf_dlm_context; int ndim = tensor->dims(); dlm_tensor->dl_tensor.ndim = ndim; - dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status); - dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status); + dlm_tensor->dl_tensor.data = tf_dlm_data; + dlm_tensor->dl_tensor.dtype = tf_dlm_type; std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape; std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides; @@ -275,13 +290,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; } - dlm_tensor->dl_tensor.shape = &(*shape_arr)[0]; + dlm_tensor->dl_tensor.shape = shape_arr->data(); // There are two ways to represent compact row-major data // 1) nullptr indicates tensor is compact and row-majored. // 2) fill in the strides array as the real case for compact row-major data. // Here we choose option 2, since some frameworks didn't handle the strides // argument properly. - dlm_tensor->dl_tensor.strides = &(*stride_arr)[0]; + dlm_tensor->dl_tensor.strides = stride_arr->data(); + dlm_tensor->dl_tensor.byte_offset = 0; // TF doesn't handle the strides and byte_offsets here return static_cast<void*>(dlm_tensor); diff --git a/tensorflow/python/dlpack/dlpack_test.py b/tensorflow/python/dlpack/dlpack_test.py index af91da80512..df53220849c 100644 --- a/tensorflow/python/dlpack/dlpack_test.py +++ b/tensorflow/python/dlpack/dlpack_test.py @@ -20,9 +20,11 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np + from tensorflow.python.dlpack import dlpack from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -95,6 +97,12 @@ class DLPackTest(parameterized.TestCase, test.TestCase): self.assertRaisesRegex(Exception, ".* is not supported by dlpack", UnsupportedComplex64) + def testMustPassTensorArgumentToDLPack(self): + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "The argument to `to_dlpack` must be a TF tensor, not Python object"): + dlpack.to_dlpack([1]) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 88bb66f189b..3401020ae99 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1129,9 +1129,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // DLPack functions m.def("TFE_ToDlpackCapsule", [](py::handle& o) { PyObject* eager_tensor_pyobject_ptr = o.ptr(); - TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + + if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) { + status->status = tensorflow::errors::InvalidArgument( + "The argument to `to_dlpack` must be a TF tensor, not Python object"); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + } + + TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());