fix
This commit is contained in:
parent
5ad4ed80bb
commit
8fa6423567
|
@ -289,7 +289,7 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
|||
return static_cast<void*>(dlm_tensor);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, TFE_Context* ctx) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||
absl::optional<std::string> device_name =
|
||||
|
@ -322,16 +322,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
|||
|
||||
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||
TF_Status* status);
|
||||
TF_Status* status,
|
||||
TFE_Context* ctx);
|
||||
|
||||
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
|
@ -62,4 +63,4 @@ def from_dlpack(dlcapsule):
|
|||
Returns:
|
||||
A Tensorflow eager tensor
|
||||
"""
|
||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle)
|
||||
|
|
|
@ -1074,7 +1074,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||
return capsule;
|
||||
});
|
||||
|
||||
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) {
|
||||
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
|
||||
const py::handle& context) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
if (absl::string_view(pycapsule.name()) !=
|
||||
|
@ -1085,8 +1086,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||
absl::string_view(pycapsule.name()));
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
}
|
||||
TFE_TensorHandle* thandle =
|
||||
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
||||
|
||||
TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
|
||||
pycapsule, status.get(), tensorflow::InputTFE_Context(context));
|
||||
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
|
||||
|
|
Loading…
Reference in New Issue