fix
This commit is contained in:
parent
5ad4ed80bb
commit
8fa6423567
|
@ -289,7 +289,7 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||||
return static_cast<void*>(dlm_tensor);
|
return static_cast<void*>(dlm_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, TFE_Context* ctx) {
|
||||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||||
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||||
absl::optional<std::string> device_name =
|
absl::optional<std::string> device_name =
|
||||||
|
@ -322,16 +322,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||||
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
|
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||||
|
|
||||||
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||||
TF_Status* status);
|
TF_Status* status,
|
||||||
|
TFE_Context* ctx);
|
||||||
|
|
||||||
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tfe
|
from tensorflow.python import pywrap_tfe
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,4 +63,4 @@ def from_dlpack(dlcapsule):
|
||||||
Returns:
|
Returns:
|
||||||
A Tensorflow eager tensor
|
A Tensorflow eager tensor
|
||||||
"""
|
"""
|
||||||
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule)
|
return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle)
|
||||||
|
|
|
@ -1074,7 +1074,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||||
return capsule;
|
return capsule;
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule) {
|
m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
|
||||||
|
const py::handle& context) {
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
if (absl::string_view(pycapsule.name()) !=
|
if (absl::string_view(pycapsule.name()) !=
|
||||||
|
@ -1085,8 +1086,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||||
absl::string_view(pycapsule.name()));
|
absl::string_view(pycapsule.name()));
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
}
|
}
|
||||||
TFE_TensorHandle* thandle =
|
|
||||||
tensorflow::TFE_HandleFromDLPack(pycapsule, status.get());
|
TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
|
||||||
|
pycapsule, status.get(), tensorflow::InputTFE_Context(context));
|
||||||
|
|
||||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue