Change all_reduce_id to use channel_id
PiperOrigin-RevId: 255315385
This commit is contained in:
parent
6ae9600988
commit
852061b75b
@ -2068,7 +2068,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (channel_id.has_value()) {
|
if (channel_id.has_value()) {
|
||||||
instr.set_all_reduce_id(channel_id->handle());
|
instr.set_channel_id(channel_id->handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
AddCalledComputation(computation, &instr);
|
AddCalledComputation(computation, &instr);
|
||||||
|
@ -319,7 +319,7 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
|||||||
auto maybe_pair = MatchesArCrsPattern(instruction);
|
auto maybe_pair = MatchesArCrsPattern(instruction);
|
||||||
if (maybe_pair) {
|
if (maybe_pair) {
|
||||||
auto pair = *maybe_pair;
|
auto pair = *maybe_pair;
|
||||||
int64 ar_id = *(instruction->all_reduce_id());
|
int64 ar_id = *(instruction->channel_id());
|
||||||
if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
|
if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -365,9 +365,10 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
|||||||
|
|
||||||
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||||
for (auto it : all_reduce_map_) {
|
for (auto it : all_reduce_map_) {
|
||||||
auto all_reduce_id = it.first;
|
auto channel_id = it.first;
|
||||||
VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking ar_id: "
|
VLOG(2)
|
||||||
<< all_reduce_id << "\n";
|
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
|
||||||
|
<< channel_id << "\n";
|
||||||
auto pairs_vec = it.second;
|
auto pairs_vec = it.second;
|
||||||
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
|
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
|
||||||
auto instr_0 = pairs_vec[0].ar;
|
auto instr_0 = pairs_vec[0].ar;
|
||||||
@ -378,9 +379,10 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
|||||||
absl::flat_hash_map<int64, int64> visited_pairs;
|
absl::flat_hash_map<int64, int64> visited_pairs;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
|
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
|
||||||
all_reduce_map_.erase(all_reduce_id);
|
all_reduce_map_.erase(channel_id);
|
||||||
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased ar_id: "
|
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
|
||||||
<< all_reduce_id << "\n";
|
"channel id: "
|
||||||
|
<< channel_id << "\n";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (next_0->IsCrossReplicaAllReduce()) {
|
if (next_0->IsCrossReplicaAllReduce()) {
|
||||||
@ -402,7 +404,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
|||||||
for (auto pair : pairs_vec) {
|
for (auto pair : pairs_vec) {
|
||||||
auto all_reduce = pair.ar;
|
auto all_reduce = pair.ar;
|
||||||
auto parent_computation = all_reduce->parent();
|
auto parent_computation = all_reduce->parent();
|
||||||
auto all_reduce_id = all_reduce->all_reduce_id();
|
auto channel_id = all_reduce->channel_id();
|
||||||
auto prev = all_reduce->mutable_operand(0);
|
auto prev = all_reduce->mutable_operand(0);
|
||||||
auto next = all_reduce->users()[0];
|
auto next = all_reduce->users()[0];
|
||||||
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
|
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
|
||||||
@ -447,7 +449,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
|||||||
next = next->users()[0];
|
next = next->users()[0];
|
||||||
}
|
}
|
||||||
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
||||||
next->set_all_reduce_id(all_reduce_id);
|
next->set_channel_id(channel_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -36,7 +36,7 @@ namespace xla {
|
|||||||
//
|
//
|
||||||
// The steps are:
|
// The steps are:
|
||||||
// 1) Find CMARs followed by simple ops followed by CRARs.
|
// 1) Find CMARs followed by simple ops followed by CRARs.
|
||||||
// 2) Group CMARs by all_reduce_id. They must all be rewritten.
|
// 2) Group CMARs by channel_id. They must all be rewritten.
|
||||||
// 3) Prove that the CMAR patterns in each core produce the same result.
|
// 3) Prove that the CMAR patterns in each core produce the same result.
|
||||||
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
|
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
|
||||||
// other operand by the number of spatial partitions.
|
// other operand by the number of spatial partitions.
|
||||||
@ -104,7 +104,7 @@ class ArCrsCombiner : public HloModulePass {
|
|||||||
}
|
}
|
||||||
pieces.push_back(instruction->name());
|
pieces.push_back(instruction->name());
|
||||||
pieces.push_back(")[id:");
|
pieces.push_back(")[id:");
|
||||||
pieces.push_back(std::to_string(*(ar->all_reduce_id())));
|
pieces.push_back(std::to_string(*(ar->channel_id())));
|
||||||
pieces.push_back(",dist:");
|
pieces.push_back(",dist:");
|
||||||
pieces.push_back(std::to_string(distance));
|
pieces.push_back(std::to_string(distance));
|
||||||
pieces.push_back("]");
|
pieces.push_back("]");
|
||||||
|
@ -414,7 +414,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = bf16[]
|
%all-reduce.ar.1 = bf16[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%convert.1 = f32[]
|
%convert.1 = f32[]
|
||||||
@ -429,7 +429,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = bf16[]
|
%all-reduce.ar.2 = bf16[]
|
||||||
all-reduce(%constant.bf16),
|
all-reduce(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%convert.2 = f32[]
|
%convert.2 = f32[]
|
||||||
@ -486,7 +486,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
|||||||
%all-reduce.ar.1 = f32[2,1]
|
%all-reduce.ar.1 = f32[2,1]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.1,
|
to_apply=%sum.1,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
|
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
|
||||||
@ -499,7 +499,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
|||||||
%all-reduce.ar.2 = f32[2,1]
|
%all-reduce.ar.2 = f32[2,1]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.1,
|
to_apply=%sum.1,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
|
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
|
||||||
@ -549,7 +549,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = f32[]
|
%all-reduce.ar.1 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%multiply.1 = f32[]
|
%multiply.1 = f32[]
|
||||||
@ -564,7 +564,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = f32[]
|
%all-reduce.ar.2 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%multiply.2 = f32[]
|
%multiply.2 = f32[]
|
||||||
@ -624,7 +624,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = bf16[]
|
%all-reduce.ar.1 = bf16[]
|
||||||
all-reduce(%constant.bf16),
|
all-reduce(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%convert.1 = f32[]
|
%convert.1 = f32[]
|
||||||
@ -642,7 +642,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = bf16[]
|
%all-reduce.ar.2 = bf16[]
|
||||||
all-reduce(%constant.bf16),
|
all-reduce(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%convert.2 = f32[]
|
%convert.2 = f32[]
|
||||||
@ -709,7 +709,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = bf16[]
|
%all-reduce.ar.1 = bf16[]
|
||||||
all-reduce(%constant.bf16),
|
all-reduce(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%convert.1 = f32[]
|
%convert.1 = f32[]
|
||||||
@ -727,7 +727,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = bf16[]
|
%all-reduce.ar.2 = bf16[]
|
||||||
all-reduce(%constant.bf16),
|
all-reduce(%constant.bf16),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%convert.2 = f32[]
|
%convert.2 = f32[]
|
||||||
@ -772,7 +772,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = f32[]
|
%all-reduce.ar.1 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.1,
|
to_apply=%sum.1,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%all-reduce.1 = f32[]
|
%all-reduce.1 = f32[]
|
||||||
@ -787,7 +787,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = f32[]
|
%all-reduce.ar.2 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.1,
|
to_apply=%sum.1,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%all-reduce.2 = f32[]
|
%all-reduce.2 = f32[]
|
||||||
@ -840,7 +840,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = f32[]
|
%all-reduce.ar.1 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add.11 = f32[]
|
%add.11 = f32[]
|
||||||
@ -858,7 +858,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = f32[]
|
%all-reduce.ar.2 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add.21 = f32[]
|
%add.21 = f32[]
|
||||||
@ -919,7 +919,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = f32[]
|
%all-reduce.ar.1 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%sub.1 = f32[]
|
%sub.1 = f32[]
|
||||||
@ -934,7 +934,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = f32[]
|
%all-reduce.ar.2 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.f32,
|
to_apply=%sum.f32,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%sub.2 = f32[]
|
%sub.2 = f32[]
|
||||||
@ -991,7 +991,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%ar11 = f32[]
|
%ar11 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add11 = f32[]
|
%add11 = f32[]
|
||||||
@ -1000,7 +1000,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%ar12 = f32[]
|
%ar12 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=2,
|
channel_id=2,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add12 = f32[]
|
%add12 = f32[]
|
||||||
@ -1015,7 +1015,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%ar21 = f32[]
|
%ar21 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%add21 = f32[]
|
%add21 = f32[]
|
||||||
@ -1024,7 +1024,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%ar22 = f32[]
|
%ar22 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=2,
|
channel_id=2,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%add22 = f32[]
|
%add22 = f32[]
|
||||||
@ -1083,13 +1083,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%ar11 = f32[]
|
%ar11 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%ar12 = f32[]
|
%ar12 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=2,
|
channel_id=2,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%add11 = f32[]
|
%add11 = f32[]
|
||||||
@ -1107,13 +1107,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
|||||||
%ar21 = f32[]
|
%ar21 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%ar22 = f32[]
|
%ar22 = f32[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0},{1}},
|
replica_groups={{0},{1}},
|
||||||
all_reduce_id=2,
|
channel_id=2,
|
||||||
to_apply=%sum,
|
to_apply=%sum,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%add21 = f32[]
|
%add21 = f32[]
|
||||||
@ -1182,7 +1182,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.1 = bf16[]
|
%all-reduce.ar.1 = bf16[]
|
||||||
all-reduce(%p),
|
all-reduce(%p),
|
||||||
replica_groups={{0}},
|
replica_groups={{0}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
%convert.1 = f32[]
|
%convert.1 = f32[]
|
||||||
@ -1197,7 +1197,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
|||||||
%all-reduce.ar.2 = bf16[]
|
%all-reduce.ar.2 = bf16[]
|
||||||
all-reduce(%constant.bf16),
|
all-reduce(%constant.bf16),
|
||||||
replica_groups={{0}},
|
replica_groups={{0}},
|
||||||
all_reduce_id=1,
|
channel_id=1,
|
||||||
to_apply=%sum.bf16,
|
to_apply=%sum.bf16,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
%convert.2 = f32[]
|
%convert.2 = f32[]
|
||||||
@ -1276,9 +1276,9 @@ HloModule foobar
|
|||||||
|
|
||||||
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||||
%p = bf16[] parameter(0)
|
%p = bf16[] parameter(0)
|
||||||
%all-reduce.0 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}},
|
%all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32, sharding={maximal device=0}
|
to_apply=%sum.f32, sharding={maximal device=0}
|
||||||
%all-reduce.1 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}},
|
%all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32, sharding={maximal device=1}
|
to_apply=%sum.f32, sharding={maximal device=1}
|
||||||
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
|
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
|
||||||
to_apply=%sum.f32, sharding={maximal device=0}
|
to_apply=%sum.f32, sharding={maximal device=0}
|
||||||
|
@ -239,7 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
|
|||||||
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||||
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
|
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
|
||||||
/*replica_groups=*/{},
|
/*replica_groups=*/{},
|
||||||
/*all_reduce_id=*/absl::nullopt));
|
/*channel_id=*/absl::nullopt));
|
||||||
HloInstruction* gte_a = builder.AddInstruction(
|
HloInstruction* gte_a = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
|
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
|
||||||
HloInstruction* gte_b = builder.AddInstruction(
|
HloInstruction* gte_b = builder.AddInstruction(
|
||||||
|
@ -257,7 +257,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
|
|||||||
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||||
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
|
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
|
||||||
/*replica_groups=*/{},
|
/*replica_groups=*/{},
|
||||||
/*all_reduce_id=*/absl::nullopt));
|
/*channel_id=*/absl::nullopt));
|
||||||
HloInstruction* gte = builder.AddInstruction(
|
HloInstruction* gte = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
|
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
|
|||||||
HloInstruction* all_reduce =
|
HloInstruction* all_reduce =
|
||||||
builder.AddInstruction(HloInstruction::CreateAllReduce(
|
builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||||
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
|
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
|
||||||
/*replica_groups=*/{}, /*all_reduce_id=*/1));
|
/*replica_groups=*/{}, /*channel_id=*/1));
|
||||||
HloInstruction* gte0 = builder.AddInstruction(
|
HloInstruction* gte0 = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
|
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
|
||||||
HloInstruction* gte1 = builder.AddInstruction(
|
HloInstruction* gte1 = builder.AddInstruction(
|
||||||
|
@ -177,13 +177,13 @@ class NcclComm {
|
|||||||
// * Only ops with the same opcode can communicate with each other. At the
|
// * Only ops with the same opcode can communicate with each other. At the
|
||||||
// moment we only support kAllReduce, so we don't check for this explicitly.
|
// moment we only support kAllReduce, so we don't check for this explicitly.
|
||||||
//
|
//
|
||||||
// * For cross-module all-reduces (i.e. instr->all_reduce_id().has_value()),
|
// * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
|
||||||
// only ops with the same value for all_reduce_id() can communicate with each
|
// only ops with the same value for channel_id() can communicate with each
|
||||||
// other.
|
// other.
|
||||||
//
|
//
|
||||||
// * For cross-replica (i.e. same-module) all-reduces (i.e.
|
// * For cross-replica (i.e. same-module) all-reduces (i.e.
|
||||||
// !all_reduce_id().has_value()), only ops from the same module (as identified
|
// !channel_id().has_value()), only ops from the same module (as
|
||||||
// by its unique_id()) can communicate with each other.
|
// identified by its unique_id()) can communicate with each other.
|
||||||
//
|
//
|
||||||
struct RendezvousKey {
|
struct RendezvousKey {
|
||||||
enum AllReduceKind {
|
enum AllReduceKind {
|
||||||
@ -196,8 +196,8 @@ struct RendezvousKey {
|
|||||||
const HloAllReduceInstruction* instr)
|
const HloAllReduceInstruction* instr)
|
||||||
: run_id(run_id), participating_replicas(participating_replicas) {
|
: run_id(run_id), participating_replicas(participating_replicas) {
|
||||||
std::tie(all_reduce_kind, op_id) =
|
std::tie(all_reduce_kind, op_id) =
|
||||||
instr->all_reduce_id().has_value()
|
instr->channel_id().has_value()
|
||||||
? std::make_pair(kCrossModule, instr->all_reduce_id().value())
|
? std::make_pair(kCrossModule, instr->channel_id().value())
|
||||||
: std::make_pair(
|
: std::make_pair(
|
||||||
kCrossReplica,
|
kCrossReplica,
|
||||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
static_cast<int64>(instr->GetModule()->unique_id()));
|
||||||
|
@ -126,8 +126,9 @@ message HloInstructionProto {
|
|||||||
// Only present for kBatchNormTraining.
|
// Only present for kBatchNormTraining.
|
||||||
int64 feature_index = 25;
|
int64 feature_index = 25;
|
||||||
|
|
||||||
// Represents a unique identifier for each Send/Recv instruction pair.
|
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||||
// Only present for kSend or kRecv.
|
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||||
|
// AllToAll). Non-positive channel_id is equivalent to no channel id.
|
||||||
int64 channel_id = 26;
|
int64 channel_id = 26;
|
||||||
|
|
||||||
// The string representation of the infeed configuration.
|
// The string representation of the infeed configuration.
|
||||||
@ -174,7 +175,9 @@ message HloInstructionProto {
|
|||||||
|
|
||||||
// Cross replica op fields.
|
// Cross replica op fields.
|
||||||
repeated ReplicaGroup replica_groups = 49;
|
repeated ReplicaGroup replica_groups = 49;
|
||||||
int64 all_reduce_id = 45;
|
// Deprecated, but keeping it for backward compatibility. Use channel_id.
|
||||||
|
// Non-positive all_reduce_id is equivalent to no all_reduce_id.
|
||||||
|
int64 all_reduce_id = 45 [deprecated = true];
|
||||||
|
|
||||||
// Whether this Send/Recv instruction transfers data to/from the host. Only
|
// Whether this Send/Recv instruction transfers data to/from the host. Only
|
||||||
// present for Send and Recv instructions and their SendDone and RecvDone
|
// present for Send and Recv instructions and their SendDone and RecvDone
|
||||||
|
@ -373,7 +373,7 @@ void HloComputation::ComputeInstructionPostOrder(
|
|||||||
case HloOpcode::kRecvDone:
|
case HloOpcode::kRecvDone:
|
||||||
return inst->channel_id();
|
return inst->channel_id();
|
||||||
case HloOpcode::kAllReduce:
|
case HloOpcode::kAllReduce:
|
||||||
return inst->all_reduce_id();
|
return inst->channel_id();
|
||||||
default:
|
default:
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
@ -428,13 +428,10 @@ HloComputation::ComputeChannelDependencies() const {
|
|||||||
switch (instruction->opcode()) {
|
switch (instruction->opcode()) {
|
||||||
case HloOpcode::kSend:
|
case HloOpcode::kSend:
|
||||||
case HloOpcode::kRecvDone:
|
case HloOpcode::kRecvDone:
|
||||||
channel_dependency_group[instruction->channel_id()].push_back(
|
|
||||||
instruction.get());
|
|
||||||
break;
|
|
||||||
case HloOpcode::kAllReduce: {
|
case HloOpcode::kAllReduce: {
|
||||||
auto all_reduce_id = instruction->all_reduce_id();
|
auto channel_id = instruction->channel_id();
|
||||||
if (all_reduce_id) {
|
if (channel_id) {
|
||||||
channel_dependency_group[all_reduce_id.value()].push_back(
|
channel_dependency_group[channel_id.value()].push_back(
|
||||||
instruction.get());
|
instruction.get());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -688,10 +688,10 @@ add {
|
|||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
param = f32[128] parameter(0), sharding={maximal device=0}
|
param = f32[128] parameter(0), sharding={maximal device=0}
|
||||||
crs0 = f32[128] all-reduce(param),
|
crs0 = f32[128] all-reduce(param),
|
||||||
replica_groups={{0}}, all_reduce_id=1, to_apply=add,
|
replica_groups={{0}}, channel_id=1, to_apply=add,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
crs1 = f32[128] all-reduce(param),
|
crs1 = f32[128] all-reduce(param),
|
||||||
replica_groups={{0}}, all_reduce_id=1, to_apply=add,
|
replica_groups={{0}}, channel_id=1, to_apply=add,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
add = f32[128] add(crs0, crs0), sharding={maximal device=0}
|
add = f32[128] add(crs0, crs0), sharding={maximal device=0}
|
||||||
ROOT t = (f32[128], f32[128]) tuple(add, crs1)
|
ROOT t = (f32[128], f32[128]) tuple(add, crs1)
|
||||||
|
@ -383,16 +383,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
||||||
<< "AllReduce should have 1 called computation but sees "
|
<< "AllReduce should have 1 called computation but sees "
|
||||||
<< proto.called_computation_ids_size();
|
<< proto.called_computation_ids_size();
|
||||||
absl::optional<int64> all_reduce_id;
|
TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
|
||||||
|
<< "AllReduce cannot have both channel_id() and all_reduce_id()";
|
||||||
|
absl::optional<int64> channel_id;
|
||||||
|
if (proto.channel_id() > 0) {
|
||||||
|
channel_id = proto.channel_id();
|
||||||
|
}
|
||||||
if (proto.all_reduce_id() > 0) {
|
if (proto.all_reduce_id() > 0) {
|
||||||
all_reduce_id = proto.all_reduce_id();
|
channel_id = proto.all_reduce_id();
|
||||||
}
|
}
|
||||||
instruction = CreateAllReduce(
|
instruction = CreateAllReduce(
|
||||||
shape, all_operands(), computations(0),
|
shape, all_operands(), computations(0),
|
||||||
/*replica_groups=*/
|
/*replica_groups=*/
|
||||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||||
proto.replica_groups().end()),
|
proto.replica_groups().end()),
|
||||||
/*all_reduce_id=*/all_reduce_id);
|
/*channel_id=*/channel_id);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kAllToAll: {
|
case HloOpcode::kAllToAll: {
|
||||||
@ -860,9 +865,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
|
|||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
HloComputation* reduce_computation,
|
HloComputation* reduce_computation,
|
||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<int64>& all_reduce_id) {
|
const absl::optional<int64>& channel_id) {
|
||||||
return absl::make_unique<HloAllReduceInstruction>(
|
return absl::make_unique<HloAllReduceInstruction>(
|
||||||
shape, operands, reduce_computation, replica_groups, all_reduce_id);
|
shape, operands, reduce_computation, replica_groups, channel_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
|
||||||
@ -1279,7 +1284,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
|
|||||||
case HloOpcode::kTrace:
|
case HloOpcode::kTrace:
|
||||||
return true;
|
return true;
|
||||||
case HloOpcode::kAllReduce:
|
case HloOpcode::kAllReduce:
|
||||||
return all_reduce_id().has_value();
|
return channel_id().has_value();
|
||||||
case HloOpcode::kCustomCall:
|
case HloOpcode::kCustomCall:
|
||||||
return Cast<HloCustomCallInstruction>(this)
|
return Cast<HloCustomCallInstruction>(this)
|
||||||
->custom_call_has_side_effect();
|
->custom_call_has_side_effect();
|
||||||
@ -2232,11 +2237,11 @@ bool HloInstruction::IsElementwiseImpl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool HloInstruction::IsCrossModuleAllReduce() const {
|
bool HloInstruction::IsCrossModuleAllReduce() const {
|
||||||
return opcode() == HloOpcode::kAllReduce && all_reduce_id();
|
return opcode() == HloOpcode::kAllReduce && channel_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloInstruction::IsCrossReplicaAllReduce() const {
|
bool HloInstruction::IsCrossReplicaAllReduce() const {
|
||||||
return opcode() == HloOpcode::kAllReduce && !all_reduce_id();
|
return opcode() == HloOpcode::kAllReduce && !channel_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
string HloInstruction::ToStringWithCanonicalNameMap(
|
string HloInstruction::ToStringWithCanonicalNameMap(
|
||||||
@ -3332,10 +3337,6 @@ const std::vector<int64>& HloInstruction::fft_length() const {
|
|||||||
return Cast<HloFftInstruction>(this)->fft_length();
|
return Cast<HloFftInstruction>(this)->fft_length();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 HloInstruction::channel_id() const {
|
|
||||||
return Cast<HloSendRecvInstruction>(this)->channel_id();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 HloInstruction::concatenate_dimension() const {
|
int64 HloInstruction::concatenate_dimension() const {
|
||||||
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
|
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
|
||||||
}
|
}
|
||||||
@ -3535,13 +3536,12 @@ HloInstruction::source_target_pairs() const {
|
|||||||
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
|
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<int64> HloInstruction::all_reduce_id() const {
|
absl::optional<int64> HloInstruction::channel_id() const {
|
||||||
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
|
return Cast<HloChannelInstruction>(this)->channel_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloInstruction::set_all_reduce_id(
|
void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
|
||||||
const absl::optional<int64>& all_reduce_id) {
|
return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
|
||||||
return Cast<HloAllReduceInstruction>(this)->set_all_reduce_id(all_reduce_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConvolutionDimensionNumbers&
|
const ConvolutionDimensionNumbers&
|
||||||
|
@ -497,14 +497,14 @@ class HloInstruction {
|
|||||||
// For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
|
// For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
|
||||||
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
|
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
|
||||||
//
|
//
|
||||||
// `all_reduce_id`: for Allreduce nodes from different modules, if they have
|
// `channel_id`: for Allreduce nodes from different modules, if
|
||||||
// the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
|
// they have the same channel_id, they will be 'Allreduce'd. If
|
||||||
// not be applied cross modules.
|
// empty, Allreduce will not be applied cross modules.
|
||||||
static std::unique_ptr<HloInstruction> CreateAllReduce(
|
static std::unique_ptr<HloInstruction> CreateAllReduce(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
HloComputation* reduce_computation,
|
HloComputation* reduce_computation,
|
||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<int64>& all_reduce_id);
|
const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
// An all-to-all op takes N array operands of the same shape and scatters them
|
// An all-to-all op takes N array operands of the same shape and scatters them
|
||||||
// to N replicas. Each replica gathers the results into a tuple.
|
// to N replicas. Each replica gathers the results into a tuple.
|
||||||
@ -952,7 +952,7 @@ class HloInstruction {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Two AllReduces are Identical if they have the same all_reduce_id.
|
// Two AllReduces are Identical if they have the same channel_id.
|
||||||
// Their operands don't have to be Identical.
|
// Their operands don't have to be Identical.
|
||||||
if (!IsCrossModuleAllReduce()) {
|
if (!IsCrossModuleAllReduce()) {
|
||||||
// Use an explicit loop rather than ContainerEquals, because copying
|
// Use an explicit loop rather than ContainerEquals, because copying
|
||||||
@ -1428,8 +1428,9 @@ class HloInstruction {
|
|||||||
// Delegates to HloFftInstruction::fft_length.
|
// Delegates to HloFftInstruction::fft_length.
|
||||||
const std::vector<int64>& fft_length() const;
|
const std::vector<int64>& fft_length() const;
|
||||||
|
|
||||||
// Delegates to HloSendRecvInstruction::channel_id.
|
// Delegates to HloChannelInstruction::channel_id.
|
||||||
int64 channel_id() const;
|
absl::optional<int64> channel_id() const;
|
||||||
|
void set_channel_id(const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
// Returns the dimension sizes or numbers associated with this instruction.
|
// Returns the dimension sizes or numbers associated with this instruction.
|
||||||
virtual const std::vector<int64>& dimensions() const {
|
virtual const std::vector<int64>& dimensions() const {
|
||||||
@ -1571,10 +1572,6 @@ class HloInstruction {
|
|||||||
// Delegates to HloCollectivePermuteInstruction::source_target_pairs.
|
// Delegates to HloCollectivePermuteInstruction::source_target_pairs.
|
||||||
const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
|
const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
|
||||||
|
|
||||||
// Delegates to HloAllReduceInstruction::all_reduce_id.
|
|
||||||
absl::optional<int64> all_reduce_id() const;
|
|
||||||
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
|
|
||||||
|
|
||||||
// Returns data on the window in a windowed operation such as
|
// Returns data on the window in a windowed operation such as
|
||||||
// convolution.
|
// convolution.
|
||||||
virtual const Window& window() const {
|
virtual const Window& window() const {
|
||||||
|
@ -361,25 +361,60 @@ HloCholeskyInstruction::CloneWithNewOperandsImpl(
|
|||||||
cholesky_options());
|
cholesky_options());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloChannelInstruction::HloChannelInstruction(
|
||||||
|
HloOpcode opcode, const Shape& shape,
|
||||||
|
const absl::optional<int64>& channel_id)
|
||||||
|
: HloInstruction(opcode, shape), channel_id_(channel_id) {}
|
||||||
|
|
||||||
|
void HloChannelInstruction::set_channel_id(
|
||||||
|
const absl::optional<int64>& channel_id) {
|
||||||
|
channel_id_ = channel_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstructionProto HloChannelInstruction::ToProto() const {
|
||||||
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
|
if (channel_id_) {
|
||||||
|
CHECK_GT(channel_id_.value(), 0)
|
||||||
|
<< "Non-positive channel id is equivalent to no channel id";
|
||||||
|
proto.set_channel_id(*channel_id_);
|
||||||
|
}
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
|
||||||
|
const HloPrintOptions& /*options*/) const {
|
||||||
|
std::vector<string> result;
|
||||||
|
if (channel_id_) {
|
||||||
|
result.push_back(StrCat("channel_id=", *channel_id_));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloChannelInstruction::IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
/*eq_computations*/) const {
|
||||||
|
const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
|
||||||
|
return channel_id() == casted_other.channel_id();
|
||||||
|
}
|
||||||
|
|
||||||
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
|
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
int64 channel_id,
|
int64 channel_id,
|
||||||
bool is_host_transfer)
|
bool is_host_transfer)
|
||||||
: HloInstruction(opcode, shape),
|
: HloChannelInstruction(opcode, shape, channel_id),
|
||||||
channel_id_(channel_id),
|
|
||||||
is_host_transfer_(is_host_transfer) {}
|
is_host_transfer_(is_host_transfer) {}
|
||||||
|
|
||||||
HloInstructionProto HloSendRecvInstruction::ToProto() const {
|
HloInstructionProto HloSendRecvInstruction::ToProto() const {
|
||||||
HloInstructionProto proto = HloInstruction::ToProto();
|
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||||
proto.set_channel_id(channel_id_);
|
|
||||||
proto.set_is_host_transfer(is_host_transfer_);
|
proto.set_is_host_transfer(is_host_transfer_);
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
|
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& options) const {
|
const HloPrintOptions& options) const {
|
||||||
std::vector<string> attrs;
|
std::vector<string> attrs =
|
||||||
attrs.push_back(StrCat("channel_id=", channel_id_));
|
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||||
if (is_host_transfer()) {
|
if (is_host_transfer()) {
|
||||||
attrs.push_back("is_host_transfer=true");
|
attrs.push_back("is_host_transfer=true");
|
||||||
}
|
}
|
||||||
@ -413,13 +448,13 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
|
|||||||
HloCloneContext* context) const {
|
HloCloneContext* context) const {
|
||||||
CHECK_EQ(new_operands.size(), 2);
|
CHECK_EQ(new_operands.size(), 2);
|
||||||
return absl::make_unique<HloSendInstruction>(
|
return absl::make_unique<HloSendInstruction>(
|
||||||
new_operands[0], new_operands[1], channel_id(), is_host_transfer());
|
new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
|
||||||
}
|
}
|
||||||
|
|
||||||
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
|
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
|
||||||
bool is_host_transfer)
|
bool is_host_transfer)
|
||||||
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
|
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
|
||||||
CHECK_NOTNULL(operand)->channel_id(),
|
CHECK_NOTNULL(operand)->channel_id().value(),
|
||||||
is_host_transfer) {
|
is_host_transfer) {
|
||||||
AppendOperand(operand);
|
AppendOperand(operand);
|
||||||
}
|
}
|
||||||
@ -450,7 +485,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
|
|||||||
HloCloneContext* context) const {
|
HloCloneContext* context) const {
|
||||||
CHECK_EQ(new_operands.size(), 1);
|
CHECK_EQ(new_operands.size(), 1);
|
||||||
return absl::make_unique<HloRecvInstruction>(
|
return absl::make_unique<HloRecvInstruction>(
|
||||||
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
|
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
|
||||||
is_host_transfer());
|
is_host_transfer());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -461,7 +496,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
|
|||||||
ShapeUtil::MakeTupleShape(
|
ShapeUtil::MakeTupleShape(
|
||||||
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
|
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
|
||||||
ShapeUtil::MakeTokenShape()}),
|
ShapeUtil::MakeTokenShape()}),
|
||||||
CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
|
CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
|
||||||
AppendOperand(operand);
|
AppendOperand(operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -477,32 +512,39 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
|
|||||||
HloCollectiveInstruction::HloCollectiveInstruction(
|
HloCollectiveInstruction::HloCollectiveInstruction(
|
||||||
HloOpcode opcode, const Shape& shape,
|
HloOpcode opcode, const Shape& shape,
|
||||||
absl::Span<HloInstruction* const> operands,
|
absl::Span<HloInstruction* const> operands,
|
||||||
const std::vector<ReplicaGroup>& replica_groups)
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
: HloInstruction(opcode, shape), replica_groups_(replica_groups) {
|
const absl::optional<int64>& channel_id)
|
||||||
|
: HloChannelInstruction(opcode, shape, channel_id),
|
||||||
|
replica_groups_({replica_groups.begin(), replica_groups.end()}) {
|
||||||
for (auto operand : operands) {
|
for (auto operand : operands) {
|
||||||
AppendOperand(operand);
|
AppendOperand(operand);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HloInstructionProto HloCollectiveInstruction::ToProto() const {
|
HloInstructionProto HloCollectiveInstruction::ToProto() const {
|
||||||
HloInstructionProto proto = HloInstruction::ToProto();
|
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||||
*proto.mutable_replica_groups() = {replica_groups_.begin(),
|
*proto.mutable_replica_groups() = {replica_groups_.begin(),
|
||||||
replica_groups_.end()};
|
replica_groups_.end()};
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
|
std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& /*options*/) const {
|
const HloPrintOptions& options) const {
|
||||||
return {StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))};
|
std::vector<string> result =
|
||||||
|
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||||
|
result.push_back(
|
||||||
|
StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloCollectiveInstruction::IdenticalSlowPath(
|
bool HloCollectiveInstruction::IdenticalSlowPath(
|
||||||
const HloInstruction& other,
|
const HloInstruction& other,
|
||||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
/*eq_computations*/) const {
|
eq_computations) const {
|
||||||
const auto& casted_other =
|
const auto& casted_other =
|
||||||
static_cast<const HloCollectiveInstruction&>(other);
|
static_cast<const HloCollectiveInstruction&>(other);
|
||||||
return absl::c_equal(replica_groups(), casted_other.replica_groups(),
|
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||||
|
absl::c_equal(replica_groups(), casted_other.replica_groups(),
|
||||||
[](const ReplicaGroup& a, const ReplicaGroup& b) {
|
[](const ReplicaGroup& a, const ReplicaGroup& b) {
|
||||||
return absl::c_equal(a.replica_ids(), b.replica_ids());
|
return absl::c_equal(a.replica_ids(), b.replica_ids());
|
||||||
});
|
});
|
||||||
@ -512,44 +554,19 @@ HloAllReduceInstruction::HloAllReduceInstruction(
|
|||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
HloComputation* reduce_computation,
|
HloComputation* reduce_computation,
|
||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<int64>& all_reduce_id)
|
const absl::optional<int64>& channel_id)
|
||||||
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
|
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
|
||||||
replica_groups),
|
replica_groups, channel_id) {
|
||||||
all_reduce_id_(all_reduce_id) {
|
|
||||||
AppendComputation(reduce_computation);
|
AppendComputation(reduce_computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloAllReduceInstruction::set_all_reduce_id(
|
|
||||||
const absl::optional<int64>& all_reduce_id) {
|
|
||||||
all_reduce_id_ = all_reduce_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
HloInstructionProto HloAllReduceInstruction::ToProto() const {
|
|
||||||
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
|
||||||
// Proto3 is so sad.
|
|
||||||
if (all_reduce_id_) {
|
|
||||||
proto.set_all_reduce_id(*all_reduce_id_);
|
|
||||||
}
|
|
||||||
return proto;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HloAllReduceInstruction::IsNoop() const {
|
bool HloAllReduceInstruction::IsNoop() const {
|
||||||
for (auto replica_group : replica_groups()) {
|
for (auto replica_group : replica_groups()) {
|
||||||
if (replica_group.replica_ids().size() != 1) {
|
if (replica_group.replica_ids().size() != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return !all_reduce_id();
|
return !channel_id();
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
|
|
||||||
const HloPrintOptions& options) const {
|
|
||||||
std::vector<string> result =
|
|
||||||
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
|
|
||||||
if (all_reduce_id_) {
|
|
||||||
result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloAllReduceInstruction::IdenticalSlowPath(
|
bool HloAllReduceInstruction::IdenticalSlowPath(
|
||||||
@ -558,8 +575,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
|
|||||||
eq_computations) const {
|
eq_computations) const {
|
||||||
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
|
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
|
||||||
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
|
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||||
eq_computations(to_apply(), casted_other.to_apply()) &&
|
eq_computations(to_apply(), casted_other.to_apply());
|
||||||
all_reduce_id() == casted_other.all_reduce_id();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloInstruction>
|
std::unique_ptr<HloInstruction>
|
||||||
@ -567,14 +583,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
|
|||||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||||
HloCloneContext* /*context*/) const {
|
HloCloneContext* /*context*/) const {
|
||||||
return absl::make_unique<HloAllReduceInstruction>(
|
return absl::make_unique<HloAllReduceInstruction>(
|
||||||
shape, new_operands, to_apply(), replica_groups(), all_reduce_id());
|
shape, new_operands, to_apply(), replica_groups(), channel_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
HloAllToAllInstruction::HloAllToAllInstruction(
|
HloAllToAllInstruction::HloAllToAllInstruction(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
const std::vector<ReplicaGroup>& replica_groups)
|
const std::vector<ReplicaGroup>& replica_groups)
|
||||||
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
|
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
|
||||||
replica_groups) {}
|
replica_groups, absl::nullopt) {}
|
||||||
|
|
||||||
std::unique_ptr<HloInstruction>
|
std::unique_ptr<HloInstruction>
|
||||||
HloAllToAllInstruction::CloneWithNewOperandsImpl(
|
HloAllToAllInstruction::CloneWithNewOperandsImpl(
|
||||||
@ -587,13 +603,14 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
|
|||||||
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
|
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
|
||||||
const Shape& shape, HloInstruction* operand,
|
const Shape& shape, HloInstruction* operand,
|
||||||
const std::vector<std::pair<int64, int64>>& source_target_pairs)
|
const std::vector<std::pair<int64, int64>>& source_target_pairs)
|
||||||
: HloInstruction(HloOpcode::kCollectivePermute, shape),
|
: HloChannelInstruction(HloOpcode::kCollectivePermute, shape,
|
||||||
|
absl::nullopt),
|
||||||
source_target_pairs_(source_target_pairs) {
|
source_target_pairs_(source_target_pairs) {
|
||||||
AppendOperand(operand);
|
AppendOperand(operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
|
HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
|
||||||
HloInstructionProto proto = HloInstruction::ToProto();
|
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||||
for (const auto& pair : source_target_pairs()) {
|
for (const auto& pair : source_target_pairs()) {
|
||||||
auto* proto_pair = proto.add_source_target_pairs();
|
auto* proto_pair = proto.add_source_target_pairs();
|
||||||
proto_pair->set_source(pair.first);
|
proto_pair->set_source(pair.first);
|
||||||
@ -604,8 +621,9 @@ HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
|
|||||||
|
|
||||||
std::vector<string>
|
std::vector<string>
|
||||||
HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
|
HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& /*options*/) const {
|
const HloPrintOptions& options) const {
|
||||||
std::vector<string> result;
|
std::vector<string> result =
|
||||||
|
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||||
std::vector<string> strs;
|
std::vector<string> strs;
|
||||||
for (const auto& pair : source_target_pairs()) {
|
for (const auto& pair : source_target_pairs()) {
|
||||||
strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
|
strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
|
||||||
@ -617,10 +635,11 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
|
|||||||
bool HloCollectivePermuteInstruction::IdenticalSlowPath(
|
bool HloCollectivePermuteInstruction::IdenticalSlowPath(
|
||||||
const HloInstruction& other,
|
const HloInstruction& other,
|
||||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
/*eq_computations*/) const {
|
eq_computations) const {
|
||||||
const auto& casted_other =
|
const auto& casted_other =
|
||||||
static_cast<const HloCollectivePermuteInstruction&>(other);
|
static_cast<const HloCollectivePermuteInstruction&>(other);
|
||||||
return absl::c_equal(source_target_pairs(),
|
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||||
|
absl::c_equal(source_target_pairs(),
|
||||||
casted_other.source_target_pairs(),
|
casted_other.source_target_pairs(),
|
||||||
[](const std::pair<int64, int64>& a,
|
[](const std::pair<int64, int64>& a,
|
||||||
const std::pair<int64, int64>& b) { return a == b; });
|
const std::pair<int64, int64>& b) { return a == b; });
|
||||||
|
@ -206,13 +206,37 @@ class HloCholeskyInstruction : public HloInstruction {
|
|||||||
CholeskyOptions cholesky_options_;
|
CholeskyOptions cholesky_options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloSendRecvInstruction : public HloInstruction {
|
// Class that represents instructions that synchronize and transfer data between
|
||||||
|
// partitioned devices. Send/Recv and collective instructions (AllReduce,
|
||||||
|
// AllToAll, CollectivePermute) belong to this instruction type. A group of
|
||||||
|
// instructions (of the same opcode) with the same channel_id communicate during
|
||||||
|
// execution.
|
||||||
|
class HloChannelInstruction : public HloInstruction {
|
||||||
public:
|
public:
|
||||||
// Returns the channel id associated with the instruction. The id is
|
// Returns the channel id associated with the instruction. The id is
|
||||||
// shared between each Send/Recv pair and is globally unique to identify each
|
// shared between each Send/Recv pair or a group of collective instructions
|
||||||
// channel.
|
// and is globally unique to identify each channel.
|
||||||
int64 channel_id() const { return channel_id_; }
|
absl::optional<int64> channel_id() const { return channel_id_; }
|
||||||
|
void set_channel_id(const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
|
||||||
|
const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
|
std::vector<string> ExtraAttributesToStringImpl(
|
||||||
|
const HloPrintOptions& options) const override;
|
||||||
|
bool IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const override;
|
||||||
|
|
||||||
|
absl::optional<int64> channel_id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class HloSendRecvInstruction : public HloChannelInstruction {
|
||||||
|
public:
|
||||||
// Returns whether this send/recv instruction sends data to/from the host.
|
// Returns whether this send/recv instruction sends data to/from the host.
|
||||||
bool is_host_transfer() const { return is_host_transfer_; }
|
bool is_host_transfer() const { return is_host_transfer_; }
|
||||||
|
|
||||||
@ -230,9 +254,6 @@ class HloSendRecvInstruction : public HloInstruction {
|
|||||||
const HloInstruction& other,
|
const HloInstruction& other,
|
||||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
eq_computations) const override;
|
eq_computations) const override;
|
||||||
// Represents a unique identifier for each Send/Recv instruction pair.
|
|
||||||
int64 channel_id_;
|
|
||||||
|
|
||||||
// Whether this send/recv instruction sends data to/from the host.
|
// Whether this send/recv instruction sends data to/from the host.
|
||||||
bool is_host_transfer_;
|
bool is_host_transfer_;
|
||||||
};
|
};
|
||||||
@ -285,7 +306,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
|
|||||||
HloCloneContext* context) const override;
|
HloCloneContext* context) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloCollectiveInstruction : public HloInstruction {
|
class HloCollectiveInstruction : public HloChannelInstruction {
|
||||||
public:
|
public:
|
||||||
const std::vector<ReplicaGroup>& replica_groups() const {
|
const std::vector<ReplicaGroup>& replica_groups() const {
|
||||||
return replica_groups_;
|
return replica_groups_;
|
||||||
@ -295,7 +316,8 @@ class HloCollectiveInstruction : public HloInstruction {
|
|||||||
explicit HloCollectiveInstruction(
|
explicit HloCollectiveInstruction(
|
||||||
HloOpcode opcode, const Shape& shape,
|
HloOpcode opcode, const Shape& shape,
|
||||||
absl::Span<HloInstruction* const> operands,
|
absl::Span<HloInstruction* const> operands,
|
||||||
const std::vector<ReplicaGroup>& replica_groups);
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
|
const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
HloInstructionProto ToProto() const override;
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
@ -315,21 +337,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
|||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
HloComputation* reduce_computation,
|
HloComputation* reduce_computation,
|
||||||
const std::vector<ReplicaGroup>& replica_groups,
|
const std::vector<ReplicaGroup>& replica_groups,
|
||||||
const absl::optional<int64>& all_reduce_id);
|
const absl::optional<int64>& channel_id);
|
||||||
|
|
||||||
absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
|
|
||||||
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
|
|
||||||
|
|
||||||
// Returns a serialized representation of this instruction.
|
|
||||||
HloInstructionProto ToProto() const override;
|
|
||||||
|
|
||||||
// Returns true if the AllReduce does no communication, so it's equivalent
|
// Returns true if the AllReduce does no communication, so it's equivalent
|
||||||
// to a mem copy.
|
// to a mem copy.
|
||||||
bool IsNoop() const;
|
bool IsNoop() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<string> ExtraAttributesToStringImpl(
|
|
||||||
const HloPrintOptions& options) const override;
|
|
||||||
bool IdenticalSlowPath(
|
bool IdenticalSlowPath(
|
||||||
const HloInstruction& other,
|
const HloInstruction& other,
|
||||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
@ -339,11 +353,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
|||||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||||
HloCloneContext* context) const override;
|
HloCloneContext* context) const override;
|
||||||
|
|
||||||
// For Allreduce nodes from different modules, if they have the same
|
|
||||||
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
|
|
||||||
// applied cross modules.
|
|
||||||
absl::optional<int64> all_reduce_id_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloAllToAllInstruction : public HloCollectiveInstruction {
|
class HloAllToAllInstruction : public HloCollectiveInstruction {
|
||||||
@ -359,7 +368,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
|
|||||||
HloCloneContext* context) const override;
|
HloCloneContext* context) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloCollectivePermuteInstruction : public HloInstruction {
|
class HloCollectivePermuteInstruction : public HloChannelInstruction {
|
||||||
public:
|
public:
|
||||||
explicit HloCollectivePermuteInstruction(
|
explicit HloCollectivePermuteInstruction(
|
||||||
const Shape& shape, HloInstruction* operand,
|
const Shape& shape, HloInstruction* operand,
|
||||||
|
@ -155,11 +155,11 @@ ENTRY entry {
|
|||||||
%p = f32[1000, 1000] parameter(0)
|
%p = f32[1000, 1000] parameter(0)
|
||||||
%token.0 = token[] after-all()
|
%token.0 = token[] after-all()
|
||||||
%send = (f32[1000, 1000], token[]) send(%p, %token.0),
|
%send = (f32[1000, 1000], token[]) send(%p, %token.0),
|
||||||
channel_id=0, is_host_transfer=true
|
channel_id=1, is_host_transfer=true
|
||||||
%n1 = f32[1000, 1000] negate(%p)
|
%n1 = f32[1000, 1000] negate(%p)
|
||||||
%n2 = f32[1000, 1000] negate(%n1)
|
%n2 = f32[1000, 1000] negate(%n1)
|
||||||
%n3 = f32[1000, 1000] negate(%n2)
|
%n3 = f32[1000, 1000] negate(%n2)
|
||||||
%send-done = token[] send-done(%send), channel_id=0, is_host_transfer=true
|
%send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ Status HloModuleGroupMetadata::Build() {
|
|||||||
if (IsChannelInstruction(hlo)) {
|
if (IsChannelInstruction(hlo)) {
|
||||||
peers.push_back(PeerComputation(hlo));
|
peers.push_back(PeerComputation(hlo));
|
||||||
} else if (hlo->IsCrossModuleAllReduce()) {
|
} else if (hlo->IsCrossModuleAllReduce()) {
|
||||||
for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) {
|
for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) {
|
||||||
if (instr == hlo) {
|
if (instr == hlo) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -235,7 +235,7 @@ bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
|
|||||||
HloComputation* HloModuleGroupMetadata::PeerComputation(
|
HloComputation* HloModuleGroupMetadata::PeerComputation(
|
||||||
const HloInstruction* instruction) const {
|
const HloInstruction* instruction) const {
|
||||||
CHECK(IsChannelInstruction(instruction));
|
CHECK(IsChannelInstruction(instruction));
|
||||||
const Channel& channel = GetChannel(instruction->channel_id());
|
const Channel& channel = GetChannel(*instruction->channel_id());
|
||||||
switch (instruction->opcode()) {
|
switch (instruction->opcode()) {
|
||||||
case HloOpcode::kSend:
|
case HloOpcode::kSend:
|
||||||
case HloOpcode::kSendDone:
|
case HloOpcode::kSendDone:
|
||||||
@ -249,8 +249,8 @@ HloComputation* HloModuleGroupMetadata::PeerComputation(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
|
const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
|
||||||
int64 all_reduce_id) const {
|
int64 channel_id) const {
|
||||||
auto it = all_reduce_map_.find(all_reduce_id);
|
auto it = all_reduce_map_.find(channel_id);
|
||||||
CHECK(it != all_reduce_map_.end());
|
CHECK(it != all_reduce_map_.end());
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
@ -330,14 +330,14 @@ Status HloModuleGroupMetadata::RecordInstructions() {
|
|||||||
TrackedInstruction(hlo, ComputationKind::kCallFunction);
|
TrackedInstruction(hlo, ComputationKind::kCallFunction);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group cross module all-reduce instructions by the all_reduce id.
|
// Group cross module all-reduce instructions by the channel id.
|
||||||
if (hlo->IsCrossModuleAllReduce()) {
|
if (hlo->IsCrossModuleAllReduce()) {
|
||||||
TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) ==
|
TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) ==
|
||||||
channel_id_map_.end())
|
channel_id_map_.end())
|
||||||
<< "all_reduce_id " << *hlo->all_reduce_id()
|
<< "channel_id " << *hlo->channel_id()
|
||||||
<< " is already used by a send/recv instruction";
|
<< " is already used by a send/recv instruction";
|
||||||
all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo);
|
all_reduce_map_[*hlo->channel_id()].push_back(hlo);
|
||||||
max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id());
|
max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -345,41 +345,41 @@ Status HloModuleGroupMetadata::RecordInstructions() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) ==
|
TF_RET_CHECK(all_reduce_map_.find(*hlo->channel_id()) ==
|
||||||
all_reduce_map_.end())
|
all_reduce_map_.end())
|
||||||
<< "channel id " << hlo->channel_id()
|
<< "channel id " << *hlo->channel_id()
|
||||||
<< " is already used by an all-reduce instruction";
|
<< " is already used by an all-reduce instruction";
|
||||||
|
|
||||||
// Add a new channel if needed.
|
// Add a new channel if needed.
|
||||||
if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
|
if (channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) {
|
||||||
channels_.emplace_back();
|
channels_.emplace_back();
|
||||||
channels_.back().id = hlo->channel_id();
|
channels_.back().id = *hlo->channel_id();
|
||||||
channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
|
channel_id_map_[*hlo->channel_id()] = channels_.size() - 1;
|
||||||
max_channel_id_ = std::max(max_channel_id_, hlo->channel_id());
|
max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
|
||||||
}
|
}
|
||||||
Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
|
Channel& channel = channels_[channel_id_map_[*hlo->channel_id()]];
|
||||||
|
|
||||||
if (hlo->opcode() == HloOpcode::kSend) {
|
if (hlo->opcode() == HloOpcode::kSend) {
|
||||||
TF_RET_CHECK(channel.send == nullptr)
|
TF_RET_CHECK(channel.send == nullptr)
|
||||||
<< "channel id " << hlo->channel_id()
|
<< "channel id " << *hlo->channel_id()
|
||||||
<< " is used by multiple send instructions";
|
<< " is used by multiple send instructions";
|
||||||
channel.send = hlo;
|
channel.send = hlo;
|
||||||
}
|
}
|
||||||
if (hlo->opcode() == HloOpcode::kRecv) {
|
if (hlo->opcode() == HloOpcode::kRecv) {
|
||||||
TF_RET_CHECK(channel.recv == nullptr)
|
TF_RET_CHECK(channel.recv == nullptr)
|
||||||
<< "channel id " << hlo->channel_id()
|
<< "channel id " << *hlo->channel_id()
|
||||||
<< " is used by multiple recv instructions";
|
<< " is used by multiple recv instructions";
|
||||||
channel.recv = hlo;
|
channel.recv = hlo;
|
||||||
}
|
}
|
||||||
if (hlo->opcode() == HloOpcode::kSendDone) {
|
if (hlo->opcode() == HloOpcode::kSendDone) {
|
||||||
TF_RET_CHECK(channel.send_done == nullptr)
|
TF_RET_CHECK(channel.send_done == nullptr)
|
||||||
<< "channel id " << hlo->channel_id()
|
<< "channel id " << *hlo->channel_id()
|
||||||
<< " is used by multiple send-done instructions";
|
<< " is used by multiple send-done instructions";
|
||||||
channel.send_done = hlo;
|
channel.send_done = hlo;
|
||||||
}
|
}
|
||||||
if (hlo->opcode() == HloOpcode::kRecvDone) {
|
if (hlo->opcode() == HloOpcode::kRecvDone) {
|
||||||
TF_RET_CHECK(channel.recv_done == nullptr)
|
TF_RET_CHECK(channel.recv_done == nullptr)
|
||||||
<< "channel id " << hlo->channel_id()
|
<< "channel id " << *hlo->channel_id()
|
||||||
<< " is used by multiple recv-done instructions";
|
<< " is used by multiple recv-done instructions";
|
||||||
channel.recv_done = hlo;
|
channel.recv_done = hlo;
|
||||||
}
|
}
|
||||||
|
@ -137,9 +137,8 @@ class HloModuleGroupMetadata {
|
|||||||
// Returns if the given channel id exists in metadata.
|
// Returns if the given channel id exists in metadata.
|
||||||
bool HasChannel(int64 channel_id) const;
|
bool HasChannel(int64 channel_id) const;
|
||||||
|
|
||||||
// Returns the all-reduce instructions with the same all_reduce_id.
|
// Returns the all-reduce instructions with the same channel_id.
|
||||||
const std::vector<HloInstruction*>& GetAllReduceGroup(
|
const std::vector<HloInstruction*>& GetAllReduceGroup(int64 channel_id) const;
|
||||||
int64 all_reduce_id) const;
|
|
||||||
|
|
||||||
// Returns the computation that contains the peer channel instructions for
|
// Returns the computation that contains the peer channel instructions for
|
||||||
// the given instruction.
|
// the given instruction.
|
||||||
@ -205,7 +204,7 @@ class HloModuleGroupMetadata {
|
|||||||
// Returns all channels in the module group.
|
// Returns all channels in the module group.
|
||||||
const std::vector<Channel>& channels() const { return channels_; }
|
const std::vector<Channel>& channels() const { return channels_; }
|
||||||
|
|
||||||
// Returns the maximum channel id or all_reduce_id used in the module group.
|
// Returns the maximum channel id used in the module group.
|
||||||
int64 max_channel_id() const { return max_channel_id_; }
|
int64 max_channel_id() const { return max_channel_id_; }
|
||||||
|
|
||||||
HloAliasAnalysis* alias_analysis(HloModule* module) const {
|
HloAliasAnalysis* alias_analysis(HloModule* module) const {
|
||||||
|
@ -62,7 +62,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
|||||||
}
|
}
|
||||||
if (predecessor->IsCrossModuleAllReduce()) {
|
if (predecessor->IsCrossModuleAllReduce()) {
|
||||||
for (HloInstruction* instr :
|
for (HloInstruction* instr :
|
||||||
metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
|
metadata_.GetAllReduceGroup(*predecessor->channel_id())) {
|
||||||
if (unique.insert(instr).second) {
|
if (unique.insert(instr).second) {
|
||||||
predecessors.push_back(instr);
|
predecessors.push_back(instr);
|
||||||
}
|
}
|
||||||
@ -82,8 +82,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
|||||||
instruction_group.push_back(companion);
|
instruction_group.push_back(companion);
|
||||||
}
|
}
|
||||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||||
instruction_group =
|
instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
|
||||||
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
|
|
||||||
} else {
|
} else {
|
||||||
instruction_group.push_back(instruction);
|
instruction_group.push_back(instruction);
|
||||||
}
|
}
|
||||||
@ -99,14 +98,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
|||||||
if (instruction->opcode() == HloOpcode::kRecvDone &&
|
if (instruction->opcode() == HloOpcode::kRecvDone &&
|
||||||
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
|
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
|
||||||
// Send is a remote predecessor of RecvDone.
|
// Send is a remote predecessor of RecvDone.
|
||||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
HloInstruction* send =
|
||||||
|
metadata_.GetChannel(*instruction->channel_id()).send;
|
||||||
add_unique_predecessor(send);
|
add_unique_predecessor(send);
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kSend &&
|
if (instruction->opcode() == HloOpcode::kSend &&
|
||||||
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||||
// Recv is a remote predecessor of Send.
|
// Recv is a remote predecessor of Send.
|
||||||
HloInstruction* recv_done =
|
HloInstruction* recv_done =
|
||||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
metadata_.GetChannel(*instruction->channel_id()).recv_done;
|
||||||
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||||
CHECK_EQ(recv_done->operand_count(), 1);
|
CHECK_EQ(recv_done->operand_count(), 1);
|
||||||
HloInstruction* recv = recv_done->mutable_operand(0);
|
HloInstruction* recv = recv_done->mutable_operand(0);
|
||||||
@ -139,7 +139,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
|||||||
}
|
}
|
||||||
if (successor->IsCrossModuleAllReduce()) {
|
if (successor->IsCrossModuleAllReduce()) {
|
||||||
for (HloInstruction* instr :
|
for (HloInstruction* instr :
|
||||||
metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
|
metadata_.GetAllReduceGroup(*successor->channel_id())) {
|
||||||
if (unique.insert(instr).second) {
|
if (unique.insert(instr).second) {
|
||||||
successors.push_back(instr);
|
successors.push_back(instr);
|
||||||
}
|
}
|
||||||
@ -160,8 +160,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
|||||||
instruction_group.push_back(companion);
|
instruction_group.push_back(companion);
|
||||||
}
|
}
|
||||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||||
instruction_group =
|
instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
|
||||||
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
|
|
||||||
} else {
|
} else {
|
||||||
instruction_group.push_back(instruction);
|
instruction_group.push_back(instruction);
|
||||||
}
|
}
|
||||||
@ -179,14 +178,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
|||||||
// Send is a remote successor of Recv.
|
// Send is a remote successor of Recv.
|
||||||
const HloInstruction* recv_done = instruction->users().front();
|
const HloInstruction* recv_done = instruction->users().front();
|
||||||
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
HloInstruction* send =
|
||||||
|
metadata_.GetChannel(*instruction->channel_id()).send;
|
||||||
add_unique_successor(send);
|
add_unique_successor(send);
|
||||||
}
|
}
|
||||||
if (instruction->opcode() == HloOpcode::kSend &&
|
if (instruction->opcode() == HloOpcode::kSend &&
|
||||||
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||||
// RecvDone is a remote successor of Send.
|
// RecvDone is a remote successor of Send.
|
||||||
HloInstruction* recv_done =
|
HloInstruction* recv_done =
|
||||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
metadata_.GetChannel(*instruction->channel_id()).recv_done;
|
||||||
add_unique_successor(recv_done);
|
add_unique_successor(recv_done);
|
||||||
}
|
}
|
||||||
return successors;
|
return successors;
|
||||||
@ -256,7 +256,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
|
|||||||
instruction_group.push_back(companion);
|
instruction_group.push_back(companion);
|
||||||
}
|
}
|
||||||
} else if (hlo->IsCrossModuleAllReduce()) {
|
} else if (hlo->IsCrossModuleAllReduce()) {
|
||||||
instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id());
|
instruction_group = metadata_.GetAllReduceGroup(*hlo->channel_id());
|
||||||
} else {
|
} else {
|
||||||
instruction_group.push_back(hlo);
|
instruction_group.push_back(hlo);
|
||||||
}
|
}
|
||||||
|
@ -836,13 +836,12 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
optional<std::vector<std::vector<int64>>> tmp_groups;
|
optional<std::vector<std::vector<int64>>> tmp_groups;
|
||||||
optional<HloComputation*> to_apply;
|
optional<HloComputation*> to_apply;
|
||||||
optional<std::vector<int64>> replica_group_ids;
|
optional<std::vector<int64>> replica_group_ids;
|
||||||
optional<int64> all_reduce_id;
|
optional<int64> channel_id;
|
||||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||||
&to_apply};
|
&to_apply};
|
||||||
attrs["replica_groups"] = {/*required=*/false,
|
attrs["replica_groups"] = {/*required=*/false,
|
||||||
AttrTy::kBracedInt64ListList, &tmp_groups};
|
AttrTy::kBracedInt64ListList, &tmp_groups};
|
||||||
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
|
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
|
||||||
&all_reduce_id};
|
|
||||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -851,7 +850,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
replica_groups = CreateReplicaGroups(*tmp_groups);
|
replica_groups = CreateReplicaGroups(*tmp_groups);
|
||||||
}
|
}
|
||||||
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
|
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
|
||||||
shape, operands, *to_apply, replica_groups, all_reduce_id));
|
shape, operands, *to_apply, replica_groups, channel_id));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kAllToAll: {
|
case HloOpcode::kAllToAll: {
|
||||||
|
@ -1416,8 +1416,8 @@ add {
|
|||||||
|
|
||||||
ENTRY CRS {
|
ENTRY CRS {
|
||||||
input = f32[8]{0} parameter(0)
|
input = f32[8]{0} parameter(0)
|
||||||
crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
|
crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
|
||||||
ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
|
ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
|
@ -85,8 +85,8 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
|||||||
std::vector<HloInstruction*> inputs;
|
std::vector<HloInstruction*> inputs;
|
||||||
const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
|
const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
|
||||||
inputs.push_back(input);
|
inputs.push_back(input);
|
||||||
if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) {
|
if (input->opcode() == HloOpcode::kAllReduce && input->channel_id()) {
|
||||||
auto it = channel_group.find(*input->all_reduce_id());
|
auto it = channel_group.find(*input->channel_id());
|
||||||
if (it != channel_group.end()) {
|
if (it != channel_group.end()) {
|
||||||
inputs.insert(inputs.end(), it->second.begin(), it->second.end());
|
inputs.insert(inputs.end(), it->second.begin(), it->second.end());
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
|||||||
|
|
||||||
switch (hlo->opcode()) {
|
switch (hlo->opcode()) {
|
||||||
case HloOpcode::kRecvDone: {
|
case HloOpcode::kRecvDone: {
|
||||||
auto it = channel_group.find(hlo->channel_id());
|
auto it = channel_group.find(*hlo->channel_id());
|
||||||
if (it != channel_group.end()) {
|
if (it != channel_group.end()) {
|
||||||
for (HloInstruction* channel : it->second) {
|
for (HloInstruction* channel : it->second) {
|
||||||
if (channel->opcode() == HloOpcode::kSend) {
|
if (channel->opcode() == HloOpcode::kSend) {
|
||||||
@ -117,9 +117,9 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kAllReduce: {
|
case HloOpcode::kAllReduce: {
|
||||||
auto all_reduce_id = hlo->all_reduce_id();
|
auto channel_id = hlo->channel_id();
|
||||||
if (all_reduce_id) {
|
if (channel_id) {
|
||||||
auto it = channel_group.find(all_reduce_id.value());
|
auto it = channel_group.find(channel_id.value());
|
||||||
if (it != channel_group.end()) {
|
if (it != channel_group.end()) {
|
||||||
for (HloInstruction* all_reduce : it->second) {
|
for (HloInstruction* all_reduce : it->second) {
|
||||||
add_dependencies(all_reduce);
|
add_dependencies(all_reduce);
|
||||||
|
@ -1210,8 +1210,8 @@ Status CheckSameChannel(const HloInstruction* instr1,
|
|||||||
return InternalError(
|
return InternalError(
|
||||||
"Expected to have the same channel id, actual channel ids are: %s "
|
"Expected to have the same channel id, actual channel ids are: %s "
|
||||||
"(%d), %s (%d)",
|
"(%d), %s (%d)",
|
||||||
instr1->ToString(), instr1->channel_id(), instr2->ToString(),
|
instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
|
||||||
instr2->channel_id());
|
*instr2->channel_id());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1282,14 +1282,14 @@ Status VerifySendsAndRecvs(const HloModule& module) {
|
|||||||
DynCast<const HloSendRecvInstruction>(instruction);
|
DynCast<const HloSendRecvInstruction>(instruction);
|
||||||
if (sendrecv->is_host_transfer()) {
|
if (sendrecv->is_host_transfer()) {
|
||||||
auto it_inserted =
|
auto it_inserted =
|
||||||
host_channels.insert({sendrecv->channel_id(), sendrecv});
|
host_channels.insert({*sendrecv->channel_id(), sendrecv});
|
||||||
if (!it_inserted.second) {
|
if (!it_inserted.second) {
|
||||||
return FailedPrecondition(
|
return FailedPrecondition(
|
||||||
"Channel %d is used for multiple host send/recv instructions: "
|
"Channel %d is used for multiple host send/recv instructions: "
|
||||||
"%s "
|
"%s "
|
||||||
"and "
|
"and "
|
||||||
"%s",
|
"%s",
|
||||||
sendrecv->channel_id(), sendrecv->ToString(),
|
*sendrecv->channel_id(), sendrecv->ToString(),
|
||||||
it_inserted.first->second->ToString());
|
it_inserted.first->second->ToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1574,9 +1574,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HandleAllReduce(HloInstruction* crs) override {
|
Status HandleAllReduce(HloInstruction* crs) override {
|
||||||
if (crs->all_reduce_id().has_value()) {
|
if (crs->channel_id().has_value()) {
|
||||||
TF_RET_CHECK(crs->all_reduce_id().value() > 0)
|
TF_RET_CHECK(crs->channel_id().value() > 0)
|
||||||
<< "All reduce id must be greater than 0 for "
|
<< "All reduce channel id must be greater than 0 for "
|
||||||
<< crs->ToShortString();
|
<< crs->ToShortString();
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -262,7 +262,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
|||||||
abs1 = f32[4,3]{1,0} abs(add)
|
abs1 = f32[4,3]{1,0} abs(add)
|
||||||
log = f32[4,3]{1,0} log(abs1)
|
log = f32[4,3]{1,0} log(abs1)
|
||||||
token0 = token[] after-all()
|
token0 = token[] after-all()
|
||||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
send = f32[4,3]{1,0} send(log, token0), channel_id=1
|
||||||
abs2 = f32[4,3]{1,0} abs(log)
|
abs2 = f32[4,3]{1,0} abs(log)
|
||||||
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
|
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
|
||||||
})")
|
})")
|
||||||
@ -294,7 +294,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
|||||||
add1 = f32[4,3]{1,0} add(p0, p1)
|
add1 = f32[4,3]{1,0} add(p0, p1)
|
||||||
log = f32[4,3]{1,0} log(p0)
|
log = f32[4,3]{1,0} log(p0)
|
||||||
token0 = token[] after-all()
|
token0 = token[] after-all()
|
||||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
send = f32[4,3]{1,0} send(log, token0), channel_id=1
|
||||||
add2 = f32[4,3]{1,0} add(log, add1)
|
add2 = f32[4,3]{1,0} add(log, add1)
|
||||||
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
|
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
|
||||||
})")
|
})")
|
||||||
@ -329,7 +329,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
|||||||
add2 = f32[4,3]{1,0} add(add1, p1)
|
add2 = f32[4,3]{1,0} add(add1, p1)
|
||||||
log = f32[4,3]{1,0} log(add2)
|
log = f32[4,3]{1,0} log(add2)
|
||||||
token0 = token[] after-all()
|
token0 = token[] after-all()
|
||||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
send = f32[4,3]{1,0} send(log, token0), channel_id=1
|
||||||
sub1 = f32[4,3]{1,0} subtract(log, add2)
|
sub1 = f32[4,3]{1,0} subtract(log, add2)
|
||||||
sub2 = f32[4,3]{1,0} subtract(add2, add1)
|
sub2 = f32[4,3]{1,0} subtract(add2, add1)
|
||||||
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
|
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
|
||||||
|
@ -408,7 +408,7 @@ Status LayoutAssignment::BuildHostChannelConstraints(
|
|||||||
TF_RET_CHECK(data_shape.IsArray());
|
TF_RET_CHECK(data_shape.IsArray());
|
||||||
TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
|
TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
|
||||||
const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
|
const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
|
||||||
send_recv_instr->channel_id(), data_shape.layout());
|
*send_recv_instr->channel_id(), data_shape.layout());
|
||||||
TF_RET_CHECK(prev_layout == nullptr)
|
TF_RET_CHECK(prev_layout == nullptr)
|
||||||
<< "Cannot constrain host transfer layout as it was set to "
|
<< "Cannot constrain host transfer layout as it was set to "
|
||||||
<< LayoutUtil::HumanString(*prev_layout) << ": "
|
<< LayoutUtil::HumanString(*prev_layout) << ": "
|
||||||
@ -480,7 +480,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
|||||||
instruction->opcode() == HloOpcode::kRecv) {
|
instruction->opcode() == HloOpcode::kRecv) {
|
||||||
CHECK(get_channel_constraints(instruction))
|
CHECK(get_channel_constraints(instruction))
|
||||||
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
||||||
int64 channel_id = instruction->channel_id();
|
int64 channel_id = *instruction->channel_id();
|
||||||
if (!get_channel_constraints(instruction)
|
if (!get_channel_constraints(instruction)
|
||||||
->IsChannelConstrained(channel_id)) {
|
->IsChannelConstrained(channel_id)) {
|
||||||
continue;
|
continue;
|
||||||
@ -492,7 +492,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
|||||||
Shape new_buffer_shape =
|
Shape new_buffer_shape =
|
||||||
get_channel_constraints(instruction)
|
get_channel_constraints(instruction)
|
||||||
->LayoutShapeForChannel(send_buffer_shape,
|
->LayoutShapeForChannel(send_buffer_shape,
|
||||||
instruction->channel_id());
|
*instruction->channel_id());
|
||||||
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
||||||
new_buffer_shape, instruction->operand(0)));
|
new_buffer_shape, instruction->operand(0)));
|
||||||
} else {
|
} else {
|
||||||
@ -503,18 +503,19 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
|||||||
const LogicalBuffer* buffer,
|
const LogicalBuffer* buffer,
|
||||||
constraints->points_to_analysis().GetBufferDefinedAt(instruction,
|
constraints->points_to_analysis().GetBufferDefinedAt(instruction,
|
||||||
{0}));
|
{0}));
|
||||||
Shape new_shape = get_channel_constraints(instruction)
|
Shape new_shape =
|
||||||
->LayoutShapeForChannel(
|
get_channel_constraints(instruction)
|
||||||
recv_buffer_shape, instruction->channel_id());
|
->LayoutShapeForChannel(recv_buffer_shape,
|
||||||
|
*instruction->channel_id());
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
constraints->SetBufferLayout(new_shape.layout(), *buffer));
|
constraints->SetBufferLayout(new_shape.layout(), *buffer));
|
||||||
}
|
}
|
||||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||||
CHECK(get_channel_constraints(instruction))
|
CHECK(get_channel_constraints(instruction))
|
||||||
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
||||||
int64 all_reduce_id = instruction->all_reduce_id().value();
|
int64 channel_id = instruction->channel_id().value();
|
||||||
if (!get_channel_constraints(instruction)
|
if (!get_channel_constraints(instruction)
|
||||||
->IsChannelConstrained(all_reduce_id)) {
|
->IsChannelConstrained(channel_id)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// TODO(b/68493863): Change to use SetOperandLayout().
|
// TODO(b/68493863): Change to use SetOperandLayout().
|
||||||
@ -522,7 +523,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
|||||||
TF_RET_CHECK(buffer_shape.IsArray());
|
TF_RET_CHECK(buffer_shape.IsArray());
|
||||||
Shape new_buffer_shape =
|
Shape new_buffer_shape =
|
||||||
get_channel_constraints(instruction)
|
get_channel_constraints(instruction)
|
||||||
->LayoutShapeForChannel(buffer_shape, all_reduce_id);
|
->LayoutShapeForChannel(buffer_shape, channel_id);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
constraints->SetInstructionLayout(new_buffer_shape, instruction));
|
constraints->SetInstructionLayout(new_buffer_shape, instruction));
|
||||||
}
|
}
|
||||||
@ -1833,7 +1834,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
|||||||
const Layout* layout =
|
const Layout* layout =
|
||||||
get_channel_constraints(instruction)
|
get_channel_constraints(instruction)
|
||||||
->ConstrainChannel(
|
->ConstrainChannel(
|
||||||
instruction->channel_id(),
|
*instruction->channel_id(),
|
||||||
ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
|
ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
|
||||||
TF_RET_CHECK(layout == nullptr)
|
TF_RET_CHECK(layout == nullptr)
|
||||||
<< instruction->ToString()
|
<< instruction->ToString()
|
||||||
@ -1848,7 +1849,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
|||||||
if (instruction->opcode() == HloOpcode::kSend) {
|
if (instruction->opcode() == HloOpcode::kSend) {
|
||||||
HloInstruction* operand = instruction->mutable_operand(0);
|
HloInstruction* operand = instruction->mutable_operand(0);
|
||||||
const Layout* layout = get_channel_constraints(instruction)
|
const Layout* layout = get_channel_constraints(instruction)
|
||||||
->ConstrainChannel(instruction->channel_id(),
|
->ConstrainChannel(*instruction->channel_id(),
|
||||||
operand->shape().layout());
|
operand->shape().layout());
|
||||||
if (layout != nullptr) {
|
if (layout != nullptr) {
|
||||||
// We found an already constrained layout which does not match the one
|
// We found an already constrained layout which does not match the one
|
||||||
@ -1873,7 +1874,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
|||||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||||
const Layout* layout =
|
const Layout* layout =
|
||||||
get_channel_constraints(instruction)
|
get_channel_constraints(instruction)
|
||||||
->ConstrainChannel(instruction->all_reduce_id().value(),
|
->ConstrainChannel(instruction->channel_id().value(),
|
||||||
instruction->shape().layout());
|
instruction->shape().layout());
|
||||||
if (layout != nullptr) {
|
if (layout != nullptr) {
|
||||||
// We found an already constrained layout which does not match the one
|
// We found an already constrained layout which does not match the one
|
||||||
|
@ -891,11 +891,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
|
|||||||
param = (f32[2,2]) parameter(0)
|
param = (f32[2,2]) parameter(0)
|
||||||
gte = f32[2,2] get-tuple-element(param), index=0
|
gte = f32[2,2] get-tuple-element(param), index=0
|
||||||
ar.0 = f32[2,2] all-reduce(gte),
|
ar.0 = f32[2,2] all-reduce(gte),
|
||||||
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
|
channel_id=1, replica_groups={{0}}, to_apply=add,
|
||||||
sharding={maximal device=0}
|
sharding={maximal device=0}
|
||||||
const = f32[2,2] constant({{0,1},{2,3}})
|
const = f32[2,2] constant({{0,1},{2,3}})
|
||||||
ROOT ar.1 = f32[2,2] all-reduce(const),
|
ROOT ar.1 = f32[2,2] all-reduce(const),
|
||||||
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
|
channel_id=1, replica_groups={{0}}, to_apply=add,
|
||||||
sharding={maximal device=1}
|
sharding={maximal device=1}
|
||||||
})";
|
})";
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||||
|
Loading…
Reference in New Issue
Block a user