[tpu_driver] Fix segfault when calling TpuDevice.platform
PiperOrigin-RevId: 338787622 Change-Id: I2a0871f5954fa46132a6032e470469e7ac2ab945
This commit is contained in:
parent
62bc872eef
commit
29f1aa1beb
@ -207,6 +207,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
py::class_<TpuDevice, PjRtDevice, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||||
.def_property_readonly("coords", &TpuDevice::coords)
|
.def_property_readonly("coords", &TpuDevice::coords)
|
||||||
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
|
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
|
||||||
|
// TODO(skye): this is a horrible hack because falling back to
|
||||||
|
// PjRtDevice::platform_name() segfaults, due to TpuDevice::client_ being
|
||||||
|
// uninitialized. This can be removed when PyTpuClient subclasses
|
||||||
|
// PjRtClient and can be used to set TpuDevice::client_.
|
||||||
|
.def_property_readonly(
|
||||||
|
"platform",
|
||||||
|
[](const TpuDevice& device) -> std::string { return kTpuPlatform; })
|
||||||
.def("__repr__", [](const TpuDevice& device) {
|
.def("__repr__", [](const TpuDevice& device) {
|
||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
|
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
|
||||||
|
@ -1773,6 +1773,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
self._ExecuteAndCompareClose(c, expected=[expected])
|
self._ExecuteAndCompareClose(c, expected=[expected])
|
||||||
|
|
||||||
|
class DeviceTest(ComputationTest):
|
||||||
|
|
||||||
|
def testPlatform(self):
|
||||||
|
for device in self.backend.local_devices():
|
||||||
|
self.assertEqual(device.platform, self.backend.platform)
|
||||||
|
|
||||||
|
tests.append(DeviceTest)
|
||||||
|
|
||||||
class ErrorTest(ComputationTest):
|
class ErrorTest(ComputationTest):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user