Call InitializeNumpyAPIForType() to make bfloat16 happy.

PiperOrigin-RevId: 296313339
Change-Id: If4f61d7de9b94860107f91e39fedc05d2807538d
This commit is contained in:
Henry Tan 2020-02-20 15:37:49 -08:00 committed by TensorFlower Gardener
parent 0c4f32c44e
commit e722dc53dc

View File

@ -25,6 +25,11 @@ namespace xla {
namespace py = pybind11;
PYBIND11_MODULE(tpu_client_extension, m) {
// Initializes the NumPy API for the use of the types module.
if (!InitializeNumpyAPIForTypes()) {
throw std::runtime_error("Unable to initialize Numpy API");
}
py::class_<PyTpuClient, std::shared_ptr<PyTpuClient>>(m, "TpuClient")
.def_static("Get", &PyTpuClient::Get, py::arg("worker"))
.def("device_count", &PyTpuClient::device_count)