[PJRT] Expose TPU runtime version via new PjRtClient::platform_version string.
On Cloud TPU VMs, the returned platform version will look like: libtpu version 0.0.1 Built on Mar 4 2021 15:25:57 (1614900357) cl/360760169 On all other platform, the platform version will be "<unknown>". [XLA:Python] Add xla_client.Client.platform_version PiperOrigin-RevId: 361215574 Change-Id: Ie65ce509caa550277972751cfb58027ad58c1d23
This commit is contained in:
parent
d4dffb1c29
commit
67ff36f392
tensorflow/compiler/xla
@ -173,6 +173,8 @@ class PjRtClient {
|
||||
// Returns a string that identifies the platform (CPU/GPU/TPU).
|
||||
virtual absl::string_view platform_name() const = 0;
|
||||
|
||||
virtual absl::string_view platform_version() const = 0;
|
||||
|
||||
// Return a device-specific default device assignment, e.g., GPU and TPU may
|
||||
// be different.
|
||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
|
@ -155,6 +155,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
|
||||
PjRtPlatformId platform_id() const override { return platform_id_; }
|
||||
absl::string_view platform_name() const override { return platform_name_; }
|
||||
absl::string_view platform_version() const override { return "<unknown>"; }
|
||||
|
||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||
|
@ -79,6 +79,10 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int task_id);
|
||||
|
||||
absl::string_view platform_version() const override {
|
||||
return platform_version_;
|
||||
}
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
|
||||
@ -86,6 +90,9 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||
|
||||
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const override;
|
||||
|
||||
private:
|
||||
const std::string platform_version_;
|
||||
};
|
||||
|
||||
PjRtTpuClient::PjRtTpuClient(
|
||||
@ -95,7 +102,18 @@ PjRtTpuClient::PjRtTpuClient(
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr) {}
|
||||
/*gpu_run_options=*/nullptr),
|
||||
platform_version_([]() {
|
||||
// Example platform version string:
|
||||
// libtpu version 0.0.1
|
||||
// Built on Mar 4 2021 15:25:57 (1614900357) cl/360760169
|
||||
tf_tpu::TpuPlatformInterface* platform =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
|
||||
TpuRuntimeVersion version = platform->version();
|
||||
return absl::StrCat(
|
||||
"libtpu version ", absl::StrJoin(version.version, "."), "\n",
|
||||
absl::string_view(version.metadata, version.metadata_size));
|
||||
}()) {}
|
||||
|
||||
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
|
@ -97,6 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
absl::string_view platform_name() const {
|
||||
return pjrt_client_->platform_name();
|
||||
}
|
||||
absl::string_view platform_version() const {
|
||||
return pjrt_client_->platform_version();
|
||||
}
|
||||
int addressable_device_count() const {
|
||||
return pjrt_client_->addressable_device_count();
|
||||
}
|
||||
|
@ -116,7 +116,8 @@ class PyTpuClient {
|
||||
return id_to_device_;
|
||||
}
|
||||
int task_id() const { return task_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
const absl::string_view platform_name() const { return platform_name_; }
|
||||
const absl::string_view platform_version() const { return "<unknown>"; }
|
||||
|
||||
StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
|
||||
return Unimplemented("ChooseCompactLayoutForShape not implemented.");
|
||||
@ -207,7 +208,9 @@ class PyTpuBuffer {
|
||||
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
std::shared_ptr<PjRtDevice> device() const { return device_; }
|
||||
const std::string& platform_name() const { return client_->platform_name(); }
|
||||
const absl::string_view platform_name() const {
|
||||
return client_->platform_name();
|
||||
}
|
||||
std::shared_ptr<PyTpuClient> client() const { return client_; }
|
||||
|
||||
// Returns the buffer's value as a tuple DAG of Python arrays. If the value
|
||||
|
@ -28,6 +28,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
py::class_<PyTpuClient, std::shared_ptr<PyTpuClient>>(m, "TpuClient")
|
||||
.def_static("Get", &PyTpuClient::Get, py::arg("worker"))
|
||||
.def_property_readonly("platform", &PyTpuClient::platform_name)
|
||||
.def_property_readonly("platform_version", &PyTpuClient::platform_version)
|
||||
.def("device_count", &PyTpuClient::device_count)
|
||||
.def("local_device_count", &PyTpuClient::local_device_count)
|
||||
.def("devices", &PyTpuClient::devices)
|
||||
|
@ -200,6 +200,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
.def_property_readonly("platform_version", &PyClient::platform_version)
|
||||
.def("device_count", &PyClient::device_count)
|
||||
.def("local_device_count", &PyClient::addressable_device_count)
|
||||
.def("devices", &PyClient::Devices)
|
||||
|
@ -2149,6 +2149,20 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
|
||||
tests.append(TracebackTest)
|
||||
|
||||
class ClientTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ClientTest, self).setUp()
|
||||
self.backend = xla_backend()
|
||||
|
||||
def testPlatformVersion(self):
|
||||
# Check doesn't crash
|
||||
version = self.backend.platform_version
|
||||
if self.backend.platform == "cpu":
|
||||
self.assertEqual(version, "<unknown>")
|
||||
|
||||
tests.append(ClientTest)
|
||||
|
||||
return tests
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user