From de5bfc8bb62ee0b73f0d84e9f7f37b83475f4371 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Fri, 5 Mar 2021 15:09:45 -0800 Subject: [PATCH] [XLA:GPU] [NFC] Clarify types in AllReduceCombiner PiperOrigin-RevId: 361234109 Change-Id: I5a89b41f600d6e03c75366fb41e1b5fdbefe9f1a --- .../xla/service/all_reduce_combiner.cc | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner.cc b/tensorflow/compiler/xla/service/all_reduce_combiner.cc index 5fb4935a4b1..01363ced8a1 100644 --- a/tensorflow/compiler/xla/service/all_reduce_combiner.cc +++ b/tensorflow/compiler/xla/service/all_reduce_combiner.cc @@ -134,8 +134,8 @@ struct GroupKey { return replica_groups.size() < other.replica_groups.size(); } for (int64 i = 0; i < replica_groups.size(); ++i) { - const auto& rg = replica_groups[i]; - const auto& org = other.replica_groups[i]; + const ReplicaGroup& rg = replica_groups[i]; + const ReplicaGroup& org = other.replica_groups[i]; if (rg.replica_ids_size() != org.replica_ids_size()) { return rg.replica_ids_size() < org.replica_ids_size(); } @@ -281,8 +281,9 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation)); - for (auto group : groups) { + TF_ASSIGN_OR_RETURN(InstructionGroups groups, + CreateComputationGroups(computation)); + for (std::vector> group : groups) { // Recompute reachability after every combine group because we can't // maintain a cross group topolgical order to be able to rely on the // transitive dependencies to detect cycles. @@ -296,7 +297,7 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { // dependencies. absl::flat_hash_map*> group_map; - for (auto& instruction : group) { + for (std::vector& instruction : group) { group_map[instruction.front()] = &instruction; } @@ -326,7 +327,7 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { // Instruction belongs to a different group. continue; } - const auto& instructions = *it->second; + const std::vector& instructions = *it->second; VLOG(1) << "Considering HLO " << instructions.front()->ToString() << " with current set size of " << current_size_in_bytes @@ -360,7 +361,8 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { // instruction that all the other ones depend on (such as one on the // forward pass of a model) could disable this optimization entirely. TF_RET_CHECK(!combine_sets.empty()); - for (const auto& previous : combine_sets.back()) { + for (const std::vector& previous : + combine_sets.back()) { // The reachability information does not reflect the planned // combination from combine_sets. We cannot just bring it up to date // cheaply since HloReachabilityMap does not track reachability @@ -445,13 +447,14 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { << " operands"; // Combine the collected sets of AllReduce instructions. - for (const auto& combine_set : combine_sets) { + for (const std::vector>& combine_set : + combine_sets) { if (combine_set.size() >= 2) { changed = true; for (int64 i = 0; i < combine_set.front().size(); ++i) { std::vector to_combine; to_combine.reserve(combine_set.size()); - for (const auto& c : combine_set) { + for (const std::vector& c : combine_set) { to_combine.push_back(c[i]); } TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));