Cherry pick dlpack fix into r2.3

This commit is contained in:
VoVAllen 2020-06-26 17:15:28 +00:00
parent 99fea8da0d
commit 8cdfc53a63
2 changed files with 7 additions and 1 deletions

View File

@ -324,7 +324,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data, ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status); total_bytes, &DeallocatorWrapperFunc, dlmt, status);
return handle; return handle;
} }

View File

@ -1169,7 +1169,13 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
<<<<<<< HEAD
return py::handle(EagerTensorFromHandle(thandle)); return py::handle(EagerTensorFromHandle(thandle));
=======
PyObject* pyhandle = EagerTensorFromHandle(thandle);
return tensorflow::PyoOrThrow(pyhandle);
>>>>>>> ce9b1295b5... fix
}); });
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context, m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,