Enable TPU POD for JAX/1VM by creating devices and local_devices topology.
PiperOrigin-RevId: 289761259 Change-Id: Icbdcca91fd37ea0a04ad16df82aede52e3281ed9
This commit is contained in:
parent
3dec91764c
commit
79cea35560
@ -34,14 +34,34 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
constexpr char kTpuPlatform[] = "tpu";
|
||||||
|
|
||||||
|
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||||
|
int core_on_chip)
|
||||||
|
: xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, host_id),
|
||||||
|
coords_(coords),
|
||||||
|
core_on_chip_(core_on_chip) {}
|
||||||
|
|
||||||
std::string TpuDevice::DebugString() const {
|
std::string TpuDevice::DebugString() const {
|
||||||
return absl::StrCat("TPU_", id());
|
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
|
||||||
|
coords_[0], coords_[1], coords_[2], core_on_chip_);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>>
|
||||||
int id) {
|
TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
|
||||||
CHECK_EQ(platform_name, "tpu");
|
std::vector<std::shared_ptr<Device>> devices;
|
||||||
return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu");
|
for (const auto& chip : system_info.tpu_chip()) {
|
||||||
|
auto& coord = chip.chip_coord();
|
||||||
|
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
|
||||||
|
int host_id = chip.host_id();
|
||||||
|
for (const auto& core : chip.core()) {
|
||||||
|
auto device = std::make_shared<TpuDevice>(
|
||||||
|
core.id(), host_id, coords_array, core.core_on_chip_index());
|
||||||
|
devices.push_back(device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return devices;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||||
@ -49,7 +69,6 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
|||||||
tpu_driver::TpuDriverConfig driver_config;
|
tpu_driver::TpuDriverConfig driver_config;
|
||||||
driver_config.set_worker(worker);
|
driver_config.set_worker(worker);
|
||||||
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
|
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
|
||||||
|
|
||||||
if (!client_status.ok()) {
|
if (!client_status.ok()) {
|
||||||
return client_status.status();
|
return client_status.status();
|
||||||
}
|
}
|
||||||
@ -58,19 +77,13 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
|||||||
|
|
||||||
tpu_driver::SystemInfo system_info;
|
tpu_driver::SystemInfo system_info;
|
||||||
client->QuerySystemInfo(&system_info);
|
client->QuerySystemInfo(&system_info);
|
||||||
int num_cores =
|
|
||||||
system_info.tpu_chip_size() * system_info.tpu_chip(0).core_size();
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<Device>> devices;
|
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<Device>> devices,
|
||||||
CHECK_GE(num_cores, 1);
|
TpuDevice::GetTpuDevices(system_info));
|
||||||
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
|
|
||||||
devices.reserve(num_cores);
|
|
||||||
for (int i = 0; i < num_cores; ++i) {
|
|
||||||
devices.push_back(MakeDevice("tpu", i));
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_shared<PyTpuClient>("tpu", std::move(client),
|
return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
|
||||||
std::move(devices), /*host_id=*/0);
|
std::move(devices),
|
||||||
|
system_info.host_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTpuClient::PyTpuClient(std::string platform_name,
|
PyTpuClient::PyTpuClient(std::string platform_name,
|
||||||
@ -81,18 +94,21 @@ PyTpuClient::PyTpuClient(std::string platform_name,
|
|||||||
driver_(std::move(driver)),
|
driver_(std::move(driver)),
|
||||||
devices_(std::move(devices)),
|
devices_(std::move(devices)),
|
||||||
host_id_(host_id) {
|
host_id_(host_id) {
|
||||||
local_devices_.resize(devices_.size());
|
|
||||||
for (const std::shared_ptr<Device>& device : devices_) {
|
for (const std::shared_ptr<Device>& device : devices_) {
|
||||||
CHECK(id_to_device_.insert({device->id(), device}).second)
|
CHECK(id_to_device_.insert({device->id(), device}).second)
|
||||||
<< "Duplicate device id: " << device->id();
|
<< "Duplicate device id: " << device->id();
|
||||||
|
|
||||||
if (device->id() != -1) {
|
if (device->host_id() == host_id_) {
|
||||||
int idx = device->id();
|
LOG(INFO) << "Detected local device, host-id: " << host_id_
|
||||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
<< ". core-id: " << device->id();
|
||||||
CHECK_LT(idx, local_devices_.size());
|
local_devices_.push_back(device);
|
||||||
local_devices_[idx] = device;
|
} else {
|
||||||
|
VLOG(2) << "Other devices, id: " << device->id();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
CHECK_GE(local_devices_.size(), 1);
|
||||||
|
LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
|
||||||
|
|
||||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||||
}
|
}
|
||||||
@ -217,8 +233,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
|
|||||||
std::shared_ptr<TpuSharedBuffer> child_device_buffer =
|
std::shared_ptr<TpuSharedBuffer> child_device_buffer =
|
||||||
child_buffer->DeviceBuffer();
|
child_buffer->DeviceBuffer();
|
||||||
// Merge all definition events from all children, so that anyone using this
|
// Merge all definition events from all children, so that anyone using this
|
||||||
// tuple must wait for all its children to finish receiving transfers.
|
// tuple must wait for all its children to finish receiving transfers. This
|
||||||
// This works recursively up a nested tuple tree as well.
|
// works recursively up a nested tuple tree as well.
|
||||||
for (std::shared_ptr<tpu_driver::Event> child_event :
|
for (std::shared_ptr<tpu_driver::Event> child_event :
|
||||||
child_device_buffer->wait_for_use) {
|
child_device_buffer->wait_for_use) {
|
||||||
child_events.push_back(std::move(child_event));
|
child_events.push_back(std::move(child_event));
|
||||||
|
@ -38,8 +38,21 @@ namespace xla {
|
|||||||
|
|
||||||
class TpuDevice : public Device {
|
class TpuDevice : public Device {
|
||||||
public:
|
public:
|
||||||
using Device::Device;
|
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||||
|
int core_on_chip);
|
||||||
|
|
||||||
|
const std::array<int, 3>& coords() const { return coords_; }
|
||||||
|
int core_on_chip() const { return core_on_chip_; }
|
||||||
|
|
||||||
std::string DebugString() const override;
|
std::string DebugString() const override;
|
||||||
|
|
||||||
|
static xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>> GetTpuDevices(
|
||||||
|
const tpu_driver::SystemInfo& system_info);
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::array<int, 3> coords_;
|
||||||
|
// Index of the core of the same chip.
|
||||||
|
int core_on_chip_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Encapsulates the state of Python session with XLA.
|
// Encapsulates the state of Python session with XLA.
|
||||||
@ -50,7 +63,7 @@ class PyTpuClient {
|
|||||||
static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
|
static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
|
||||||
|
|
||||||
explicit PyTpuClient(std::string platform_name,
|
explicit PyTpuClient(std::string platform_name,
|
||||||
std::unique_ptr<tpu_driver::TpuDriver> client,
|
std::unique_ptr<tpu_driver::TpuDriver> driver,
|
||||||
std::vector<std::shared_ptr<Device>> devices,
|
std::vector<std::shared_ptr<Device>> devices,
|
||||||
int host_id);
|
int host_id);
|
||||||
virtual ~PyTpuClient() = default;
|
virtual ~PyTpuClient() = default;
|
||||||
|
@ -206,8 +206,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||||
|
|
||||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||||
|
.def_property_readonly("coords", &TpuDevice::coords)
|
||||||
|
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
|
||||||
.def("__repr__", [](const TpuDevice& device) {
|
.def("__repr__", [](const TpuDevice& device) {
|
||||||
return absl::StrFormat("TpuDevice(id=%i)", device.id());
|
return absl::StrFormat(
|
||||||
|
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
|
||||||
|
device.id(), device.host_id(), device.coords()[0],
|
||||||
|
device.coords()[1], device.coords()[2], device.core_on_chip());
|
||||||
});
|
});
|
||||||
} // NOLINT(readability/fn_size)
|
} // NOLINT(readability/fn_size)
|
||||||
|
|
||||||
|
@ -19,15 +19,24 @@ package tpu_driver;
|
|||||||
|
|
||||||
enum MemoryRegion { HBM = 1; }
|
enum MemoryRegion { HBM = 1; }
|
||||||
|
|
||||||
|
message ChipCoordinate {
|
||||||
|
required int32 x = 1;
|
||||||
|
required int32 y = 2;
|
||||||
|
required int32 z = 3;
|
||||||
|
}
|
||||||
|
|
||||||
message TpuCoreInfo {
|
message TpuCoreInfo {
|
||||||
required int32 id = 1;
|
required int32 id = 1;
|
||||||
|
optional int32 core_on_chip_index = 2;
|
||||||
required int64 hbm_bytes_available = 100;
|
optional int32 core_on_host_index = 3;
|
||||||
required int64 hbm_bytes_allocatable = 101;
|
optional int64 hbm_bytes_available = 100;
|
||||||
|
optional int64 hbm_bytes_allocatable = 101;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TpuChipInfo {
|
message TpuChipInfo {
|
||||||
repeated TpuCoreInfo core = 1;
|
repeated TpuCoreInfo core = 1;
|
||||||
|
optional int32 host_id = 2;
|
||||||
|
optional ChipCoordinate chip_coord = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CpuInfo {
|
message CpuInfo {
|
||||||
@ -40,6 +49,11 @@ message CpuInfo {
|
|||||||
message SystemInfo {
|
message SystemInfo {
|
||||||
repeated TpuChipInfo tpu_chip = 1;
|
repeated TpuChipInfo tpu_chip = 1;
|
||||||
required CpuInfo cpu = 2;
|
required CpuInfo cpu = 2;
|
||||||
|
repeated TpuCoreInfo local_core = 3;
|
||||||
|
optional int32 host_id = 4;
|
||||||
|
optional int32 host_count = 5;
|
||||||
|
optional int32 chip_count = 6;
|
||||||
|
optional int32 core_count = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TpuDriverConfig {
|
message TpuDriverConfig {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user