Change all_reduce_id to use channel_id

PiperOrigin-RevId: 255315385
This commit is contained in:
HyoukJoong Lee 2019-06-26 19:47:50 -07:00 committed by TensorFlower Gardener
parent 6ae9600988
commit 852061b75b
26 changed files with 267 additions and 241 deletions

View File

@ -2068,7 +2068,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
} }
if (channel_id.has_value()) { if (channel_id.has_value()) {
instr.set_all_reduce_id(channel_id->handle()); instr.set_channel_id(channel_id->handle());
} }
AddCalledComputation(computation, &instr); AddCalledComputation(computation, &instr);

View File

@ -319,7 +319,7 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
auto maybe_pair = MatchesArCrsPattern(instruction); auto maybe_pair = MatchesArCrsPattern(instruction);
if (maybe_pair) { if (maybe_pair) {
auto pair = *maybe_pair; auto pair = *maybe_pair;
int64 ar_id = *(instruction->all_reduce_id()); int64 ar_id = *(instruction->channel_id());
if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) { if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
continue; continue;
} }
@ -365,9 +365,10 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
for (auto it : all_reduce_map_) { for (auto it : all_reduce_map_) {
auto all_reduce_id = it.first; auto channel_id = it.first;
VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking ar_id: " VLOG(2)
<< all_reduce_id << "\n"; << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
<< channel_id << "\n";
auto pairs_vec = it.second; auto pairs_vec = it.second;
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
auto instr_0 = pairs_vec[0].ar; auto instr_0 = pairs_vec[0].ar;
@ -378,9 +379,10 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
absl::flat_hash_map<int64, int64> visited_pairs; absl::flat_hash_map<int64, int64> visited_pairs;
while (true) { while (true) {
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
all_reduce_map_.erase(all_reduce_id); all_reduce_map_.erase(channel_id);
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased ar_id: " VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
<< all_reduce_id << "\n"; "channel id: "
<< channel_id << "\n";
break; break;
} }
if (next_0->IsCrossReplicaAllReduce()) { if (next_0->IsCrossReplicaAllReduce()) {
@ -402,7 +404,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
for (auto pair : pairs_vec) { for (auto pair : pairs_vec) {
auto all_reduce = pair.ar; auto all_reduce = pair.ar;
auto parent_computation = all_reduce->parent(); auto parent_computation = all_reduce->parent();
auto all_reduce_id = all_reduce->all_reduce_id(); auto channel_id = all_reduce->channel_id();
auto prev = all_reduce->mutable_operand(0); auto prev = all_reduce->mutable_operand(0);
auto next = all_reduce->users()[0]; auto next = all_reduce->users()[0];
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
@ -447,7 +449,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
next = next->users()[0]; next = next->users()[0];
} }
// The AllReduce and the CRS are combined to an all-core AllReduce. // The AllReduce and the CRS are combined to an all-core AllReduce.
next->set_all_reduce_id(all_reduce_id); next->set_channel_id(channel_id);
} }
} }
return true; return true;

View File

@ -36,7 +36,7 @@ namespace xla {
// //
// The steps are: // The steps are:
// 1) Find CMARs followed by simple ops followed by CRARs. // 1) Find CMARs followed by simple ops followed by CRARs.
// 2) Group CMARs by all_reduce_id. They must all be rewritten. // 2) Group CMARs by channel_id. They must all be rewritten.
// 3) Prove that the CMAR patterns in each core produce the same result. // 3) Prove that the CMAR patterns in each core produce the same result.
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
// other operand by the number of spatial partitions. // other operand by the number of spatial partitions.
@ -104,7 +104,7 @@ class ArCrsCombiner : public HloModulePass {
} }
pieces.push_back(instruction->name()); pieces.push_back(instruction->name());
pieces.push_back(")[id:"); pieces.push_back(")[id:");
pieces.push_back(std::to_string(*(ar->all_reduce_id()))); pieces.push_back(std::to_string(*(ar->channel_id())));
pieces.push_back(",dist:"); pieces.push_back(",dist:");
pieces.push_back(std::to_string(distance)); pieces.push_back(std::to_string(distance));
pieces.push_back("]"); pieces.push_back("]");

View File

@ -414,7 +414,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[] %all-reduce.ar.1 = bf16[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=0} sharding={maximal device=0}
%convert.1 = f32[] %convert.1 = f32[]
@ -429,7 +429,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[] %all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16), all-reduce(%constant.bf16),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=1} sharding={maximal device=1}
%convert.2 = f32[] %convert.2 = f32[]
@ -486,7 +486,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
%all-reduce.ar.1 = f32[2,1] %all-reduce.ar.1 = f32[2,1]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.1, to_apply=%sum.1,
sharding={maximal device=0} sharding={maximal device=0}
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1) %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
@ -499,7 +499,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
%all-reduce.ar.2 = f32[2,1] %all-reduce.ar.2 = f32[2,1]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.1, to_apply=%sum.1,
sharding={maximal device=1} sharding={maximal device=1}
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2) %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
@ -549,7 +549,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[] %all-reduce.ar.1 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.f32, to_apply=%sum.f32,
sharding={maximal device=0} sharding={maximal device=0}
%multiply.1 = f32[] %multiply.1 = f32[]
@ -564,7 +564,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[] %all-reduce.ar.2 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.f32, to_apply=%sum.f32,
sharding={maximal device=1} sharding={maximal device=1}
%multiply.2 = f32[] %multiply.2 = f32[]
@ -624,7 +624,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[] %all-reduce.ar.1 = bf16[]
all-reduce(%constant.bf16), all-reduce(%constant.bf16),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=0} sharding={maximal device=0}
%convert.1 = f32[] %convert.1 = f32[]
@ -642,7 +642,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[] %all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16), all-reduce(%constant.bf16),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=1} sharding={maximal device=1}
%convert.2 = f32[] %convert.2 = f32[]
@ -709,7 +709,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[] %all-reduce.ar.1 = bf16[]
all-reduce(%constant.bf16), all-reduce(%constant.bf16),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=0} sharding={maximal device=0}
%convert.1 = f32[] %convert.1 = f32[]
@ -727,7 +727,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[] %all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16), all-reduce(%constant.bf16),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=1} sharding={maximal device=1}
%convert.2 = f32[] %convert.2 = f32[]
@ -772,7 +772,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[] %all-reduce.ar.1 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.1, to_apply=%sum.1,
sharding={maximal device=0} sharding={maximal device=0}
%all-reduce.1 = f32[] %all-reduce.1 = f32[]
@ -787,7 +787,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[] %all-reduce.ar.2 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.1, to_apply=%sum.1,
sharding={maximal device=1} sharding={maximal device=1}
%all-reduce.2 = f32[] %all-reduce.2 = f32[]
@ -840,7 +840,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[] %all-reduce.ar.1 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=0} sharding={maximal device=0}
%add.11 = f32[] %add.11 = f32[]
@ -858,7 +858,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[] %all-reduce.ar.2 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=0} sharding={maximal device=0}
%add.21 = f32[] %add.21 = f32[]
@ -919,7 +919,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[] %all-reduce.ar.1 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.f32, to_apply=%sum.f32,
sharding={maximal device=0} sharding={maximal device=0}
%sub.1 = f32[] %sub.1 = f32[]
@ -934,7 +934,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[] %all-reduce.ar.2 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.f32, to_apply=%sum.f32,
sharding={maximal device=1} sharding={maximal device=1}
%sub.2 = f32[] %sub.2 = f32[]
@ -991,7 +991,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar11 = f32[] %ar11 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=0} sharding={maximal device=0}
%add11 = f32[] %add11 = f32[]
@ -1000,7 +1000,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar12 = f32[] %ar12 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=2, channel_id=2,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=0} sharding={maximal device=0}
%add12 = f32[] %add12 = f32[]
@ -1015,7 +1015,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar21 = f32[] %ar21 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=1} sharding={maximal device=1}
%add21 = f32[] %add21 = f32[]
@ -1024,7 +1024,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar22 = f32[] %ar22 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=2, channel_id=2,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=1} sharding={maximal device=1}
%add22 = f32[] %add22 = f32[]
@ -1083,13 +1083,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar11 = f32[] %ar11 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=0} sharding={maximal device=0}
%ar12 = f32[] %ar12 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=2, channel_id=2,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=0} sharding={maximal device=0}
%add11 = f32[] %add11 = f32[]
@ -1107,13 +1107,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar21 = f32[] %ar21 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=1, channel_id=1,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=1} sharding={maximal device=1}
%ar22 = f32[] %ar22 = f32[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0},{1}}, replica_groups={{0},{1}},
all_reduce_id=2, channel_id=2,
to_apply=%sum, to_apply=%sum,
sharding={maximal device=1} sharding={maximal device=1}
%add21 = f32[] %add21 = f32[]
@ -1182,7 +1182,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[] %all-reduce.ar.1 = bf16[]
all-reduce(%p), all-reduce(%p),
replica_groups={{0}}, replica_groups={{0}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=0} sharding={maximal device=0}
%convert.1 = f32[] %convert.1 = f32[]
@ -1197,7 +1197,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[] %all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16), all-reduce(%constant.bf16),
replica_groups={{0}}, replica_groups={{0}},
all_reduce_id=1, channel_id=1,
to_apply=%sum.bf16, to_apply=%sum.bf16,
sharding={maximal device=1} sharding={maximal device=1}
%convert.2 = f32[] %convert.2 = f32[]
@ -1276,9 +1276,9 @@ HloModule foobar
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%p = bf16[] parameter(0) %p = bf16[] parameter(0)
%all-reduce.0 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}}, %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=0} to_apply=%sum.f32, sharding={maximal device=0}
%all-reduce.1 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}}, %all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=1} to_apply=%sum.f32, sharding={maximal device=1}
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}}, %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=0} to_apply=%sum.f32, sharding={maximal device=0}

View File

@ -239,7 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum, ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
/*replica_groups=*/{}, /*replica_groups=*/{},
/*all_reduce_id=*/absl::nullopt)); /*channel_id=*/absl::nullopt));
HloInstruction* gte_a = builder.AddInstruction( HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction( HloInstruction* gte_b = builder.AddInstruction(

View File

@ -257,7 +257,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
/*replica_groups=*/{}, /*replica_groups=*/{},
/*all_reduce_id=*/absl::nullopt)); /*channel_id=*/absl::nullopt));
HloInstruction* gte = builder.AddInstruction( HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));

View File

@ -211,7 +211,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
HloInstruction* all_reduce = HloInstruction* all_reduce =
builder.AddInstruction(HloInstruction::CreateAllReduce( builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
/*replica_groups=*/{}, /*all_reduce_id=*/1)); /*replica_groups=*/{}, /*channel_id=*/1));
HloInstruction* gte0 = builder.AddInstruction( HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0)); HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
HloInstruction* gte1 = builder.AddInstruction( HloInstruction* gte1 = builder.AddInstruction(

View File

@ -177,13 +177,13 @@ class NcclComm {
// * Only ops with the same opcode can communicate with each other. At the // * Only ops with the same opcode can communicate with each other. At the
// moment we only support kAllReduce, so we don't check for this explicitly. // moment we only support kAllReduce, so we don't check for this explicitly.
// //
// * For cross-module all-reduces (i.e. instr->all_reduce_id().has_value()), // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
// only ops with the same value for all_reduce_id() can communicate with each // only ops with the same value for channel_id() can communicate with each
// other. // other.
// //
// * For cross-replica (i.e. same-module) all-reduces (i.e. // * For cross-replica (i.e. same-module) all-reduces (i.e.
// !all_reduce_id().has_value()), only ops from the same module (as identified // !channel_id().has_value()), only ops from the same module (as
// by its unique_id()) can communicate with each other. // identified by its unique_id()) can communicate with each other.
// //
struct RendezvousKey { struct RendezvousKey {
enum AllReduceKind { enum AllReduceKind {
@ -196,8 +196,8 @@ struct RendezvousKey {
const HloAllReduceInstruction* instr) const HloAllReduceInstruction* instr)
: run_id(run_id), participating_replicas(participating_replicas) { : run_id(run_id), participating_replicas(participating_replicas) {
std::tie(all_reduce_kind, op_id) = std::tie(all_reduce_kind, op_id) =
instr->all_reduce_id().has_value() instr->channel_id().has_value()
? std::make_pair(kCrossModule, instr->all_reduce_id().value()) ? std::make_pair(kCrossModule, instr->channel_id().value())
: std::make_pair( : std::make_pair(
kCrossReplica, kCrossReplica,
static_cast<int64>(instr->GetModule()->unique_id())); static_cast<int64>(instr->GetModule()->unique_id()));

View File

@ -126,8 +126,9 @@ message HloInstructionProto {
// Only present for kBatchNormTraining. // Only present for kBatchNormTraining.
int64 feature_index = 25; int64 feature_index = 25;
// Represents a unique identifier for each Send/Recv instruction pair. // Represents a unique identifier for each Send/Recv instruction pair or
// Only present for kSend or kRecv. // optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id is equivalent to no channel id.
int64 channel_id = 26; int64 channel_id = 26;
// The string representation of the infeed configuration. // The string representation of the infeed configuration.
@ -174,7 +175,9 @@ message HloInstructionProto {
// Cross replica op fields. // Cross replica op fields.
repeated ReplicaGroup replica_groups = 49; repeated ReplicaGroup replica_groups = 49;
int64 all_reduce_id = 45; // Deprecated, but keeping it for backward compatibility. Use channel_id.
// Non-positive all_reduce_id is equivalent to no all_reduce_id.
int64 all_reduce_id = 45 [deprecated = true];
// Whether this Send/Recv instruction transfers data to/from the host. Only // Whether this Send/Recv instruction transfers data to/from the host. Only
// present for Send and Recv instructions and their SendDone and RecvDone // present for Send and Recv instructions and their SendDone and RecvDone

View File

@ -373,7 +373,7 @@ void HloComputation::ComputeInstructionPostOrder(
case HloOpcode::kRecvDone: case HloOpcode::kRecvDone:
return inst->channel_id(); return inst->channel_id();
case HloOpcode::kAllReduce: case HloOpcode::kAllReduce:
return inst->all_reduce_id(); return inst->channel_id();
default: default:
return absl::nullopt; return absl::nullopt;
} }
@ -428,13 +428,10 @@ HloComputation::ComputeChannelDependencies() const {
switch (instruction->opcode()) { switch (instruction->opcode()) {
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kRecvDone: case HloOpcode::kRecvDone:
channel_dependency_group[instruction->channel_id()].push_back(
instruction.get());
break;
case HloOpcode::kAllReduce: { case HloOpcode::kAllReduce: {
auto all_reduce_id = instruction->all_reduce_id(); auto channel_id = instruction->channel_id();
if (all_reduce_id) { if (channel_id) {
channel_dependency_group[all_reduce_id.value()].push_back( channel_dependency_group[channel_id.value()].push_back(
instruction.get()); instruction.get());
} }
break; break;

View File

@ -688,10 +688,10 @@ add {
ENTRY entry { ENTRY entry {
param = f32[128] parameter(0), sharding={maximal device=0} param = f32[128] parameter(0), sharding={maximal device=0}
crs0 = f32[128] all-reduce(param), crs0 = f32[128] all-reduce(param),
replica_groups={{0}}, all_reduce_id=1, to_apply=add, replica_groups={{0}}, channel_id=1, to_apply=add,
sharding={maximal device=0} sharding={maximal device=0}
crs1 = f32[128] all-reduce(param), crs1 = f32[128] all-reduce(param),
replica_groups={{0}}, all_reduce_id=1, to_apply=add, replica_groups={{0}}, channel_id=1, to_apply=add,
sharding={maximal device=1} sharding={maximal device=1}
add = f32[128] add(crs0, crs0), sharding={maximal device=0} add = f32[128] add(crs0, crs0), sharding={maximal device=0}
ROOT t = (f32[128], f32[128]) tuple(add, crs1) ROOT t = (f32[128], f32[128]) tuple(add, crs1)

View File

@ -383,16 +383,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1) TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "AllReduce should have 1 called computation but sees " << "AllReduce should have 1 called computation but sees "
<< proto.called_computation_ids_size(); << proto.called_computation_ids_size();
absl::optional<int64> all_reduce_id; TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
<< "AllReduce cannot have both channel_id() and all_reduce_id()";
absl::optional<int64> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
if (proto.all_reduce_id() > 0) { if (proto.all_reduce_id() > 0) {
all_reduce_id = proto.all_reduce_id(); channel_id = proto.all_reduce_id();
} }
instruction = CreateAllReduce( instruction = CreateAllReduce(
shape, all_operands(), computations(0), shape, all_operands(), computations(0),
/*replica_groups=*/ /*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(), std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()), proto.replica_groups().end()),
/*all_reduce_id=*/all_reduce_id); /*channel_id=*/channel_id);
break; break;
} }
case HloOpcode::kAllToAll: { case HloOpcode::kAllToAll: {
@ -860,9 +865,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
const Shape& shape, absl::Span<HloInstruction* const> operands, const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation, HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id) { const absl::optional<int64>& channel_id) {
return absl::make_unique<HloAllReduceInstruction>( return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_groups, all_reduce_id); shape, operands, reduce_computation, replica_groups, channel_id);
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
@ -1279,7 +1284,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kTrace: case HloOpcode::kTrace:
return true; return true;
case HloOpcode::kAllReduce: case HloOpcode::kAllReduce:
return all_reduce_id().has_value(); return channel_id().has_value();
case HloOpcode::kCustomCall: case HloOpcode::kCustomCall:
return Cast<HloCustomCallInstruction>(this) return Cast<HloCustomCallInstruction>(this)
->custom_call_has_side_effect(); ->custom_call_has_side_effect();
@ -2232,11 +2237,11 @@ bool HloInstruction::IsElementwiseImpl(
} }
bool HloInstruction::IsCrossModuleAllReduce() const { bool HloInstruction::IsCrossModuleAllReduce() const {
return opcode() == HloOpcode::kAllReduce && all_reduce_id(); return opcode() == HloOpcode::kAllReduce && channel_id();
} }
bool HloInstruction::IsCrossReplicaAllReduce() const { bool HloInstruction::IsCrossReplicaAllReduce() const {
return opcode() == HloOpcode::kAllReduce && !all_reduce_id(); return opcode() == HloOpcode::kAllReduce && !channel_id();
} }
string HloInstruction::ToStringWithCanonicalNameMap( string HloInstruction::ToStringWithCanonicalNameMap(
@ -3332,10 +3337,6 @@ const std::vector<int64>& HloInstruction::fft_length() const {
return Cast<HloFftInstruction>(this)->fft_length(); return Cast<HloFftInstruction>(this)->fft_length();
} }
int64 HloInstruction::channel_id() const {
return Cast<HloSendRecvInstruction>(this)->channel_id();
}
int64 HloInstruction::concatenate_dimension() const { int64 HloInstruction::concatenate_dimension() const {
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension(); return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
} }
@ -3535,13 +3536,12 @@ HloInstruction::source_target_pairs() const {
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs(); return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
} }
absl::optional<int64> HloInstruction::all_reduce_id() const { absl::optional<int64> HloInstruction::channel_id() const {
return Cast<HloAllReduceInstruction>(this)->all_reduce_id(); return Cast<HloChannelInstruction>(this)->channel_id();
} }
void HloInstruction::set_all_reduce_id( void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
const absl::optional<int64>& all_reduce_id) { return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
return Cast<HloAllReduceInstruction>(this)->set_all_reduce_id(all_reduce_id);
} }
const ConvolutionDimensionNumbers& const ConvolutionDimensionNumbers&

View File

@ -497,14 +497,14 @@ class HloInstruction {
// For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
// //
// `all_reduce_id`: for Allreduce nodes from different modules, if they have // `channel_id`: for Allreduce nodes from different modules, if
// the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // they have the same channel_id, they will be 'Allreduce'd. If
// not be applied cross modules. // empty, Allreduce will not be applied cross modules.
static std::unique_ptr<HloInstruction> CreateAllReduce( static std::unique_ptr<HloInstruction> CreateAllReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands, const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation, HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id); const absl::optional<int64>& channel_id);
// An all-to-all op takes N array operands of the same shape and scatters them // An all-to-all op takes N array operands of the same shape and scatters them
// to N replicas. Each replica gathers the results into a tuple. // to N replicas. Each replica gathers the results into a tuple.
@ -952,7 +952,7 @@ class HloInstruction {
return false; return false;
} }
// Two AllReduces are Identical if they have the same all_reduce_id. // Two AllReduces are Identical if they have the same channel_id.
// Their operands don't have to be Identical. // Their operands don't have to be Identical.
if (!IsCrossModuleAllReduce()) { if (!IsCrossModuleAllReduce()) {
// Use an explicit loop rather than ContainerEquals, because copying // Use an explicit loop rather than ContainerEquals, because copying
@ -1428,8 +1428,9 @@ class HloInstruction {
// Delegates to HloFftInstruction::fft_length. // Delegates to HloFftInstruction::fft_length.
const std::vector<int64>& fft_length() const; const std::vector<int64>& fft_length() const;
// Delegates to HloSendRecvInstruction::channel_id. // Delegates to HloChannelInstruction::channel_id.
int64 channel_id() const; absl::optional<int64> channel_id() const;
void set_channel_id(const absl::optional<int64>& channel_id);
// Returns the dimension sizes or numbers associated with this instruction. // Returns the dimension sizes or numbers associated with this instruction.
virtual const std::vector<int64>& dimensions() const { virtual const std::vector<int64>& dimensions() const {
@ -1571,10 +1572,6 @@ class HloInstruction {
// Delegates to HloCollectivePermuteInstruction::source_target_pairs. // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
const std::vector<std::pair<int64, int64>>& source_target_pairs() const; const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
// Delegates to HloAllReduceInstruction::all_reduce_id.
absl::optional<int64> all_reduce_id() const;
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
// Returns data on the window in a windowed operation such as // Returns data on the window in a windowed operation such as
// convolution. // convolution.
virtual const Window& window() const { virtual const Window& window() const {

View File

@ -361,25 +361,60 @@ HloCholeskyInstruction::CloneWithNewOperandsImpl(
cholesky_options()); cholesky_options());
} }
HloChannelInstruction::HloChannelInstruction(
HloOpcode opcode, const Shape& shape,
const absl::optional<int64>& channel_id)
: HloInstruction(opcode, shape), channel_id_(channel_id) {}
void HloChannelInstruction::set_channel_id(
const absl::optional<int64>& channel_id) {
channel_id_ = channel_id;
}
HloInstructionProto HloChannelInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
if (channel_id_) {
CHECK_GT(channel_id_.value(), 0)
<< "Non-positive channel id is equivalent to no channel id";
proto.set_channel_id(*channel_id_);
}
return proto;
}
std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const {
std::vector<string> result;
if (channel_id_) {
result.push_back(StrCat("channel_id=", *channel_id_));
}
return result;
}
bool HloChannelInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
/*eq_computations*/) const {
const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
return channel_id() == casted_other.channel_id();
}
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape, const Shape& shape,
int64 channel_id, int64 channel_id,
bool is_host_transfer) bool is_host_transfer)
: HloInstruction(opcode, shape), : HloChannelInstruction(opcode, shape, channel_id),
channel_id_(channel_id),
is_host_transfer_(is_host_transfer) {} is_host_transfer_(is_host_transfer) {}
HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto(); HloInstructionProto proto = HloChannelInstruction::ToProto();
proto.set_channel_id(channel_id_);
proto.set_is_host_transfer(is_host_transfer_); proto.set_is_host_transfer(is_host_transfer_);
return proto; return proto;
} }
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl( std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const { const HloPrintOptions& options) const {
std::vector<string> attrs; std::vector<string> attrs =
attrs.push_back(StrCat("channel_id=", channel_id_)); HloChannelInstruction::ExtraAttributesToStringImpl(options);
if (is_host_transfer()) { if (is_host_transfer()) {
attrs.push_back("is_host_transfer=true"); attrs.push_back("is_host_transfer=true");
} }
@ -413,13 +448,13 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const { HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2); CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloSendInstruction>( return absl::make_unique<HloSendInstruction>(
new_operands[0], new_operands[1], channel_id(), is_host_transfer()); new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
} }
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer) bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
CHECK_NOTNULL(operand)->channel_id(), CHECK_NOTNULL(operand)->channel_id().value(),
is_host_transfer) { is_host_transfer) {
AppendOperand(operand); AppendOperand(operand);
} }
@ -450,7 +485,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const { HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1); CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRecvInstruction>( return absl::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
is_host_transfer()); is_host_transfer());
} }
@ -461,7 +496,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
ShapeUtil::MakeTupleShape( ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(operand->shape(), 0), {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
ShapeUtil::MakeTokenShape()}), ShapeUtil::MakeTokenShape()}),
CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) { CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
AppendOperand(operand); AppendOperand(operand);
} }
@ -477,32 +512,39 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCollectiveInstruction::HloCollectiveInstruction( HloCollectiveInstruction::HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape, HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups) const std::vector<ReplicaGroup>& replica_groups,
: HloInstruction(opcode, shape), replica_groups_(replica_groups) { const absl::optional<int64>& channel_id)
: HloChannelInstruction(opcode, shape, channel_id),
replica_groups_({replica_groups.begin(), replica_groups.end()}) {
for (auto operand : operands) { for (auto operand : operands) {
AppendOperand(operand); AppendOperand(operand);
} }
} }
HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto HloCollectiveInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto(); HloInstructionProto proto = HloChannelInstruction::ToProto();
*proto.mutable_replica_groups() = {replica_groups_.begin(), *proto.mutable_replica_groups() = {replica_groups_.begin(),
replica_groups_.end()}; replica_groups_.end()};
return proto; return proto;
} }
std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl( std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const { const HloPrintOptions& options) const {
return {StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))}; std::vector<string> result =
HloChannelInstruction::ExtraAttributesToStringImpl(options);
result.push_back(
StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
return result;
} }
bool HloCollectiveInstruction::IdenticalSlowPath( bool HloCollectiveInstruction::IdenticalSlowPath(
const HloInstruction& other, const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>& const std::function<bool(const HloComputation*, const HloComputation*)>&
/*eq_computations*/) const { eq_computations) const {
const auto& casted_other = const auto& casted_other =
static_cast<const HloCollectiveInstruction&>(other); static_cast<const HloCollectiveInstruction&>(other);
return absl::c_equal(replica_groups(), casted_other.replica_groups(), return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
absl::c_equal(replica_groups(), casted_other.replica_groups(),
[](const ReplicaGroup& a, const ReplicaGroup& b) { [](const ReplicaGroup& a, const ReplicaGroup& b) {
return absl::c_equal(a.replica_ids(), b.replica_ids()); return absl::c_equal(a.replica_ids(), b.replica_ids());
}); });
@ -512,44 +554,19 @@ HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands, const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation, HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id) const absl::optional<int64>& channel_id)
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
replica_groups), replica_groups, channel_id) {
all_reduce_id_(all_reduce_id) {
AppendComputation(reduce_computation); AppendComputation(reduce_computation);
} }
void HloAllReduceInstruction::set_all_reduce_id(
const absl::optional<int64>& all_reduce_id) {
all_reduce_id_ = all_reduce_id;
}
HloInstructionProto HloAllReduceInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
// Proto3 is so sad.
if (all_reduce_id_) {
proto.set_all_reduce_id(*all_reduce_id_);
}
return proto;
}
bool HloAllReduceInstruction::IsNoop() const { bool HloAllReduceInstruction::IsNoop() const {
for (auto replica_group : replica_groups()) { for (auto replica_group : replica_groups()) {
if (replica_group.replica_ids().size() != 1) { if (replica_group.replica_ids().size() != 1) {
return false; return false;
} }
} }
return !all_reduce_id(); return !channel_id();
}
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> result =
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (all_reduce_id_) {
result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
}
return result;
} }
bool HloAllReduceInstruction::IdenticalSlowPath( bool HloAllReduceInstruction::IdenticalSlowPath(
@ -558,8 +575,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
eq_computations) const { eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other); const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
eq_computations(to_apply(), casted_other.to_apply()) && eq_computations(to_apply(), casted_other.to_apply());
all_reduce_id() == casted_other.all_reduce_id();
} }
std::unique_ptr<HloInstruction> std::unique_ptr<HloInstruction>
@ -567,14 +583,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands, const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const { HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllReduceInstruction>( return absl::make_unique<HloAllReduceInstruction>(
shape, new_operands, to_apply(), replica_groups(), all_reduce_id()); shape, new_operands, to_apply(), replica_groups(), channel_id());
} }
HloAllToAllInstruction::HloAllToAllInstruction( HloAllToAllInstruction::HloAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands, const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups) const std::vector<ReplicaGroup>& replica_groups)
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
replica_groups) {} replica_groups, absl::nullopt) {}
std::unique_ptr<HloInstruction> std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl( HloAllToAllInstruction::CloneWithNewOperandsImpl(
@ -587,13 +603,14 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs) const std::vector<std::pair<int64, int64>>& source_target_pairs)
: HloInstruction(HloOpcode::kCollectivePermute, shape), : HloChannelInstruction(HloOpcode::kCollectivePermute, shape,
absl::nullopt),
source_target_pairs_(source_target_pairs) { source_target_pairs_(source_target_pairs) {
AppendOperand(operand); AppendOperand(operand);
} }
HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto(); HloInstructionProto proto = HloChannelInstruction::ToProto();
for (const auto& pair : source_target_pairs()) { for (const auto& pair : source_target_pairs()) {
auto* proto_pair = proto.add_source_target_pairs(); auto* proto_pair = proto.add_source_target_pairs();
proto_pair->set_source(pair.first); proto_pair->set_source(pair.first);
@ -604,8 +621,9 @@ HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
std::vector<string> std::vector<string>
HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const { const HloPrintOptions& options) const {
std::vector<string> result; std::vector<string> result =
HloChannelInstruction::ExtraAttributesToStringImpl(options);
std::vector<string> strs; std::vector<string> strs;
for (const auto& pair : source_target_pairs()) { for (const auto& pair : source_target_pairs()) {
strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
@ -617,10 +635,11 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
bool HloCollectivePermuteInstruction::IdenticalSlowPath( bool HloCollectivePermuteInstruction::IdenticalSlowPath(
const HloInstruction& other, const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>& const std::function<bool(const HloComputation*, const HloComputation*)>&
/*eq_computations*/) const { eq_computations) const {
const auto& casted_other = const auto& casted_other =
static_cast<const HloCollectivePermuteInstruction&>(other); static_cast<const HloCollectivePermuteInstruction&>(other);
return absl::c_equal(source_target_pairs(), return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
absl::c_equal(source_target_pairs(),
casted_other.source_target_pairs(), casted_other.source_target_pairs(),
[](const std::pair<int64, int64>& a, [](const std::pair<int64, int64>& a,
const std::pair<int64, int64>& b) { return a == b; }); const std::pair<int64, int64>& b) { return a == b; });

View File

@ -206,13 +206,37 @@ class HloCholeskyInstruction : public HloInstruction {
CholeskyOptions cholesky_options_; CholeskyOptions cholesky_options_;
}; };
class HloSendRecvInstruction : public HloInstruction { // Class that represents instructions that synchronize and transfer data between
// partitioned devices. Send/Recv and collective instructions (AllReduce,
// AllToAll, CollectivePermute) belong to this instruction type. A group of
// instructions (of the same opcode) with the same channel_id communicate during
// execution.
class HloChannelInstruction : public HloInstruction {
public: public:
// Returns the channel id associated with the instruction. The id is // Returns the channel id associated with the instruction. The id is
// shared between each Send/Recv pair and is globally unique to identify each // shared between each Send/Recv pair or a group of collective instructions
// channel. // and is globally unique to identify each channel.
int64 channel_id() const { return channel_id_; } absl::optional<int64> channel_id() const { return channel_id_; }
void set_channel_id(const absl::optional<int64>& channel_id);
protected:
explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
absl::optional<int64> channel_id_;
};
class HloSendRecvInstruction : public HloChannelInstruction {
public:
// Returns whether this send/recv instruction sends data to/from the host. // Returns whether this send/recv instruction sends data to/from the host.
bool is_host_transfer() const { return is_host_transfer_; } bool is_host_transfer() const { return is_host_transfer_; }
@ -230,9 +254,6 @@ class HloSendRecvInstruction : public HloInstruction {
const HloInstruction& other, const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>& const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override; eq_computations) const override;
// Represents a unique identifier for each Send/Recv instruction pair.
int64 channel_id_;
// Whether this send/recv instruction sends data to/from the host. // Whether this send/recv instruction sends data to/from the host.
bool is_host_transfer_; bool is_host_transfer_;
}; };
@ -285,7 +306,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
HloCloneContext* context) const override; HloCloneContext* context) const override;
}; };
class HloCollectiveInstruction : public HloInstruction { class HloCollectiveInstruction : public HloChannelInstruction {
public: public:
const std::vector<ReplicaGroup>& replica_groups() const { const std::vector<ReplicaGroup>& replica_groups() const {
return replica_groups_; return replica_groups_;
@ -295,7 +316,8 @@ class HloCollectiveInstruction : public HloInstruction {
explicit HloCollectiveInstruction( explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape, HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups); const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override; HloInstructionProto ToProto() const override;
@ -315,21 +337,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
const Shape& shape, absl::Span<HloInstruction* const> operands, const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation, HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups, const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id); const absl::optional<int64>& channel_id);
absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Returns true if the AllReduce does no communication, so it's equivalent // Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy. // to a mem copy.
bool IsNoop() const; bool IsNoop() const;
private: private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath( bool IdenticalSlowPath(
const HloInstruction& other, const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>& const std::function<bool(const HloComputation*, const HloComputation*)>&
@ -339,11 +353,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands, const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override; HloCloneContext* context) const override;
// For Allreduce nodes from different modules, if they have the same
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
absl::optional<int64> all_reduce_id_;
}; };
class HloAllToAllInstruction : public HloCollectiveInstruction { class HloAllToAllInstruction : public HloCollectiveInstruction {
@ -359,7 +368,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
HloCloneContext* context) const override; HloCloneContext* context) const override;
}; };
class HloCollectivePermuteInstruction : public HloInstruction { class HloCollectivePermuteInstruction : public HloChannelInstruction {
public: public:
explicit HloCollectivePermuteInstruction( explicit HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,

View File

@ -155,11 +155,11 @@ ENTRY entry {
%p = f32[1000, 1000] parameter(0) %p = f32[1000, 1000] parameter(0)
%token.0 = token[] after-all() %token.0 = token[] after-all()
%send = (f32[1000, 1000], token[]) send(%p, %token.0), %send = (f32[1000, 1000], token[]) send(%p, %token.0),
channel_id=0, is_host_transfer=true channel_id=1, is_host_transfer=true
%n1 = f32[1000, 1000] negate(%p) %n1 = f32[1000, 1000] negate(%p)
%n2 = f32[1000, 1000] negate(%n1) %n2 = f32[1000, 1000] negate(%n1)
%n3 = f32[1000, 1000] negate(%n2) %n3 = f32[1000, 1000] negate(%n2)
%send-done = token[] send-done(%send), channel_id=0, is_host_transfer=true %send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true
} }
)"; )";

View File

@ -83,7 +83,7 @@ Status HloModuleGroupMetadata::Build() {
if (IsChannelInstruction(hlo)) { if (IsChannelInstruction(hlo)) {
peers.push_back(PeerComputation(hlo)); peers.push_back(PeerComputation(hlo));
} else if (hlo->IsCrossModuleAllReduce()) { } else if (hlo->IsCrossModuleAllReduce()) {
for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) {
if (instr == hlo) { if (instr == hlo) {
continue; continue;
} }
@ -235,7 +235,7 @@ bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
HloComputation* HloModuleGroupMetadata::PeerComputation( HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const { const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction)); CHECK(IsChannelInstruction(instruction));
const Channel& channel = GetChannel(instruction->channel_id()); const Channel& channel = GetChannel(*instruction->channel_id());
switch (instruction->opcode()) { switch (instruction->opcode()) {
case HloOpcode::kSend: case HloOpcode::kSend:
case HloOpcode::kSendDone: case HloOpcode::kSendDone:
@ -249,8 +249,8 @@ HloComputation* HloModuleGroupMetadata::PeerComputation(
} }
const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup( const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
int64 all_reduce_id) const { int64 channel_id) const {
auto it = all_reduce_map_.find(all_reduce_id); auto it = all_reduce_map_.find(channel_id);
CHECK(it != all_reduce_map_.end()); CHECK(it != all_reduce_map_.end());
return it->second; return it->second;
} }
@ -330,14 +330,14 @@ Status HloModuleGroupMetadata::RecordInstructions() {
TrackedInstruction(hlo, ComputationKind::kCallFunction); TrackedInstruction(hlo, ComputationKind::kCallFunction);
} }
// Group cross module all-reduce instructions by the all_reduce id. // Group cross module all-reduce instructions by the channel id.
if (hlo->IsCrossModuleAllReduce()) { if (hlo->IsCrossModuleAllReduce()) {
TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) == TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) ==
channel_id_map_.end()) channel_id_map_.end())
<< "all_reduce_id " << *hlo->all_reduce_id() << "channel_id " << *hlo->channel_id()
<< " is already used by a send/recv instruction"; << " is already used by a send/recv instruction";
all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo); all_reduce_map_[*hlo->channel_id()].push_back(hlo);
max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id()); max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
return Status::OK(); return Status::OK();
} }
@ -345,41 +345,41 @@ Status HloModuleGroupMetadata::RecordInstructions() {
return Status::OK(); return Status::OK();
} }
TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) == TF_RET_CHECK(all_reduce_map_.find(*hlo->channel_id()) ==
all_reduce_map_.end()) all_reduce_map_.end())
<< "channel id " << hlo->channel_id() << "channel id " << *hlo->channel_id()
<< " is already used by an all-reduce instruction"; << " is already used by an all-reduce instruction";
// Add a new channel if needed. // Add a new channel if needed.
if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { if (channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) {
channels_.emplace_back(); channels_.emplace_back();
channels_.back().id = hlo->channel_id(); channels_.back().id = *hlo->channel_id();
channel_id_map_[hlo->channel_id()] = channels_.size() - 1; channel_id_map_[*hlo->channel_id()] = channels_.size() - 1;
max_channel_id_ = std::max(max_channel_id_, hlo->channel_id()); max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
} }
Channel& channel = channels_[channel_id_map_[hlo->channel_id()]]; Channel& channel = channels_[channel_id_map_[*hlo->channel_id()]];
if (hlo->opcode() == HloOpcode::kSend) { if (hlo->opcode() == HloOpcode::kSend) {
TF_RET_CHECK(channel.send == nullptr) TF_RET_CHECK(channel.send == nullptr)
<< "channel id " << hlo->channel_id() << "channel id " << *hlo->channel_id()
<< " is used by multiple send instructions"; << " is used by multiple send instructions";
channel.send = hlo; channel.send = hlo;
} }
if (hlo->opcode() == HloOpcode::kRecv) { if (hlo->opcode() == HloOpcode::kRecv) {
TF_RET_CHECK(channel.recv == nullptr) TF_RET_CHECK(channel.recv == nullptr)
<< "channel id " << hlo->channel_id() << "channel id " << *hlo->channel_id()
<< " is used by multiple recv instructions"; << " is used by multiple recv instructions";
channel.recv = hlo; channel.recv = hlo;
} }
if (hlo->opcode() == HloOpcode::kSendDone) { if (hlo->opcode() == HloOpcode::kSendDone) {
TF_RET_CHECK(channel.send_done == nullptr) TF_RET_CHECK(channel.send_done == nullptr)
<< "channel id " << hlo->channel_id() << "channel id " << *hlo->channel_id()
<< " is used by multiple send-done instructions"; << " is used by multiple send-done instructions";
channel.send_done = hlo; channel.send_done = hlo;
} }
if (hlo->opcode() == HloOpcode::kRecvDone) { if (hlo->opcode() == HloOpcode::kRecvDone) {
TF_RET_CHECK(channel.recv_done == nullptr) TF_RET_CHECK(channel.recv_done == nullptr)
<< "channel id " << hlo->channel_id() << "channel id " << *hlo->channel_id()
<< " is used by multiple recv-done instructions"; << " is used by multiple recv-done instructions";
channel.recv_done = hlo; channel.recv_done = hlo;
} }

View File

@ -137,9 +137,8 @@ class HloModuleGroupMetadata {
// Returns if the given channel id exists in metadata. // Returns if the given channel id exists in metadata.
bool HasChannel(int64 channel_id) const; bool HasChannel(int64 channel_id) const;
// Returns the all-reduce instructions with the same all_reduce_id. // Returns the all-reduce instructions with the same channel_id.
const std::vector<HloInstruction*>& GetAllReduceGroup( const std::vector<HloInstruction*>& GetAllReduceGroup(int64 channel_id) const;
int64 all_reduce_id) const;
// Returns the computation that contains the peer channel instructions for // Returns the computation that contains the peer channel instructions for
// the given instruction. // the given instruction.
@ -205,7 +204,7 @@ class HloModuleGroupMetadata {
// Returns all channels in the module group. // Returns all channels in the module group.
const std::vector<Channel>& channels() const { return channels_; } const std::vector<Channel>& channels() const { return channels_; }
// Returns the maximum channel id or all_reduce_id used in the module group. // Returns the maximum channel id used in the module group.
int64 max_channel_id() const { return max_channel_id_; } int64 max_channel_id() const { return max_channel_id_; }
HloAliasAnalysis* alias_analysis(HloModule* module) const { HloAliasAnalysis* alias_analysis(HloModule* module) const {

View File

@ -62,7 +62,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
} }
if (predecessor->IsCrossModuleAllReduce()) { if (predecessor->IsCrossModuleAllReduce()) {
for (HloInstruction* instr : for (HloInstruction* instr :
metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { metadata_.GetAllReduceGroup(*predecessor->channel_id())) {
if (unique.insert(instr).second) { if (unique.insert(instr).second) {
predecessors.push_back(instr); predecessors.push_back(instr);
} }
@ -82,8 +82,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
instruction_group.push_back(companion); instruction_group.push_back(companion);
} }
} else if (instruction->IsCrossModuleAllReduce()) { } else if (instruction->IsCrossModuleAllReduce()) {
instruction_group = instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
} else { } else {
instruction_group.push_back(instruction); instruction_group.push_back(instruction);
} }
@ -99,14 +98,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
if (instruction->opcode() == HloOpcode::kRecvDone && if (instruction->opcode() == HloOpcode::kRecvDone &&
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) { !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone. // Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; HloInstruction* send =
metadata_.GetChannel(*instruction->channel_id()).send;
add_unique_predecessor(send); add_unique_predecessor(send);
} }
if (instruction->opcode() == HloOpcode::kSend && if (instruction->opcode() == HloOpcode::kSend &&
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) { !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send. // Recv is a remote predecessor of Send.
HloInstruction* recv_done = HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done; metadata_.GetChannel(*instruction->channel_id()).recv_done;
CHECK(recv_done->opcode() == HloOpcode::kRecvDone); CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
CHECK_EQ(recv_done->operand_count(), 1); CHECK_EQ(recv_done->operand_count(), 1);
HloInstruction* recv = recv_done->mutable_operand(0); HloInstruction* recv = recv_done->mutable_operand(0);
@ -139,7 +139,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
} }
if (successor->IsCrossModuleAllReduce()) { if (successor->IsCrossModuleAllReduce()) {
for (HloInstruction* instr : for (HloInstruction* instr :
metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { metadata_.GetAllReduceGroup(*successor->channel_id())) {
if (unique.insert(instr).second) { if (unique.insert(instr).second) {
successors.push_back(instr); successors.push_back(instr);
} }
@ -160,8 +160,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
instruction_group.push_back(companion); instruction_group.push_back(companion);
} }
} else if (instruction->IsCrossModuleAllReduce()) { } else if (instruction->IsCrossModuleAllReduce()) {
instruction_group = instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
} else { } else {
instruction_group.push_back(instruction); instruction_group.push_back(instruction);
} }
@ -179,14 +178,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
// Send is a remote successor of Recv. // Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front(); const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone); CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; HloInstruction* send =
metadata_.GetChannel(*instruction->channel_id()).send;
add_unique_successor(send); add_unique_successor(send);
} }
if (instruction->opcode() == HloOpcode::kSend && if (instruction->opcode() == HloOpcode::kSend &&
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) { !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send. // RecvDone is a remote successor of Send.
HloInstruction* recv_done = HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done; metadata_.GetChannel(*instruction->channel_id()).recv_done;
add_unique_successor(recv_done); add_unique_successor(recv_done);
} }
return successors; return successors;
@ -256,7 +256,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
instruction_group.push_back(companion); instruction_group.push_back(companion);
} }
} else if (hlo->IsCrossModuleAllReduce()) { } else if (hlo->IsCrossModuleAllReduce()) {
instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id()); instruction_group = metadata_.GetAllReduceGroup(*hlo->channel_id());
} else { } else {
instruction_group.push_back(hlo); instruction_group.push_back(hlo);
} }

View File

@ -836,13 +836,12 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<std::vector<int64>>> tmp_groups; optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply; optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids; optional<std::vector<int64>> replica_group_ids;
optional<int64> all_reduce_id; optional<int64> channel_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply}; &to_apply};
attrs["replica_groups"] = {/*required=*/false, attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups}; AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
&all_reduce_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false; return false;
} }
@ -851,7 +850,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
replica_groups = CreateReplicaGroups(*tmp_groups); replica_groups = CreateReplicaGroups(*tmp_groups);
} }
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce( instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
shape, operands, *to_apply, replica_groups, all_reduce_id)); shape, operands, *to_apply, replica_groups, channel_id));
break; break;
} }
case HloOpcode::kAllToAll: { case HloOpcode::kAllToAll: {

View File

@ -1416,8 +1416,8 @@ add {
ENTRY CRS { ENTRY CRS {
input = f32[8]{0} parameter(0) input = f32[8]{0} parameter(0)
crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
} }
)" )"

View File

@ -85,8 +85,8 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
std::vector<HloInstruction*> inputs; std::vector<HloInstruction*> inputs;
const auto add_input = [&channel_group, &inputs](HloInstruction* input) { const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
inputs.push_back(input); inputs.push_back(input);
if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) { if (input->opcode() == HloOpcode::kAllReduce && input->channel_id()) {
auto it = channel_group.find(*input->all_reduce_id()); auto it = channel_group.find(*input->channel_id());
if (it != channel_group.end()) { if (it != channel_group.end()) {
inputs.insert(inputs.end(), it->second.begin(), it->second.end()); inputs.insert(inputs.end(), it->second.begin(), it->second.end());
} }
@ -106,7 +106,7 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
switch (hlo->opcode()) { switch (hlo->opcode()) {
case HloOpcode::kRecvDone: { case HloOpcode::kRecvDone: {
auto it = channel_group.find(hlo->channel_id()); auto it = channel_group.find(*hlo->channel_id());
if (it != channel_group.end()) { if (it != channel_group.end()) {
for (HloInstruction* channel : it->second) { for (HloInstruction* channel : it->second) {
if (channel->opcode() == HloOpcode::kSend) { if (channel->opcode() == HloOpcode::kSend) {
@ -117,9 +117,9 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
break; break;
} }
case HloOpcode::kAllReduce: { case HloOpcode::kAllReduce: {
auto all_reduce_id = hlo->all_reduce_id(); auto channel_id = hlo->channel_id();
if (all_reduce_id) { if (channel_id) {
auto it = channel_group.find(all_reduce_id.value()); auto it = channel_group.find(channel_id.value());
if (it != channel_group.end()) { if (it != channel_group.end()) {
for (HloInstruction* all_reduce : it->second) { for (HloInstruction* all_reduce : it->second) {
add_dependencies(all_reduce); add_dependencies(all_reduce);

View File

@ -1210,8 +1210,8 @@ Status CheckSameChannel(const HloInstruction* instr1,
return InternalError( return InternalError(
"Expected to have the same channel id, actual channel ids are: %s " "Expected to have the same channel id, actual channel ids are: %s "
"(%d), %s (%d)", "(%d), %s (%d)",
instr1->ToString(), instr1->channel_id(), instr2->ToString(), instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
instr2->channel_id()); *instr2->channel_id());
} }
return Status::OK(); return Status::OK();
} }
@ -1282,14 +1282,14 @@ Status VerifySendsAndRecvs(const HloModule& module) {
DynCast<const HloSendRecvInstruction>(instruction); DynCast<const HloSendRecvInstruction>(instruction);
if (sendrecv->is_host_transfer()) { if (sendrecv->is_host_transfer()) {
auto it_inserted = auto it_inserted =
host_channels.insert({sendrecv->channel_id(), sendrecv}); host_channels.insert({*sendrecv->channel_id(), sendrecv});
if (!it_inserted.second) { if (!it_inserted.second) {
return FailedPrecondition( return FailedPrecondition(
"Channel %d is used for multiple host send/recv instructions: " "Channel %d is used for multiple host send/recv instructions: "
"%s " "%s "
"and " "and "
"%s", "%s",
sendrecv->channel_id(), sendrecv->ToString(), *sendrecv->channel_id(), sendrecv->ToString(),
it_inserted.first->second->ToString()); it_inserted.first->second->ToString());
} }
} }
@ -1574,9 +1574,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
} }
Status HandleAllReduce(HloInstruction* crs) override { Status HandleAllReduce(HloInstruction* crs) override {
if (crs->all_reduce_id().has_value()) { if (crs->channel_id().has_value()) {
TF_RET_CHECK(crs->all_reduce_id().value() > 0) TF_RET_CHECK(crs->channel_id().value() > 0)
<< "All reduce id must be greater than 0 for " << "All reduce channel id must be greater than 0 for "
<< crs->ToShortString(); << crs->ToShortString();
} }
return Status::OK(); return Status::OK();

View File

@ -262,7 +262,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
abs1 = f32[4,3]{1,0} abs(add) abs1 = f32[4,3]{1,0} abs(add)
log = f32[4,3]{1,0} log(abs1) log = f32[4,3]{1,0} log(abs1)
token0 = token[] after-all() token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=0 send = f32[4,3]{1,0} send(log, token0), channel_id=1
abs2 = f32[4,3]{1,0} abs(log) abs2 = f32[4,3]{1,0} abs(log)
ROOT root = f32[4,3]{1,0} subtract(abs2, add) ROOT root = f32[4,3]{1,0} subtract(abs2, add)
})") })")
@ -294,7 +294,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
add1 = f32[4,3]{1,0} add(p0, p1) add1 = f32[4,3]{1,0} add(p0, p1)
log = f32[4,3]{1,0} log(p0) log = f32[4,3]{1,0} log(p0)
token0 = token[] after-all() token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=0 send = f32[4,3]{1,0} send(log, token0), channel_id=1
add2 = f32[4,3]{1,0} add(log, add1) add2 = f32[4,3]{1,0} add(log, add1)
ROOT root = f32[4,3]{1,0} subtract(add1, add2) ROOT root = f32[4,3]{1,0} subtract(add1, add2)
})") })")
@ -329,7 +329,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
add2 = f32[4,3]{1,0} add(add1, p1) add2 = f32[4,3]{1,0} add(add1, p1)
log = f32[4,3]{1,0} log(add2) log = f32[4,3]{1,0} log(add2)
token0 = token[] after-all() token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=0 send = f32[4,3]{1,0} send(log, token0), channel_id=1
sub1 = f32[4,3]{1,0} subtract(log, add2) sub1 = f32[4,3]{1,0} subtract(log, add2)
sub2 = f32[4,3]{1,0} subtract(add2, add1) sub2 = f32[4,3]{1,0} subtract(add2, add1)
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)

View File

@ -408,7 +408,7 @@ Status LayoutAssignment::BuildHostChannelConstraints(
TF_RET_CHECK(data_shape.IsArray()); TF_RET_CHECK(data_shape.IsArray());
TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
send_recv_instr->channel_id(), data_shape.layout()); *send_recv_instr->channel_id(), data_shape.layout());
TF_RET_CHECK(prev_layout == nullptr) TF_RET_CHECK(prev_layout == nullptr)
<< "Cannot constrain host transfer layout as it was set to " << "Cannot constrain host transfer layout as it was set to "
<< LayoutUtil::HumanString(*prev_layout) << ": " << LayoutUtil::HumanString(*prev_layout) << ": "
@ -480,7 +480,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
instruction->opcode() == HloOpcode::kRecv) { instruction->opcode() == HloOpcode::kRecv) {
CHECK(get_channel_constraints(instruction)) CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints"; << "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id(); int64 channel_id = *instruction->channel_id();
if (!get_channel_constraints(instruction) if (!get_channel_constraints(instruction)
->IsChannelConstrained(channel_id)) { ->IsChannelConstrained(channel_id)) {
continue; continue;
@ -492,7 +492,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
Shape new_buffer_shape = Shape new_buffer_shape =
get_channel_constraints(instruction) get_channel_constraints(instruction)
->LayoutShapeForChannel(send_buffer_shape, ->LayoutShapeForChannel(send_buffer_shape,
instruction->channel_id()); *instruction->channel_id());
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
new_buffer_shape, instruction->operand(0))); new_buffer_shape, instruction->operand(0)));
} else { } else {
@ -503,18 +503,19 @@ Status LayoutAssignment::AddMandatoryConstraints(
const LogicalBuffer* buffer, const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(instruction, constraints->points_to_analysis().GetBufferDefinedAt(instruction,
{0})); {0}));
Shape new_shape = get_channel_constraints(instruction) Shape new_shape =
->LayoutShapeForChannel( get_channel_constraints(instruction)
recv_buffer_shape, instruction->channel_id()); ->LayoutShapeForChannel(recv_buffer_shape,
*instruction->channel_id());
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer)); constraints->SetBufferLayout(new_shape.layout(), *buffer));
} }
} else if (instruction->IsCrossModuleAllReduce()) { } else if (instruction->IsCrossModuleAllReduce()) {
CHECK(get_channel_constraints(instruction)) CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints"; << "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 all_reduce_id = instruction->all_reduce_id().value(); int64 channel_id = instruction->channel_id().value();
if (!get_channel_constraints(instruction) if (!get_channel_constraints(instruction)
->IsChannelConstrained(all_reduce_id)) { ->IsChannelConstrained(channel_id)) {
continue; continue;
} }
// TODO(b/68493863): Change to use SetOperandLayout(). // TODO(b/68493863): Change to use SetOperandLayout().
@ -522,7 +523,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RET_CHECK(buffer_shape.IsArray()); TF_RET_CHECK(buffer_shape.IsArray());
Shape new_buffer_shape = Shape new_buffer_shape =
get_channel_constraints(instruction) get_channel_constraints(instruction)
->LayoutShapeForChannel(buffer_shape, all_reduce_id); ->LayoutShapeForChannel(buffer_shape, channel_id);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(new_buffer_shape, instruction)); constraints->SetInstructionLayout(new_buffer_shape, instruction));
} }
@ -1833,7 +1834,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
const Layout* layout = const Layout* layout =
get_channel_constraints(instruction) get_channel_constraints(instruction)
->ConstrainChannel( ->ConstrainChannel(
instruction->channel_id(), *instruction->channel_id(),
ShapeUtil::GetSubshape(instruction->shape(), {0}).layout()); ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
TF_RET_CHECK(layout == nullptr) TF_RET_CHECK(layout == nullptr)
<< instruction->ToString() << instruction->ToString()
@ -1848,7 +1849,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
if (instruction->opcode() == HloOpcode::kSend) { if (instruction->opcode() == HloOpcode::kSend) {
HloInstruction* operand = instruction->mutable_operand(0); HloInstruction* operand = instruction->mutable_operand(0);
const Layout* layout = get_channel_constraints(instruction) const Layout* layout = get_channel_constraints(instruction)
->ConstrainChannel(instruction->channel_id(), ->ConstrainChannel(*instruction->channel_id(),
operand->shape().layout()); operand->shape().layout());
if (layout != nullptr) { if (layout != nullptr) {
// We found an already constrained layout which does not match the one // We found an already constrained layout which does not match the one
@ -1873,7 +1874,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
} else if (instruction->IsCrossModuleAllReduce()) { } else if (instruction->IsCrossModuleAllReduce()) {
const Layout* layout = const Layout* layout =
get_channel_constraints(instruction) get_channel_constraints(instruction)
->ConstrainChannel(instruction->all_reduce_id().value(), ->ConstrainChannel(instruction->channel_id().value(),
instruction->shape().layout()); instruction->shape().layout());
if (layout != nullptr) { if (layout != nullptr) {
// We found an already constrained layout which does not match the one // We found an already constrained layout which does not match the one

View File

@ -891,11 +891,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
param = (f32[2,2]) parameter(0) param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0 gte = f32[2,2] get-tuple-element(param), index=0
ar.0 = f32[2,2] all-reduce(gte), ar.0 = f32[2,2] all-reduce(gte),
all_reduce_id=1, replica_groups={{0}}, to_apply=add, channel_id=1, replica_groups={{0}}, to_apply=add,
sharding={maximal device=0} sharding={maximal device=0}
const = f32[2,2] constant({{0,1},{2,3}}) const = f32[2,2] constant({{0,1},{2,3}})
ROOT ar.1 = f32[2,2] all-reduce(const), ROOT ar.1 = f32[2,2] all-reduce(const),
all_reduce_id=1, replica_groups={{0}}, to_apply=add, channel_id=1, replica_groups={{0}}, to_apply=add,
sharding={maximal device=1} sharding={maximal device=1}
})"; })";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,