This commit is contained in:
VoVAllen 2020-03-28 08:41:24 +00:00
parent 5ad4ed80bb
commit 8fa6423567
4 changed files with 10 additions and 12 deletions

View File

@ -289,7 +289,7 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
return static_cast<void*>(dlm_tensor); return static_cast<void*>(dlm_tensor);
} }
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) { TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, TFE_Context* ctx) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm); DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor; DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name = absl::optional<std::string> device_name =
@ -322,16 +322,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
return nullptr; return nullptr;
} }
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data, ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status); total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
TFE_DeleteContext(ctx);
return handle; return handle;
} }

View File

@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
// Converts DLPack (DLManagedTensor*) to eager tensor handle. // Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status); TF_Status* status,
TFE_Context* ctx);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. // Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import pywrap_tfe from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -62,4 +63,4 @@ def from_dlpack(dlcapsule):
Returns: Returns:
A Tensorflow eager tensor A Tensorflow eager tensor
""" """
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule) return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle)

View File

@ -1074,7 +1074,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
return capsule; return capsule;
}); });
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) { m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
const py::handle& context) {
tensorflow::Safe_TF_StatusPtr status = tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus()); tensorflow::make_safe(TF_NewStatus());
if (absl::string_view(pycapsule.name()) != if (absl::string_view(pycapsule.name()) !=
@ -1085,8 +1086,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
absl::string_view(pycapsule.name())); absl::string_view(pycapsule.name()));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
} }
TFE_TensorHandle* thandle =
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get()); TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
pycapsule, status.get(), tensorflow::InputTFE_Context(context));
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());