Call InitializeNumpyAPIForType() to make bfloat16 happy.
PiperOrigin-RevId: 296313339 Change-Id: If4f61d7de9b94860107f91e39fedc05d2807538d
This commit is contained in:
parent
0c4f32c44e
commit
e722dc53dc
@ -25,6 +25,11 @@ namespace xla {
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
// Initializes the NumPy API for the use of the types module.
|
||||
if (!InitializeNumpyAPIForTypes()) {
|
||||
throw std::runtime_error("Unable to initialize Numpy API");
|
||||
}
|
||||
|
||||
py::class_<PyTpuClient, std::shared_ptr<PyTpuClient>>(m, "TpuClient")
|
||||
.def_static("Get", &PyTpuClient::Get, py::arg("worker"))
|
||||
.def("device_count", &PyTpuClient::device_count)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user