[XLA:GPU] [NFC] Clarify types in AllReduceCombiner

PiperOrigin-RevId: 361234109
Change-Id: I5a89b41f600d6e03c75366fb41e1b5fdbefe9f1a
This commit is contained in:
George Karpenkov 2021-03-05 15:09:45 -08:00 committed by TensorFlower Gardener
parent d605c2c9d3
commit de5bfc8bb6

View File

@ -134,8 +134,8 @@ struct GroupKey {
return replica_groups.size() < other.replica_groups.size();
}
for (int64 i = 0; i < replica_groups.size(); ++i) {
const auto& rg = replica_groups[i];
const auto& org = other.replica_groups[i];
const ReplicaGroup& rg = replica_groups[i];
const ReplicaGroup& org = other.replica_groups[i];
if (rg.replica_ids_size() != org.replica_ids_size()) {
return rg.replica_ids_size() < org.replica_ids_size();
}
@ -281,8 +281,9 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
bool changed = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation));
for (auto group : groups) {
TF_ASSIGN_OR_RETURN(InstructionGroups groups,
CreateComputationGroups(computation));
for (std::vector<std::vector<HloInstruction*>> group : groups) {
// Recompute reachability after every combine group because we can't
// maintain a cross group topolgical order to be able to rely on the
// transitive dependencies to detect cycles.
@ -296,7 +297,7 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
// dependencies.
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>*>
group_map;
for (auto& instruction : group) {
for (std::vector<HloInstruction*>& instruction : group) {
group_map[instruction.front()] = &instruction;
}
@ -326,7 +327,7 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
// Instruction belongs to a different group.
continue;
}
const auto& instructions = *it->second;
const std::vector<HloInstruction*>& instructions = *it->second;
VLOG(1) << "Considering HLO " << instructions.front()->ToString()
<< " with current set size of " << current_size_in_bytes
@ -360,7 +361,8 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
// instruction that all the other ones depend on (such as one on the
// forward pass of a model) could disable this optimization entirely.
TF_RET_CHECK(!combine_sets.empty());
for (const auto& previous : combine_sets.back()) {
for (const std::vector<HloInstruction*>& previous :
combine_sets.back()) {
// The reachability information does not reflect the planned
// combination from combine_sets. We cannot just bring it up to date
// cheaply since HloReachabilityMap does not track reachability
@ -445,13 +447,14 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
<< " operands";
// Combine the collected sets of AllReduce instructions.
for (const auto& combine_set : combine_sets) {
for (const std::vector<std::vector<HloInstruction*>>& combine_set :
combine_sets) {
if (combine_set.size() >= 2) {
changed = true;
for (int64 i = 0; i < combine_set.front().size(); ++i) {
std::vector<HloInstruction*> to_combine;
to_combine.reserve(combine_set.size());
for (const auto& c : combine_set) {
for (const std::vector<HloInstruction*>& c : combine_set) {
to_combine.push_back(c[i]);
}
TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));