[tpu_driver] Fix segfault when calling TpuDevice.platform
PiperOrigin-RevId: 338787622 Change-Id: I2a0871f5954fa46132a6032e470469e7ac2ab945
This commit is contained in:
parent
62bc872eef
commit
29f1aa1beb
@ -207,6 +207,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
|
||||
// TODO(skye): this is a horrible hack because falling back to
|
||||
// PjRtDevice::platform_name() segfaults, due to TpuDevice::client_ being
|
||||
// uninitialized. This can be removed when PyTpuClient subclasses
|
||||
// PjRtClient and can be used to set TpuDevice::client_.
|
||||
.def_property_readonly(
|
||||
"platform",
|
||||
[](const TpuDevice& device) -> std::string { return kTpuPlatform; })
|
||||
.def("__repr__", [](const TpuDevice& device) {
|
||||
return absl::StrFormat(
|
||||
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
|
||||
|
@ -1773,6 +1773,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
dtype=np.int32)
|
||||
self._ExecuteAndCompareClose(c, expected=[expected])
|
||||
|
||||
class DeviceTest(ComputationTest):
|
||||
|
||||
def testPlatform(self):
|
||||
for device in self.backend.local_devices():
|
||||
self.assertEqual(device.platform, self.backend.platform)
|
||||
|
||||
tests.append(DeviceTest)
|
||||
|
||||
class ErrorTest(ComputationTest):
|
||||
|
||||
def setUp(self):
|
||||
|
Loading…
Reference in New Issue
Block a user