Merge pull request #41017 from VoVAllen/fix_dlpack_r2.3
Cherry-pick dlpack fix #40843 into r2.3
This commit is contained in:
commit
fc9d68a5e8
|
@ -221,8 +221,7 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype,
|
|||
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||
// TFE_NewTensorHandleFromDeviceMemory.
|
||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
||||
TFE_CallDLManagedTensorDeleter(dlmt_vptr);
|
||||
}
|
||||
|
||||
// Checks whether the stride array matches the layout of compact, row-majored
|
||||
|
@ -324,7 +323,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
|
|||
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||
total_bytes, &DeallocatorWrapperFunc, dlmt, status);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
|
|
@ -1169,7 +1169,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||
|
||||
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||
return py::handle(EagerTensorFromHandle(thandle));
|
||||
|
||||
PyObject* pyhandle = EagerTensorFromHandle(thandle);
|
||||
return tensorflow::PyoOrThrow(pyhandle);
|
||||
});
|
||||
|
||||
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
|
||||
|
|
Loading…
Reference in New Issue