[XLA:Python] Make gRPC TPU driver inherit from PjRtDevice instead of PjRtStreamExecutorDevice.
The TPU driver does not use StreamExecutor, so it should not inherit from the SE implementation. PiperOrigin-RevId: 352596334 Change-Id: Ice3e5087430a57f958b9c6240cc0e3bd0cc65a5a
This commit is contained in:
parent
d7485ac5c3
commit
bfc75faae7
@ -19,6 +19,7 @@ cc_library(
|
||||
],
|
||||
compatible_with = [],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -26,7 +27,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:semaphore",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
|
||||
|
@ -37,8 +37,8 @@ namespace xla {
|
||||
|
||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip)
|
||||
: xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
|
||||
/*device_kind=*/"Cloud TPU", host_id),
|
||||
: id_(id),
|
||||
host_id_(host_id),
|
||||
coords_(coords),
|
||||
core_on_chip_(core_on_chip) {}
|
||||
|
||||
|
@ -24,7 +24,8 @@ limitations under the License.
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
@ -39,7 +40,7 @@ namespace xla {
|
||||
|
||||
constexpr char kTpuPlatform[] = "tpu";
|
||||
|
||||
class TpuDevice : public PjRtStreamExecutorDevice {
|
||||
class TpuDevice : public PjRtDevice {
|
||||
public:
|
||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip);
|
||||
@ -52,8 +53,31 @@ class TpuDevice : public PjRtStreamExecutorDevice {
|
||||
static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
|
||||
GetTpuDevices(const tpu_driver::SystemInfo& system_info);
|
||||
|
||||
PjRtClient* client() const override { return nullptr; }
|
||||
|
||||
bool IsAddressable() const override { return false; }
|
||||
|
||||
int id() const override { return id_; }
|
||||
|
||||
int host_id() const override { return host_id_; }
|
||||
|
||||
int local_hardware_id() const override { return -1; }
|
||||
|
||||
const std::string& device_kind() const override { return device_kind_; }
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal) const override {
|
||||
return Unimplemented("Infeed not yet implemented via this API");
|
||||
}
|
||||
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override {
|
||||
return Unimplemented("Outfeed not yet implemented via this API");
|
||||
}
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
const int host_id_;
|
||||
const std::array<int, 3> coords_;
|
||||
const std::string device_kind_ = "Cloud TPU";
|
||||
// Index of the core of the same chip.
|
||||
int core_on_chip_;
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user