From 92a4daaa73b8a4e109ffba0d2904efee7fa8e5b2 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 11 Sep 2020 11:35:32 -0700 Subject: [PATCH] Add fixes to counting total TPU chips and cores to conform with existing JAX usage PiperOrigin-RevId: 331189252 Change-Id: I3496579588631d61703f35908a0daadec844b981 --- .../xla/python/tpu_driver/pod_tpu_driver.cc | 52 ++++++++++++++----- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc index 114014e4e13..ac54df39895 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -183,28 +183,54 @@ class PodTpuDriver : public TpuDriver { CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie()); } + int cumulative_core_id = 0; + absl::flat_hash_set> processed_chips; + for (int driver_num = 0; driver_num < workers.size(); ++driver_num) { SystemInfo driver_info; drivers_[driver_num]->QuerySystemInfo(&driver_info); for (const auto& tpu_chip : driver_info.tpu_chip()) { - *(pod_info_.add_tpu_chip()) = tpu_chip; + std::tuple coord{tpu_chip.chip_coord().x(), + tpu_chip.chip_coord().y(), + tpu_chip.chip_coord().z()}; + // We only want to add chips that we have not seen before if we are in a + // TPU pod slice, or we are only seeing local cores (e.g. we are + // connected to individual TPUs or we are in a test environment). + if (!processed_chips.contains(coord) || + driver_info.core_count() == driver_info.local_core_size()) { + *(pod_info_.add_tpu_chip()) = tpu_chip; + processed_chips.insert(coord); + } } - int core_num = 0; - for (const auto& tpu_core : driver_info.local_core()) { - *(pod_info_.add_local_core()) = tpu_core; - core_to_driver_.push_back(drivers_[driver_num].get()); - core_to_driver_id_.push_back(driver_num); - core_to_driver_core_.push_back(core_num++); - } *(pod_info_.mutable_cpu()) = driver_info.cpu(); - pod_info_.set_host_count(pod_info_.host_count() + 1); - pod_info_.set_chip_count(pod_info_.chip_count() + - driver_info.chip_count()); - pod_info_.set_core_count(pod_info_.core_count() + - driver_info.core_count()); } + + // Process all the unique chips that we have seen. + for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) { + for (auto& tpu_core : *tpu_chip.mutable_core()) { + int current_core = cumulative_core_id++; + + core_to_driver_.push_back(drivers_[tpu_chip.host_id()].get()); + core_to_driver_id_.push_back(tpu_chip.host_id()); + core_to_driver_core_.push_back(tpu_core.id()); + + tpu_core.set_id(current_core); + tpu_core.set_core_on_host_index(current_core); + *(pod_info_.add_local_core()) = tpu_core; + } + + // We are setting host_id to zero because we want this to look like one + // host with many cores from the perspective of tpu_client.cc. + tpu_chip.set_host_id(0); + } + + pod_info_.set_chip_count(pod_info_.tpu_chip_size()); + pod_info_.set_core_count(pod_info_.local_core_size()); + + // We want this to look like one host with many TPU chips/cores connected. + pod_info_.set_host_count(1); pod_info_.set_host_id(0); }