Cherry pick dlpack fix into r2.3
This commit is contained in:
parent
99fea8da0d
commit
8cdfc53a63
|
@ -324,7 +324,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
|
||||||
|
|
||||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
total_bytes, &DeallocatorWrapperFunc, dlmt, status);
|
||||||
|
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1169,7 +1169,13 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||||
|
|
||||||
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||||
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||||
|
<<<<<<< HEAD
|
||||||
return py::handle(EagerTensorFromHandle(thandle));
|
return py::handle(EagerTensorFromHandle(thandle));
|
||||||
|
=======
|
||||||
|
|
||||||
|
PyObject* pyhandle = EagerTensorFromHandle(thandle);
|
||||||
|
return tensorflow::PyoOrThrow(pyhandle);
|
||||||
|
>>>>>>> ce9b1295b5... fix
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
|
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
|
||||||
|
|
Loading…
Reference in New Issue