[PJRT] Expose TPU runtime version via new PjRtClient::platform_version string.

On Cloud TPU VMs, the returned platform version will look like:

libtpu version 0.0.1
Built on Mar 4 2021 15:25:57 (1614900357) cl/360760169

On all other platform, the platform version will be "<unknown>".

[XLA:Python] Add xla_client.Client.platform_version

PiperOrigin-RevId: 361215574
Change-Id: Ie65ce509caa550277972751cfb58027ad58c1d23
This commit is contained in:
Skye Wanderman-Milne 2021-03-05 13:39:05 -08:00 committed by TensorFlower Gardener
parent d4dffb1c29
commit 67ff36f392
8 changed files with 46 additions and 3 deletions

View File

@ -173,6 +173,8 @@ class PjRtClient {
// Returns a string that identifies the platform (CPU/GPU/TPU).
virtual absl::string_view platform_name() const = 0;
virtual absl::string_view platform_version() const = 0;
// Return a device-specific default device assignment, e.g., GPU and TPU may
// be different.
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(

View File

@ -155,6 +155,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
PjRtPlatformId platform_id() const override { return platform_id_; }
absl::string_view platform_name() const override { return platform_name_; }
absl::string_view platform_version() const override { return "<unknown>"; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This

View File

@ -79,6 +79,10 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int task_id);
absl::string_view platform_version() const override {
return platform_version_;
}
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
@ -86,6 +90,9 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override;
private:
const std::string platform_version_;
};
PjRtTpuClient::PjRtTpuClient(
@ -95,7 +102,18 @@ PjRtTpuClient::PjRtTpuClient(
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {}
/*gpu_run_options=*/nullptr),
platform_version_([]() {
// Example platform version string:
// libtpu version 0.0.1
// Built on Mar 4 2021 15:25:57 (1614900357) cl/360760169
tf_tpu::TpuPlatformInterface* platform =
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
TpuRuntimeVersion version = platform->version();
return absl::StrCat(
"libtpu version ", absl::StrJoin(version.version, "."), "\n",
absl::string_view(version.metadata, version.metadata_size));
}()) {}
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {

View File

@ -97,6 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
absl::string_view platform_name() const {
return pjrt_client_->platform_name();
}
absl::string_view platform_version() const {
return pjrt_client_->platform_version();
}
int addressable_device_count() const {
return pjrt_client_->addressable_device_count();
}

View File

@ -116,7 +116,8 @@ class PyTpuClient {
return id_to_device_;
}
int task_id() const { return task_id_; }
const std::string& platform_name() const { return platform_name_; }
const absl::string_view platform_name() const { return platform_name_; }
const absl::string_view platform_version() const { return "<unknown>"; }
StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
return Unimplemented("ChooseCompactLayoutForShape not implemented.");
@ -207,7 +208,9 @@ class PyTpuBuffer {
const Shape& on_host_shape() const { return on_host_shape_; }
std::shared_ptr<PjRtDevice> device() const { return device_; }
const std::string& platform_name() const { return client_->platform_name(); }
const absl::string_view platform_name() const {
return client_->platform_name();
}
std::shared_ptr<PyTpuClient> client() const { return client_; }
// Returns the buffer's value as a tuple DAG of Python arrays. If the value

View File

@ -28,6 +28,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::class_<PyTpuClient, std::shared_ptr<PyTpuClient>>(m, "TpuClient")
.def_static("Get", &PyTpuClient::Get, py::arg("worker"))
.def_property_readonly("platform", &PyTpuClient::platform_name)
.def_property_readonly("platform_version", &PyTpuClient::platform_version)
.def("device_count", &PyTpuClient::device_count)
.def("local_device_count", &PyTpuClient::local_device_count)
.def("devices", &PyTpuClient::devices)

View File

@ -200,6 +200,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def_property_readonly("platform_version", &PyClient::platform_version)
.def("device_count", &PyClient::device_count)
.def("local_device_count", &PyClient::addressable_device_count)
.def("devices", &PyClient::Devices)

View File

@ -2149,6 +2149,20 @@ def TestFactory(xla_backend, cloud_tpu=False):
tests.append(TracebackTest)
class ClientTest(parameterized.TestCase):
def setUp(self):
super(ClientTest, self).setUp()
self.backend = xla_backend()
def testPlatformVersion(self):
# Check doesn't crash
version = self.backend.platform_version
if self.backend.platform == "cpu":
self.assertEqual(version, "<unknown>")
tests.append(ClientTest)
return tests