This commit is contained in:
VoVAllen 2020-03-28 08:41:24 +00:00
parent 5ad4ed80bb
commit 8fa6423567
4 changed files with 10 additions and 12 deletions

View File

@ -289,7 +289,7 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, TFE_Context* ctx) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
@ -322,16 +322,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
return nullptr;
}
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
TFE_DeleteContext(ctx);
return handle;
}

View File

@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status);
TF_Status* status,
TFE_Context* ctx);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.util.tf_export import tf_export
@ -62,4 +63,4 @@ def from_dlpack(dlcapsule):
Returns:
A Tensorflow eager tensor
"""
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle)

View File

@ -1074,7 +1074,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
return capsule;
});
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) {
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
const py::handle& context) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(pycapsule.name()) !=
@ -1085,8 +1086,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
absl::string_view(pycapsule.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle =
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
pycapsule, status.get(), tensorflow::InputTFE_Context(context));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());