Enable TPU POD for JAX/1VM by creating devices and local_devices topology.
PiperOrigin-RevId: 289761259 Change-Id: Icbdcca91fd37ea0a04ad16df82aede52e3281ed9
This commit is contained in:
parent
3dec91764c
commit
79cea35560
@ -34,14 +34,34 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
constexpr char kTpuPlatform[] = "tpu";
|
||||
|
||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip)
|
||||
: xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, host_id),
|
||||
coords_(coords),
|
||||
core_on_chip_(core_on_chip) {}
|
||||
|
||||
std::string TpuDevice::DebugString() const {
|
||||
return absl::StrCat("TPU_", id());
|
||||
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
|
||||
coords_[0], coords_[1], coords_[2], core_on_chip_);
|
||||
}
|
||||
|
||||
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
||||
int id) {
|
||||
CHECK_EQ(platform_name, "tpu");
|
||||
return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu");
|
||||
xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>>
|
||||
TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
|
||||
std::vector<std::shared_ptr<Device>> devices;
|
||||
for (const auto& chip : system_info.tpu_chip()) {
|
||||
auto& coord = chip.chip_coord();
|
||||
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
|
||||
int host_id = chip.host_id();
|
||||
for (const auto& core : chip.core()) {
|
||||
auto device = std::make_shared<TpuDevice>(
|
||||
core.id(), host_id, coords_array, core.core_on_chip_index());
|
||||
devices.push_back(device);
|
||||
}
|
||||
}
|
||||
|
||||
return devices;
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
@ -49,7 +69,6 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
tpu_driver::TpuDriverConfig driver_config;
|
||||
driver_config.set_worker(worker);
|
||||
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
|
||||
|
||||
if (!client_status.ok()) {
|
||||
return client_status.status();
|
||||
}
|
||||
@ -58,19 +77,13 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
|
||||
tpu_driver::SystemInfo system_info;
|
||||
client->QuerySystemInfo(&system_info);
|
||||
int num_cores =
|
||||
system_info.tpu_chip_size() * system_info.tpu_chip(0).core_size();
|
||||
|
||||
std::vector<std::shared_ptr<Device>> devices;
|
||||
CHECK_GE(num_cores, 1);
|
||||
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
|
||||
devices.reserve(num_cores);
|
||||
for (int i = 0; i < num_cores; ++i) {
|
||||
devices.push_back(MakeDevice("tpu", i));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<Device>> devices,
|
||||
TpuDevice::GetTpuDevices(system_info));
|
||||
|
||||
return std::make_shared<PyTpuClient>("tpu", std::move(client),
|
||||
std::move(devices), /*host_id=*/0);
|
||||
return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
|
||||
std::move(devices),
|
||||
system_info.host_id());
|
||||
}
|
||||
|
||||
PyTpuClient::PyTpuClient(std::string platform_name,
|
||||
@ -81,18 +94,21 @@ PyTpuClient::PyTpuClient(std::string platform_name,
|
||||
driver_(std::move(driver)),
|
||||
devices_(std::move(devices)),
|
||||
host_id_(host_id) {
|
||||
local_devices_.resize(devices_.size());
|
||||
for (const std::shared_ptr<Device>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
if (device->id() != -1) {
|
||||
int idx = device->id();
|
||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||
CHECK_LT(idx, local_devices_.size());
|
||||
local_devices_[idx] = device;
|
||||
if (device->host_id() == host_id_) {
|
||||
LOG(INFO) << "Detected local device, host-id: " << host_id_
|
||||
<< ". core-id: " << device->id();
|
||||
local_devices_.push_back(device);
|
||||
} else {
|
||||
VLOG(2) << "Other devices, id: " << device->id();
|
||||
}
|
||||
}
|
||||
CHECK_GE(local_devices_.size(), 1);
|
||||
LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
|
||||
|
||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||
}
|
||||
@ -217,8 +233,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
|
||||
std::shared_ptr<TpuSharedBuffer> child_device_buffer =
|
||||
child_buffer->DeviceBuffer();
|
||||
// Merge all definition events from all children, so that anyone using this
|
||||
// tuple must wait for all its children to finish receiving transfers.
|
||||
// This works recursively up a nested tuple tree as well.
|
||||
// tuple must wait for all its children to finish receiving transfers. This
|
||||
// works recursively up a nested tuple tree as well.
|
||||
for (std::shared_ptr<tpu_driver::Event> child_event :
|
||||
child_device_buffer->wait_for_use) {
|
||||
child_events.push_back(std::move(child_event));
|
||||
|
@ -38,8 +38,21 @@ namespace xla {
|
||||
|
||||
class TpuDevice : public Device {
|
||||
public:
|
||||
using Device::Device;
|
||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip);
|
||||
|
||||
const std::array<int, 3>& coords() const { return coords_; }
|
||||
int core_on_chip() const { return core_on_chip_; }
|
||||
|
||||
std::string DebugString() const override;
|
||||
|
||||
static xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>> GetTpuDevices(
|
||||
const tpu_driver::SystemInfo& system_info);
|
||||
|
||||
private:
|
||||
const std::array<int, 3> coords_;
|
||||
// Index of the core of the same chip.
|
||||
int core_on_chip_;
|
||||
};
|
||||
|
||||
// Encapsulates the state of Python session with XLA.
|
||||
@ -50,7 +63,7 @@ class PyTpuClient {
|
||||
static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
|
||||
|
||||
explicit PyTpuClient(std::string platform_name,
|
||||
std::unique_ptr<tpu_driver::TpuDriver> client,
|
||||
std::unique_ptr<tpu_driver::TpuDriver> driver,
|
||||
std::vector<std::shared_ptr<Device>> devices,
|
||||
int host_id);
|
||||
virtual ~PyTpuClient() = default;
|
||||
|
@ -206,8 +206,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||
|
||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||
.def_property_readonly("coords", &TpuDevice::coords)
|
||||
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
|
||||
.def("__repr__", [](const TpuDevice& device) {
|
||||
return absl::StrFormat("TpuDevice(id=%i)", device.id());
|
||||
return absl::StrFormat(
|
||||
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
|
||||
device.id(), device.host_id(), device.coords()[0],
|
||||
device.coords()[1], device.coords()[2], device.core_on_chip());
|
||||
});
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
||||
|
@ -19,15 +19,24 @@ package tpu_driver;
|
||||
|
||||
enum MemoryRegion { HBM = 1; }
|
||||
|
||||
message ChipCoordinate {
|
||||
required int32 x = 1;
|
||||
required int32 y = 2;
|
||||
required int32 z = 3;
|
||||
}
|
||||
|
||||
message TpuCoreInfo {
|
||||
required int32 id = 1;
|
||||
|
||||
required int64 hbm_bytes_available = 100;
|
||||
required int64 hbm_bytes_allocatable = 101;
|
||||
optional int32 core_on_chip_index = 2;
|
||||
optional int32 core_on_host_index = 3;
|
||||
optional int64 hbm_bytes_available = 100;
|
||||
optional int64 hbm_bytes_allocatable = 101;
|
||||
}
|
||||
|
||||
message TpuChipInfo {
|
||||
repeated TpuCoreInfo core = 1;
|
||||
optional int32 host_id = 2;
|
||||
optional ChipCoordinate chip_coord = 3;
|
||||
}
|
||||
|
||||
message CpuInfo {
|
||||
@ -40,6 +49,11 @@ message CpuInfo {
|
||||
message SystemInfo {
|
||||
repeated TpuChipInfo tpu_chip = 1;
|
||||
required CpuInfo cpu = 2;
|
||||
repeated TpuCoreInfo local_core = 3;
|
||||
optional int32 host_id = 4;
|
||||
optional int32 host_count = 5;
|
||||
optional int32 chip_count = 6;
|
||||
optional int32 core_count = 7;
|
||||
}
|
||||
|
||||
message TpuDriverConfig {
|
||||
|
Loading…
Reference in New Issue
Block a user