From bfc75faae77274ff9a0d542a3f271d5424f54a61 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Jan 2021 10:15:46 -0800 Subject: [PATCH] [XLA:Python] Make gRPC TPU driver inherit from PjRtDevice instead of PjRtStreamExecutorDevice. The TPU driver does not use StreamExecutor, so it should not inherit from the SE implementation. PiperOrigin-RevId: 352596334 Change-Id: Ice3e5087430a57f958b9c6240cc0e3bd0cc65a5a --- .../xla/python/tpu_driver/client/BUILD | 3 +- .../python/tpu_driver/client/tpu_client.cc | 4 +-- .../xla/python/tpu_driver/client/tpu_client.h | 28 +++++++++++++++++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index ec05d933316..25c39d68960 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -19,6 +19,7 @@ cc_library( ], compatible_with = [], deps = [ + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -26,7 +27,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/pjrt:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index a9aa218ca6f..4656e9dfa7a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,8 +37,8 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr, - /*device_kind=*/"Cloud TPU", host_id), + : id_(id), + host_id_(host_id), coords_(coords), core_on_chip_(core_on_chip) {} diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index cc4e4471e8e..85211b01bbb 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,8 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -39,7 +40,7 @@ namespace xla { constexpr char kTpuPlatform[] = "tpu"; -class TpuDevice : public PjRtStreamExecutorDevice { +class TpuDevice : public PjRtDevice { public: TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip); @@ -52,8 +53,31 @@ class TpuDevice : public PjRtStreamExecutorDevice { static xla::StatusOr>> GetTpuDevices(const tpu_driver::SystemInfo& system_info); + PjRtClient* client() const override { return nullptr; } + + bool IsAddressable() const override { return false; } + + int id() const override { return id_; } + + int host_id() const override { return host_id_; } + + int local_hardware_id() const override { return -1; } + + const std::string& device_kind() const override { return device_kind_; } + + Status TransferToInfeed(const LiteralSlice& literal) const override { + return Unimplemented("Infeed not yet implemented via this API"); + } + + StatusOr TransferFromOutfeed(const Shape& shape) const override { + return Unimplemented("Outfeed not yet implemented via this API"); + } + private: + const int id_; + const int host_id_; const std::array coords_; + const std::string device_kind_ = "Cloud TPU"; // Index of the core of the same chip. int core_on_chip_; };