Add fixes to counting total TPU chips and cores to conform with existing JAX usage
PiperOrigin-RevId: 331189252 Change-Id: I3496579588631d61703f35908a0daadec844b981
This commit is contained in:
parent
cbc4001ecd
commit
92a4daaa73
@ -183,28 +183,54 @@ class PodTpuDriver : public TpuDriver {
|
||||
CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
int cumulative_core_id = 0;
|
||||
absl::flat_hash_set<std::tuple<int, int, int>> processed_chips;
|
||||
|
||||
for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
|
||||
SystemInfo driver_info;
|
||||
drivers_[driver_num]->QuerySystemInfo(&driver_info);
|
||||
|
||||
for (const auto& tpu_chip : driver_info.tpu_chip()) {
|
||||
*(pod_info_.add_tpu_chip()) = tpu_chip;
|
||||
std::tuple<int, int, int> coord{tpu_chip.chip_coord().x(),
|
||||
tpu_chip.chip_coord().y(),
|
||||
tpu_chip.chip_coord().z()};
|
||||
// We only want to add chips that we have not seen before if we are in a
|
||||
// TPU pod slice, or we are only seeing local cores (e.g. we are
|
||||
// connected to individual TPUs or we are in a test environment).
|
||||
if (!processed_chips.contains(coord) ||
|
||||
driver_info.core_count() == driver_info.local_core_size()) {
|
||||
*(pod_info_.add_tpu_chip()) = tpu_chip;
|
||||
processed_chips.insert(coord);
|
||||
}
|
||||
}
|
||||
|
||||
int core_num = 0;
|
||||
for (const auto& tpu_core : driver_info.local_core()) {
|
||||
*(pod_info_.add_local_core()) = tpu_core;
|
||||
core_to_driver_.push_back(drivers_[driver_num].get());
|
||||
core_to_driver_id_.push_back(driver_num);
|
||||
core_to_driver_core_.push_back(core_num++);
|
||||
}
|
||||
*(pod_info_.mutable_cpu()) = driver_info.cpu();
|
||||
pod_info_.set_host_count(pod_info_.host_count() + 1);
|
||||
pod_info_.set_chip_count(pod_info_.chip_count() +
|
||||
driver_info.chip_count());
|
||||
pod_info_.set_core_count(pod_info_.core_count() +
|
||||
driver_info.core_count());
|
||||
}
|
||||
|
||||
// Process all the unique chips that we have seen.
|
||||
for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) {
|
||||
for (auto& tpu_core : *tpu_chip.mutable_core()) {
|
||||
int current_core = cumulative_core_id++;
|
||||
|
||||
core_to_driver_.push_back(drivers_[tpu_chip.host_id()].get());
|
||||
core_to_driver_id_.push_back(tpu_chip.host_id());
|
||||
core_to_driver_core_.push_back(tpu_core.id());
|
||||
|
||||
tpu_core.set_id(current_core);
|
||||
tpu_core.set_core_on_host_index(current_core);
|
||||
*(pod_info_.add_local_core()) = tpu_core;
|
||||
}
|
||||
|
||||
// We are setting host_id to zero because we want this to look like one
|
||||
// host with many cores from the perspective of tpu_client.cc.
|
||||
tpu_chip.set_host_id(0);
|
||||
}
|
||||
|
||||
pod_info_.set_chip_count(pod_info_.tpu_chip_size());
|
||||
pod_info_.set_core_count(pod_info_.local_core_size());
|
||||
|
||||
// We want this to look like one host with many TPU chips/cores connected.
|
||||
pod_info_.set_host_count(1);
|
||||
pod_info_.set_host_id(0);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user