Add fixes to counting total TPU chips and cores to conform with existing JAX usage

PiperOrigin-RevId: 331189252
Change-Id: I3496579588631d61703f35908a0daadec844b981
This commit is contained in:
Frank Chen 2020-09-11 11:35:32 -07:00 committed by TensorFlower Gardener
parent cbc4001ecd
commit 92a4daaa73

View File

@ -183,28 +183,54 @@ class PodTpuDriver : public TpuDriver {
CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie());
}
int cumulative_core_id = 0;
absl::flat_hash_set<std::tuple<int, int, int>> processed_chips;
for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
SystemInfo driver_info;
drivers_[driver_num]->QuerySystemInfo(&driver_info);
for (const auto& tpu_chip : driver_info.tpu_chip()) {
*(pod_info_.add_tpu_chip()) = tpu_chip;
std::tuple<int, int, int> coord{tpu_chip.chip_coord().x(),
tpu_chip.chip_coord().y(),
tpu_chip.chip_coord().z()};
// We only want to add chips that we have not seen before if we are in a
// TPU pod slice, or we are only seeing local cores (e.g. we are
// connected to individual TPUs or we are in a test environment).
if (!processed_chips.contains(coord) ||
driver_info.core_count() == driver_info.local_core_size()) {
*(pod_info_.add_tpu_chip()) = tpu_chip;
processed_chips.insert(coord);
}
}
int core_num = 0;
for (const auto& tpu_core : driver_info.local_core()) {
*(pod_info_.add_local_core()) = tpu_core;
core_to_driver_.push_back(drivers_[driver_num].get());
core_to_driver_id_.push_back(driver_num);
core_to_driver_core_.push_back(core_num++);
}
*(pod_info_.mutable_cpu()) = driver_info.cpu();
pod_info_.set_host_count(pod_info_.host_count() + 1);
pod_info_.set_chip_count(pod_info_.chip_count() +
driver_info.chip_count());
pod_info_.set_core_count(pod_info_.core_count() +
driver_info.core_count());
}
// Process all the unique chips that we have seen.
for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) {
for (auto& tpu_core : *tpu_chip.mutable_core()) {
int current_core = cumulative_core_id++;
core_to_driver_.push_back(drivers_[tpu_chip.host_id()].get());
core_to_driver_id_.push_back(tpu_chip.host_id());
core_to_driver_core_.push_back(tpu_core.id());
tpu_core.set_id(current_core);
tpu_core.set_core_on_host_index(current_core);
*(pod_info_.add_local_core()) = tpu_core;
}
// We are setting host_id to zero because we want this to look like one
// host with many cores from the perspective of tpu_client.cc.
tpu_chip.set_host_id(0);
}
pod_info_.set_chip_count(pod_info_.tpu_chip_size());
pod_info_.set_core_count(pod_info_.local_core_size());
// We want this to look like one host with many TPU chips/cores connected.
pod_info_.set_host_count(1);
pod_info_.set_host_id(0);
}