[tpu_driver] Fix segfault when calling TpuDevice.platform

PiperOrigin-RevId: 338787622
Change-Id: I2a0871f5954fa46132a6032e470469e7ac2ab945
This commit is contained in:
Skye Wanderman-Milne 2020-10-23 20:04:31 -07:00 committed by TensorFlower Gardener
parent 62bc872eef
commit 29f1aa1beb
2 changed files with 15 additions and 0 deletions

View File

@ -207,6 +207,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice") py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords) .def_property_readonly("coords", &TpuDevice::coords)
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip) .def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
// TODO(skye): this is a horrible hack because falling back to
// PjRtDevice::platform_name() segfaults, due to TpuDevice::client_ being
// uninitialized. This can be removed when PyTpuClient subclasses
// PjRtClient and can be used to set TpuDevice::client_.
.def_property_readonly(
"platform",
[](const TpuDevice& device) -> std::string { return kTpuPlatform; })
.def("__repr__", [](const TpuDevice& device) { .def("__repr__", [](const TpuDevice& device) {
return absl::StrFormat( return absl::StrFormat(
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)", "TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",

View File

@ -1773,6 +1773,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
dtype=np.int32) dtype=np.int32)
self._ExecuteAndCompareClose(c, expected=[expected]) self._ExecuteAndCompareClose(c, expected=[expected])
class DeviceTest(ComputationTest):
def testPlatform(self):
for device in self.backend.local_devices():
self.assertEqual(device.platform, self.backend.platform)
tests.append(DeviceTest)
class ErrorTest(ComputationTest): class ErrorTest(ComputationTest):
def setUp(self): def setUp(self):