diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index dc3b25c47cb..3ad7b1bea7b 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -289,7 +289,7 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { return static_cast(dlm_tensor); } -TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { +TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, TFE_Context* ctx) { DLManagedTensor* dlmt = static_cast(dlm); DLTensor* dl_tensor = &dlmt->dl_tensor; absl::optional device_name = @@ -322,16 +322,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { return nullptr; } - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); - TFE_DeleteContextOptions(opts); - TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( ctx, device_name.value().c_str(), dtype, dims, num_dims, data, total_bytes, &DeallocatorWrapperFunc, &dlmt, status); - TFE_DeleteContext(ctx); - return handle; } diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 4177af1a6e7..8c85dee62f7 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, // Converts DLPack (DLManagedTensor*) to eager tensor handle. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, - TF_Status* status); + TF_Status* status, + TFE_Context* ctx); // Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); diff --git a/tensorflow/python/dlpack/dlpack.py b/tensorflow/python/dlpack/dlpack.py index 47bf5b35a8b..8bad390ef21 100644 --- a/tensorflow/python/dlpack/dlpack.py +++ b/tensorflow/python/dlpack/dlpack.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tfe +from tensorflow.python.eager import context from tensorflow.python.util.tf_export import tf_export @@ -62,4 +63,4 @@ def from_dlpack(dlcapsule): Returns: A Tensorflow eager tensor """ - return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) + return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 1823858dcab..ee43e6e1d43 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -1074,7 +1074,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) { return capsule; }); - m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) { + m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule, + const py::handle& context) { tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); if (absl::string_view(pycapsule.name()) != @@ -1085,8 +1086,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { absl::string_view(pycapsule.name())); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); } - TFE_TensorHandle* thandle = - tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); + + TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( + pycapsule, status.get(), tensorflow::InputTFE_Context(context)); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());