[XLA:Python] Make gRPC TPU driver inherit from PjRtDevice instead of PjRtStreamExecutorDevice.

The TPU driver does not use StreamExecutor, so it should not inherit from the SE implementation.

PiperOrigin-RevId: 352596334
Change-Id: Ice3e5087430a57f958b9c6240cc0e3bd0cc65a5a
This commit is contained in:
Peter Hawkins 2021-01-19 10:15:46 -08:00 committed by TensorFlower Gardener
parent d7485ac5c3
commit bfc75faae7
3 changed files with 30 additions and 5 deletions

View File

@ -19,6 +19,7 @@ cc_library(
],
compatible_with = [],
deps = [
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@ -26,7 +27,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:semaphore",
"//tensorflow/compiler/xla/python/tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",

View File

@ -37,8 +37,8 @@ namespace xla {
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip)
: xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
/*device_kind=*/"Cloud TPU", host_id),
: id_(id),
host_id_(host_id),
coords_(coords),
core_on_chip_(core_on_chip) {}

View File

@ -24,7 +24,8 @@ limitations under the License.
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@ -39,7 +40,7 @@ namespace xla {
constexpr char kTpuPlatform[] = "tpu";
class TpuDevice : public PjRtStreamExecutorDevice {
class TpuDevice : public PjRtDevice {
public:
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip);
@ -52,8 +53,31 @@ class TpuDevice : public PjRtStreamExecutorDevice {
static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
GetTpuDevices(const tpu_driver::SystemInfo& system_info);
PjRtClient* client() const override { return nullptr; }
bool IsAddressable() const override { return false; }
int id() const override { return id_; }
int host_id() const override { return host_id_; }
int local_hardware_id() const override { return -1; }
const std::string& device_kind() const override { return device_kind_; }
Status TransferToInfeed(const LiteralSlice& literal) const override {
return Unimplemented("Infeed not yet implemented via this API");
}
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override {
return Unimplemented("Outfeed not yet implemented via this API");
}
private:
const int id_;
const int host_id_;
const std::array<int, 3> coords_;
const std::string device_kind_ = "Cloud TPU";
// Index of the core of the same chip.
int core_on_chip_;
};