[XLA:GPU] [NFC] Clarify types in AllReduceCombiner
PiperOrigin-RevId: 361234109 Change-Id: I5a89b41f600d6e03c75366fb41e1b5fdbefe9f1a
This commit is contained in:
parent
d605c2c9d3
commit
de5bfc8bb6
@ -134,8 +134,8 @@ struct GroupKey {
|
||||
return replica_groups.size() < other.replica_groups.size();
|
||||
}
|
||||
for (int64 i = 0; i < replica_groups.size(); ++i) {
|
||||
const auto& rg = replica_groups[i];
|
||||
const auto& org = other.replica_groups[i];
|
||||
const ReplicaGroup& rg = replica_groups[i];
|
||||
const ReplicaGroup& org = other.replica_groups[i];
|
||||
if (rg.replica_ids_size() != org.replica_ids_size()) {
|
||||
return rg.replica_ids_size() < org.replica_ids_size();
|
||||
}
|
||||
@ -281,8 +281,9 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
|
||||
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation));
|
||||
for (auto group : groups) {
|
||||
TF_ASSIGN_OR_RETURN(InstructionGroups groups,
|
||||
CreateComputationGroups(computation));
|
||||
for (std::vector<std::vector<HloInstruction*>> group : groups) {
|
||||
// Recompute reachability after every combine group because we can't
|
||||
// maintain a cross group topolgical order to be able to rely on the
|
||||
// transitive dependencies to detect cycles.
|
||||
@ -296,7 +297,7 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
|
||||
// dependencies.
|
||||
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>*>
|
||||
group_map;
|
||||
for (auto& instruction : group) {
|
||||
for (std::vector<HloInstruction*>& instruction : group) {
|
||||
group_map[instruction.front()] = &instruction;
|
||||
}
|
||||
|
||||
@ -326,7 +327,7 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
|
||||
// Instruction belongs to a different group.
|
||||
continue;
|
||||
}
|
||||
const auto& instructions = *it->second;
|
||||
const std::vector<HloInstruction*>& instructions = *it->second;
|
||||
|
||||
VLOG(1) << "Considering HLO " << instructions.front()->ToString()
|
||||
<< " with current set size of " << current_size_in_bytes
|
||||
@ -360,7 +361,8 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
|
||||
// instruction that all the other ones depend on (such as one on the
|
||||
// forward pass of a model) could disable this optimization entirely.
|
||||
TF_RET_CHECK(!combine_sets.empty());
|
||||
for (const auto& previous : combine_sets.back()) {
|
||||
for (const std::vector<HloInstruction*>& previous :
|
||||
combine_sets.back()) {
|
||||
// The reachability information does not reflect the planned
|
||||
// combination from combine_sets. We cannot just bring it up to date
|
||||
// cheaply since HloReachabilityMap does not track reachability
|
||||
@ -445,13 +447,14 @@ StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
|
||||
<< " operands";
|
||||
|
||||
// Combine the collected sets of AllReduce instructions.
|
||||
for (const auto& combine_set : combine_sets) {
|
||||
for (const std::vector<std::vector<HloInstruction*>>& combine_set :
|
||||
combine_sets) {
|
||||
if (combine_set.size() >= 2) {
|
||||
changed = true;
|
||||
for (int64 i = 0; i < combine_set.front().size(); ++i) {
|
||||
std::vector<HloInstruction*> to_combine;
|
||||
to_combine.reserve(combine_set.size());
|
||||
for (const auto& c : combine_set) {
|
||||
for (const std::vector<HloInstruction*>& c : combine_set) {
|
||||
to_combine.push_back(c[i]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user