Change all_reduce_id to use channel_id
PiperOrigin-RevId: 255315385
This commit is contained in:
parent
6ae9600988
commit
852061b75b
@ -2068,7 +2068,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
||||
}
|
||||
|
||||
if (channel_id.has_value()) {
|
||||
instr.set_all_reduce_id(channel_id->handle());
|
||||
instr.set_channel_id(channel_id->handle());
|
||||
}
|
||||
|
||||
AddCalledComputation(computation, &instr);
|
||||
|
@ -319,7 +319,7 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||
auto maybe_pair = MatchesArCrsPattern(instruction);
|
||||
if (maybe_pair) {
|
||||
auto pair = *maybe_pair;
|
||||
int64 ar_id = *(instruction->all_reduce_id());
|
||||
int64 ar_id = *(instruction->channel_id());
|
||||
if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
@ -365,9 +365,10 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
|
||||
|
||||
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||
for (auto it : all_reduce_map_) {
|
||||
auto all_reduce_id = it.first;
|
||||
VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking ar_id: "
|
||||
<< all_reduce_id << "\n";
|
||||
auto channel_id = it.first;
|
||||
VLOG(2)
|
||||
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
|
||||
<< channel_id << "\n";
|
||||
auto pairs_vec = it.second;
|
||||
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
|
||||
auto instr_0 = pairs_vec[0].ar;
|
||||
@ -378,9 +379,10 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
|
||||
absl::flat_hash_map<int64, int64> visited_pairs;
|
||||
while (true) {
|
||||
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
|
||||
all_reduce_map_.erase(all_reduce_id);
|
||||
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased ar_id: "
|
||||
<< all_reduce_id << "\n";
|
||||
all_reduce_map_.erase(channel_id);
|
||||
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
|
||||
"channel id: "
|
||||
<< channel_id << "\n";
|
||||
break;
|
||||
}
|
||||
if (next_0->IsCrossReplicaAllReduce()) {
|
||||
@ -402,7 +404,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
||||
for (auto pair : pairs_vec) {
|
||||
auto all_reduce = pair.ar;
|
||||
auto parent_computation = all_reduce->parent();
|
||||
auto all_reduce_id = all_reduce->all_reduce_id();
|
||||
auto channel_id = all_reduce->channel_id();
|
||||
auto prev = all_reduce->mutable_operand(0);
|
||||
auto next = all_reduce->users()[0];
|
||||
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
|
||||
@ -447,7 +449,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
|
||||
next = next->users()[0];
|
||||
}
|
||||
// The AllReduce and the CRS are combined to an all-core AllReduce.
|
||||
next->set_all_reduce_id(all_reduce_id);
|
||||
next->set_channel_id(channel_id);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -36,7 +36,7 @@ namespace xla {
|
||||
//
|
||||
// The steps are:
|
||||
// 1) Find CMARs followed by simple ops followed by CRARs.
|
||||
// 2) Group CMARs by all_reduce_id. They must all be rewritten.
|
||||
// 2) Group CMARs by channel_id. They must all be rewritten.
|
||||
// 3) Prove that the CMAR patterns in each core produce the same result.
|
||||
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
|
||||
// other operand by the number of spatial partitions.
|
||||
@ -104,7 +104,7 @@ class ArCrsCombiner : public HloModulePass {
|
||||
}
|
||||
pieces.push_back(instruction->name());
|
||||
pieces.push_back(")[id:");
|
||||
pieces.push_back(std::to_string(*(ar->all_reduce_id())));
|
||||
pieces.push_back(std::to_string(*(ar->channel_id())));
|
||||
pieces.push_back(",dist:");
|
||||
pieces.push_back(std::to_string(distance));
|
||||
pieces.push_back("]");
|
||||
|
@ -414,7 +414,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = bf16[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[]
|
||||
@ -429,7 +429,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = bf16[]
|
||||
all-reduce(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[]
|
||||
@ -486,7 +486,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
||||
%all-reduce.ar.1 = f32[2,1]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.1,
|
||||
sharding={maximal device=0}
|
||||
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
|
||||
@ -499,7 +499,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
|
||||
%all-reduce.ar.2 = f32[2,1]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.1,
|
||||
sharding={maximal device=1}
|
||||
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
|
||||
@ -549,7 +549,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
%multiply.1 = f32[]
|
||||
@ -564,7 +564,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
%multiply.2 = f32[]
|
||||
@ -624,7 +624,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = bf16[]
|
||||
all-reduce(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[]
|
||||
@ -642,7 +642,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = bf16[]
|
||||
all-reduce(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[]
|
||||
@ -709,7 +709,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = bf16[]
|
||||
all-reduce(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[]
|
||||
@ -727,7 +727,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = bf16[]
|
||||
all-reduce(%constant.bf16),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[]
|
||||
@ -772,7 +772,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.1,
|
||||
sharding={maximal device=0}
|
||||
%all-reduce.1 = f32[]
|
||||
@ -787,7 +787,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.1,
|
||||
sharding={maximal device=1}
|
||||
%all-reduce.2 = f32[]
|
||||
@ -840,7 +840,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=0}
|
||||
%add.11 = f32[]
|
||||
@ -858,7 +858,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=0}
|
||||
%add.21 = f32[]
|
||||
@ -919,7 +919,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=0}
|
||||
%sub.1 = f32[]
|
||||
@ -934,7 +934,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.f32,
|
||||
sharding={maximal device=1}
|
||||
%sub.2 = f32[]
|
||||
@ -991,7 +991,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%ar11 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=0}
|
||||
%add11 = f32[]
|
||||
@ -1000,7 +1000,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%ar12 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=2,
|
||||
channel_id=2,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=0}
|
||||
%add12 = f32[]
|
||||
@ -1015,7 +1015,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%ar21 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=1}
|
||||
%add21 = f32[]
|
||||
@ -1024,7 +1024,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%ar22 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=2,
|
||||
channel_id=2,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=1}
|
||||
%add22 = f32[]
|
||||
@ -1083,13 +1083,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%ar11 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=0}
|
||||
%ar12 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=2,
|
||||
channel_id=2,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=0}
|
||||
%add11 = f32[]
|
||||
@ -1107,13 +1107,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
|
||||
%ar21 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=1}
|
||||
%ar22 = f32[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0},{1}},
|
||||
all_reduce_id=2,
|
||||
channel_id=2,
|
||||
to_apply=%sum,
|
||||
sharding={maximal device=1}
|
||||
%add21 = f32[]
|
||||
@ -1182,7 +1182,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.1 = bf16[]
|
||||
all-reduce(%p),
|
||||
replica_groups={{0}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=0}
|
||||
%convert.1 = f32[]
|
||||
@ -1197,7 +1197,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
%all-reduce.ar.2 = bf16[]
|
||||
all-reduce(%constant.bf16),
|
||||
replica_groups={{0}},
|
||||
all_reduce_id=1,
|
||||
channel_id=1,
|
||||
to_apply=%sum.bf16,
|
||||
sharding={maximal device=1}
|
||||
%convert.2 = f32[]
|
||||
@ -1276,9 +1276,9 @@ HloModule foobar
|
||||
|
||||
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
|
||||
%p = bf16[] parameter(0)
|
||||
%all-reduce.0 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}},
|
||||
%all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
|
||||
to_apply=%sum.f32, sharding={maximal device=0}
|
||||
%all-reduce.1 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}},
|
||||
%all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
|
||||
to_apply=%sum.f32, sharding={maximal device=1}
|
||||
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
|
||||
to_apply=%sum.f32, sharding={maximal device=0}
|
||||
|
@ -239,7 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
|
||||
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
|
||||
/*replica_groups=*/{},
|
||||
/*all_reduce_id=*/absl::nullopt));
|
||||
/*channel_id=*/absl::nullopt));
|
||||
HloInstruction* gte_a = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
|
||||
HloInstruction* gte_b = builder.AddInstruction(
|
||||
|
@ -257,7 +257,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
|
||||
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
|
||||
/*replica_groups=*/{},
|
||||
/*all_reduce_id=*/absl::nullopt));
|
||||
/*channel_id=*/absl::nullopt));
|
||||
HloInstruction* gte = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
|
||||
|
||||
|
@ -211,7 +211,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
|
||||
HloInstruction* all_reduce =
|
||||
builder.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
|
||||
/*replica_groups=*/{}, /*all_reduce_id=*/1));
|
||||
/*replica_groups=*/{}, /*channel_id=*/1));
|
||||
HloInstruction* gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
|
||||
HloInstruction* gte1 = builder.AddInstruction(
|
||||
|
@ -177,13 +177,13 @@ class NcclComm {
|
||||
// * Only ops with the same opcode can communicate with each other. At the
|
||||
// moment we only support kAllReduce, so we don't check for this explicitly.
|
||||
//
|
||||
// * For cross-module all-reduces (i.e. instr->all_reduce_id().has_value()),
|
||||
// only ops with the same value for all_reduce_id() can communicate with each
|
||||
// * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
|
||||
// only ops with the same value for channel_id() can communicate with each
|
||||
// other.
|
||||
//
|
||||
// * For cross-replica (i.e. same-module) all-reduces (i.e.
|
||||
// !all_reduce_id().has_value()), only ops from the same module (as identified
|
||||
// by its unique_id()) can communicate with each other.
|
||||
// !channel_id().has_value()), only ops from the same module (as
|
||||
// identified by its unique_id()) can communicate with each other.
|
||||
//
|
||||
struct RendezvousKey {
|
||||
enum AllReduceKind {
|
||||
@ -196,8 +196,8 @@ struct RendezvousKey {
|
||||
const HloAllReduceInstruction* instr)
|
||||
: run_id(run_id), participating_replicas(participating_replicas) {
|
||||
std::tie(all_reduce_kind, op_id) =
|
||||
instr->all_reduce_id().has_value()
|
||||
? std::make_pair(kCrossModule, instr->all_reduce_id().value())
|
||||
instr->channel_id().has_value()
|
||||
? std::make_pair(kCrossModule, instr->channel_id().value())
|
||||
: std::make_pair(
|
||||
kCrossReplica,
|
||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
||||
|
@ -126,8 +126,9 @@ message HloInstructionProto {
|
||||
// Only present for kBatchNormTraining.
|
||||
int64 feature_index = 25;
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair.
|
||||
// Only present for kSend or kRecv.
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id is equivalent to no channel id.
|
||||
int64 channel_id = 26;
|
||||
|
||||
// The string representation of the infeed configuration.
|
||||
@ -174,7 +175,9 @@ message HloInstructionProto {
|
||||
|
||||
// Cross replica op fields.
|
||||
repeated ReplicaGroup replica_groups = 49;
|
||||
int64 all_reduce_id = 45;
|
||||
// Deprecated, but keeping it for backward compatibility. Use channel_id.
|
||||
// Non-positive all_reduce_id is equivalent to no all_reduce_id.
|
||||
int64 all_reduce_id = 45 [deprecated = true];
|
||||
|
||||
// Whether this Send/Recv instruction transfers data to/from the host. Only
|
||||
// present for Send and Recv instructions and their SendDone and RecvDone
|
||||
|
@ -373,7 +373,7 @@ void HloComputation::ComputeInstructionPostOrder(
|
||||
case HloOpcode::kRecvDone:
|
||||
return inst->channel_id();
|
||||
case HloOpcode::kAllReduce:
|
||||
return inst->all_reduce_id();
|
||||
return inst->channel_id();
|
||||
default:
|
||||
return absl::nullopt;
|
||||
}
|
||||
@ -428,13 +428,10 @@ HloComputation::ComputeChannelDependencies() const {
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kRecvDone:
|
||||
channel_dependency_group[instruction->channel_id()].push_back(
|
||||
instruction.get());
|
||||
break;
|
||||
case HloOpcode::kAllReduce: {
|
||||
auto all_reduce_id = instruction->all_reduce_id();
|
||||
if (all_reduce_id) {
|
||||
channel_dependency_group[all_reduce_id.value()].push_back(
|
||||
auto channel_id = instruction->channel_id();
|
||||
if (channel_id) {
|
||||
channel_dependency_group[channel_id.value()].push_back(
|
||||
instruction.get());
|
||||
}
|
||||
break;
|
||||
|
@ -688,10 +688,10 @@ add {
|
||||
ENTRY entry {
|
||||
param = f32[128] parameter(0), sharding={maximal device=0}
|
||||
crs0 = f32[128] all-reduce(param),
|
||||
replica_groups={{0}}, all_reduce_id=1, to_apply=add,
|
||||
replica_groups={{0}}, channel_id=1, to_apply=add,
|
||||
sharding={maximal device=0}
|
||||
crs1 = f32[128] all-reduce(param),
|
||||
replica_groups={{0}}, all_reduce_id=1, to_apply=add,
|
||||
replica_groups={{0}}, channel_id=1, to_apply=add,
|
||||
sharding={maximal device=1}
|
||||
add = f32[128] add(crs0, crs0), sharding={maximal device=0}
|
||||
ROOT t = (f32[128], f32[128]) tuple(add, crs1)
|
||||
|
@ -383,16 +383,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
||||
<< "AllReduce should have 1 called computation but sees "
|
||||
<< proto.called_computation_ids_size();
|
||||
absl::optional<int64> all_reduce_id;
|
||||
TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
|
||||
<< "AllReduce cannot have both channel_id() and all_reduce_id()";
|
||||
absl::optional<int64> channel_id;
|
||||
if (proto.channel_id() > 0) {
|
||||
channel_id = proto.channel_id();
|
||||
}
|
||||
if (proto.all_reduce_id() > 0) {
|
||||
all_reduce_id = proto.all_reduce_id();
|
||||
channel_id = proto.all_reduce_id();
|
||||
}
|
||||
instruction = CreateAllReduce(
|
||||
shape, all_operands(), computations(0),
|
||||
/*replica_groups=*/
|
||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||
proto.replica_groups().end()),
|
||||
/*all_reduce_id=*/all_reduce_id);
|
||||
/*channel_id=*/channel_id);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kAllToAll: {
|
||||
@ -860,9 +865,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& all_reduce_id) {
|
||||
const absl::optional<int64>& channel_id) {
|
||||
return absl::make_unique<HloAllReduceInstruction>(
|
||||
shape, operands, reduce_computation, replica_groups, all_reduce_id);
|
||||
shape, operands, reduce_computation, replica_groups, channel_id);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
|
||||
@ -1279,7 +1284,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
|
||||
case HloOpcode::kTrace:
|
||||
return true;
|
||||
case HloOpcode::kAllReduce:
|
||||
return all_reduce_id().has_value();
|
||||
return channel_id().has_value();
|
||||
case HloOpcode::kCustomCall:
|
||||
return Cast<HloCustomCallInstruction>(this)
|
||||
->custom_call_has_side_effect();
|
||||
@ -2232,11 +2237,11 @@ bool HloInstruction::IsElementwiseImpl(
|
||||
}
|
||||
|
||||
bool HloInstruction::IsCrossModuleAllReduce() const {
|
||||
return opcode() == HloOpcode::kAllReduce && all_reduce_id();
|
||||
return opcode() == HloOpcode::kAllReduce && channel_id();
|
||||
}
|
||||
|
||||
bool HloInstruction::IsCrossReplicaAllReduce() const {
|
||||
return opcode() == HloOpcode::kAllReduce && !all_reduce_id();
|
||||
return opcode() == HloOpcode::kAllReduce && !channel_id();
|
||||
}
|
||||
|
||||
string HloInstruction::ToStringWithCanonicalNameMap(
|
||||
@ -3332,10 +3337,6 @@ const std::vector<int64>& HloInstruction::fft_length() const {
|
||||
return Cast<HloFftInstruction>(this)->fft_length();
|
||||
}
|
||||
|
||||
int64 HloInstruction::channel_id() const {
|
||||
return Cast<HloSendRecvInstruction>(this)->channel_id();
|
||||
}
|
||||
|
||||
int64 HloInstruction::concatenate_dimension() const {
|
||||
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
|
||||
}
|
||||
@ -3535,13 +3536,12 @@ HloInstruction::source_target_pairs() const {
|
||||
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
|
||||
}
|
||||
|
||||
absl::optional<int64> HloInstruction::all_reduce_id() const {
|
||||
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
|
||||
absl::optional<int64> HloInstruction::channel_id() const {
|
||||
return Cast<HloChannelInstruction>(this)->channel_id();
|
||||
}
|
||||
|
||||
void HloInstruction::set_all_reduce_id(
|
||||
const absl::optional<int64>& all_reduce_id) {
|
||||
return Cast<HloAllReduceInstruction>(this)->set_all_reduce_id(all_reduce_id);
|
||||
void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
|
||||
return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
|
||||
}
|
||||
|
||||
const ConvolutionDimensionNumbers&
|
||||
|
@ -497,14 +497,14 @@ class HloInstruction {
|
||||
// For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
|
||||
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
|
||||
//
|
||||
// `all_reduce_id`: for Allreduce nodes from different modules, if they have
|
||||
// the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
|
||||
// not be applied cross modules.
|
||||
// `channel_id`: for Allreduce nodes from different modules, if
|
||||
// they have the same channel_id, they will be 'Allreduce'd. If
|
||||
// empty, Allreduce will not be applied cross modules.
|
||||
static std::unique_ptr<HloInstruction> CreateAllReduce(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& all_reduce_id);
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// An all-to-all op takes N array operands of the same shape and scatters them
|
||||
// to N replicas. Each replica gathers the results into a tuple.
|
||||
@ -952,7 +952,7 @@ class HloInstruction {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Two AllReduces are Identical if they have the same all_reduce_id.
|
||||
// Two AllReduces are Identical if they have the same channel_id.
|
||||
// Their operands don't have to be Identical.
|
||||
if (!IsCrossModuleAllReduce()) {
|
||||
// Use an explicit loop rather than ContainerEquals, because copying
|
||||
@ -1428,8 +1428,9 @@ class HloInstruction {
|
||||
// Delegates to HloFftInstruction::fft_length.
|
||||
const std::vector<int64>& fft_length() const;
|
||||
|
||||
// Delegates to HloSendRecvInstruction::channel_id.
|
||||
int64 channel_id() const;
|
||||
// Delegates to HloChannelInstruction::channel_id.
|
||||
absl::optional<int64> channel_id() const;
|
||||
void set_channel_id(const absl::optional<int64>& channel_id);
|
||||
|
||||
// Returns the dimension sizes or numbers associated with this instruction.
|
||||
virtual const std::vector<int64>& dimensions() const {
|
||||
@ -1571,10 +1572,6 @@ class HloInstruction {
|
||||
// Delegates to HloCollectivePermuteInstruction::source_target_pairs.
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
|
||||
|
||||
// Delegates to HloAllReduceInstruction::all_reduce_id.
|
||||
absl::optional<int64> all_reduce_id() const;
|
||||
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
|
||||
|
||||
// Returns data on the window in a windowed operation such as
|
||||
// convolution.
|
||||
virtual const Window& window() const {
|
||||
|
@ -361,25 +361,60 @@ HloCholeskyInstruction::CloneWithNewOperandsImpl(
|
||||
cholesky_options());
|
||||
}
|
||||
|
||||
HloChannelInstruction::HloChannelInstruction(
|
||||
HloOpcode opcode, const Shape& shape,
|
||||
const absl::optional<int64>& channel_id)
|
||||
: HloInstruction(opcode, shape), channel_id_(channel_id) {}
|
||||
|
||||
void HloChannelInstruction::set_channel_id(
|
||||
const absl::optional<int64>& channel_id) {
|
||||
channel_id_ = channel_id;
|
||||
}
|
||||
|
||||
HloInstructionProto HloChannelInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
if (channel_id_) {
|
||||
CHECK_GT(channel_id_.value(), 0)
|
||||
<< "Non-positive channel id is equivalent to no channel id";
|
||||
proto.set_channel_id(*channel_id_);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& /*options*/) const {
|
||||
std::vector<string> result;
|
||||
if (channel_id_) {
|
||||
result.push_back(StrCat("channel_id=", *channel_id_));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HloChannelInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
/*eq_computations*/) const {
|
||||
const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
|
||||
return channel_id() == casted_other.channel_id();
|
||||
}
|
||||
|
||||
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
|
||||
const Shape& shape,
|
||||
int64 channel_id,
|
||||
bool is_host_transfer)
|
||||
: HloInstruction(opcode, shape),
|
||||
channel_id_(channel_id),
|
||||
: HloChannelInstruction(opcode, shape, channel_id),
|
||||
is_host_transfer_(is_host_transfer) {}
|
||||
|
||||
HloInstructionProto HloSendRecvInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
proto.set_channel_id(channel_id_);
|
||||
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||
proto.set_is_host_transfer(is_host_transfer_);
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> attrs;
|
||||
attrs.push_back(StrCat("channel_id=", channel_id_));
|
||||
std::vector<string> attrs =
|
||||
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||
if (is_host_transfer()) {
|
||||
attrs.push_back("is_host_transfer=true");
|
||||
}
|
||||
@ -413,13 +448,13 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 2);
|
||||
return absl::make_unique<HloSendInstruction>(
|
||||
new_operands[0], new_operands[1], channel_id(), is_host_transfer());
|
||||
new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
|
||||
}
|
||||
|
||||
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
|
||||
bool is_host_transfer)
|
||||
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
|
||||
CHECK_NOTNULL(operand)->channel_id(),
|
||||
CHECK_NOTNULL(operand)->channel_id().value(),
|
||||
is_host_transfer) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
@ -450,7 +485,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
return absl::make_unique<HloRecvInstruction>(
|
||||
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
|
||||
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
|
||||
is_host_transfer());
|
||||
}
|
||||
|
||||
@ -461,7 +496,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
|
||||
ShapeUtil::MakeTokenShape()}),
|
||||
CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
|
||||
CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
|
||||
@ -477,32 +512,39 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
|
||||
HloCollectiveInstruction::HloCollectiveInstruction(
|
||||
HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups)
|
||||
: HloInstruction(opcode, shape), replica_groups_(replica_groups) {
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& channel_id)
|
||||
: HloChannelInstruction(opcode, shape, channel_id),
|
||||
replica_groups_({replica_groups.begin(), replica_groups.end()}) {
|
||||
for (auto operand : operands) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
}
|
||||
|
||||
HloInstructionProto HloCollectiveInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||
*proto.mutable_replica_groups() = {replica_groups_.begin(),
|
||||
replica_groups_.end()};
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& /*options*/) const {
|
||||
return {StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))};
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> result =
|
||||
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||
result.push_back(
|
||||
StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HloCollectiveInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
/*eq_computations*/) const {
|
||||
eq_computations) const {
|
||||
const auto& casted_other =
|
||||
static_cast<const HloCollectiveInstruction&>(other);
|
||||
return absl::c_equal(replica_groups(), casted_other.replica_groups(),
|
||||
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||
absl::c_equal(replica_groups(), casted_other.replica_groups(),
|
||||
[](const ReplicaGroup& a, const ReplicaGroup& b) {
|
||||
return absl::c_equal(a.replica_ids(), b.replica_ids());
|
||||
});
|
||||
@ -512,44 +554,19 @@ HloAllReduceInstruction::HloAllReduceInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& all_reduce_id)
|
||||
const absl::optional<int64>& channel_id)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
|
||||
replica_groups),
|
||||
all_reduce_id_(all_reduce_id) {
|
||||
replica_groups, channel_id) {
|
||||
AppendComputation(reduce_computation);
|
||||
}
|
||||
|
||||
void HloAllReduceInstruction::set_all_reduce_id(
|
||||
const absl::optional<int64>& all_reduce_id) {
|
||||
all_reduce_id_ = all_reduce_id;
|
||||
}
|
||||
|
||||
HloInstructionProto HloAllReduceInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
|
||||
// Proto3 is so sad.
|
||||
if (all_reduce_id_) {
|
||||
proto.set_all_reduce_id(*all_reduce_id_);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
bool HloAllReduceInstruction::IsNoop() const {
|
||||
for (auto replica_group : replica_groups()) {
|
||||
if (replica_group.replica_ids().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return !all_reduce_id();
|
||||
}
|
||||
|
||||
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> result =
|
||||
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
|
||||
if (all_reduce_id_) {
|
||||
result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
|
||||
}
|
||||
return result;
|
||||
return !channel_id();
|
||||
}
|
||||
|
||||
bool HloAllReduceInstruction::IdenticalSlowPath(
|
||||
@ -558,8 +575,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
|
||||
eq_computations) const {
|
||||
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
|
||||
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||
eq_computations(to_apply(), casted_other.to_apply()) &&
|
||||
all_reduce_id() == casted_other.all_reduce_id();
|
||||
eq_computations(to_apply(), casted_other.to_apply());
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
@ -567,14 +583,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* /*context*/) const {
|
||||
return absl::make_unique<HloAllReduceInstruction>(
|
||||
shape, new_operands, to_apply(), replica_groups(), all_reduce_id());
|
||||
shape, new_operands, to_apply(), replica_groups(), channel_id());
|
||||
}
|
||||
|
||||
HloAllToAllInstruction::HloAllToAllInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
|
||||
replica_groups) {}
|
||||
replica_groups, absl::nullopt) {}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
HloAllToAllInstruction::CloneWithNewOperandsImpl(
|
||||
@ -587,13 +603,14 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
|
||||
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
const std::vector<std::pair<int64, int64>>& source_target_pairs)
|
||||
: HloInstruction(HloOpcode::kCollectivePermute, shape),
|
||||
: HloChannelInstruction(HloOpcode::kCollectivePermute, shape,
|
||||
absl::nullopt),
|
||||
source_target_pairs_(source_target_pairs) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
|
||||
HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
HloInstructionProto proto = HloChannelInstruction::ToProto();
|
||||
for (const auto& pair : source_target_pairs()) {
|
||||
auto* proto_pair = proto.add_source_target_pairs();
|
||||
proto_pair->set_source(pair.first);
|
||||
@ -604,8 +621,9 @@ HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
|
||||
|
||||
std::vector<string>
|
||||
HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& /*options*/) const {
|
||||
std::vector<string> result;
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> result =
|
||||
HloChannelInstruction::ExtraAttributesToStringImpl(options);
|
||||
std::vector<string> strs;
|
||||
for (const auto& pair : source_target_pairs()) {
|
||||
strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
|
||||
@ -617,10 +635,11 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
|
||||
bool HloCollectivePermuteInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
/*eq_computations*/) const {
|
||||
eq_computations) const {
|
||||
const auto& casted_other =
|
||||
static_cast<const HloCollectivePermuteInstruction&>(other);
|
||||
return absl::c_equal(source_target_pairs(),
|
||||
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||
absl::c_equal(source_target_pairs(),
|
||||
casted_other.source_target_pairs(),
|
||||
[](const std::pair<int64, int64>& a,
|
||||
const std::pair<int64, int64>& b) { return a == b; });
|
||||
|
@ -206,13 +206,37 @@ class HloCholeskyInstruction : public HloInstruction {
|
||||
CholeskyOptions cholesky_options_;
|
||||
};
|
||||
|
||||
class HloSendRecvInstruction : public HloInstruction {
|
||||
// Class that represents instructions that synchronize and transfer data between
|
||||
// partitioned devices. Send/Recv and collective instructions (AllReduce,
|
||||
// AllToAll, CollectivePermute) belong to this instruction type. A group of
|
||||
// instructions (of the same opcode) with the same channel_id communicate during
|
||||
// execution.
|
||||
class HloChannelInstruction : public HloInstruction {
|
||||
public:
|
||||
// Returns the channel id associated with the instruction. The id is
|
||||
// shared between each Send/Recv pair and is globally unique to identify each
|
||||
// channel.
|
||||
int64 channel_id() const { return channel_id_; }
|
||||
// shared between each Send/Recv pair or a group of collective instructions
|
||||
// and is globally unique to identify each channel.
|
||||
absl::optional<int64> channel_id() const { return channel_id_; }
|
||||
void set_channel_id(const absl::optional<int64>& channel_id);
|
||||
|
||||
protected:
|
||||
explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
|
||||
absl::optional<int64> channel_id_;
|
||||
};
|
||||
|
||||
class HloSendRecvInstruction : public HloChannelInstruction {
|
||||
public:
|
||||
// Returns whether this send/recv instruction sends data to/from the host.
|
||||
bool is_host_transfer() const { return is_host_transfer_; }
|
||||
|
||||
@ -230,9 +254,6 @@ class HloSendRecvInstruction : public HloInstruction {
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
// Represents a unique identifier for each Send/Recv instruction pair.
|
||||
int64 channel_id_;
|
||||
|
||||
// Whether this send/recv instruction sends data to/from the host.
|
||||
bool is_host_transfer_;
|
||||
};
|
||||
@ -285,7 +306,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
|
||||
HloCloneContext* context) const override;
|
||||
};
|
||||
|
||||
class HloCollectiveInstruction : public HloInstruction {
|
||||
class HloCollectiveInstruction : public HloChannelInstruction {
|
||||
public:
|
||||
const std::vector<ReplicaGroup>& replica_groups() const {
|
||||
return replica_groups_;
|
||||
@ -295,7 +316,8 @@ class HloCollectiveInstruction : public HloInstruction {
|
||||
explicit HloCollectiveInstruction(
|
||||
HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
@ -315,21 +337,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* reduce_computation,
|
||||
const std::vector<ReplicaGroup>& replica_groups,
|
||||
const absl::optional<int64>& all_reduce_id);
|
||||
|
||||
absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
|
||||
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
|
||||
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
const absl::optional<int64>& channel_id);
|
||||
|
||||
// Returns true if the AllReduce does no communication, so it's equivalent
|
||||
// to a mem copy.
|
||||
bool IsNoop() const;
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
@ -339,11 +353,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
// For Allreduce nodes from different modules, if they have the same
|
||||
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
|
||||
// applied cross modules.
|
||||
absl::optional<int64> all_reduce_id_;
|
||||
};
|
||||
|
||||
class HloAllToAllInstruction : public HloCollectiveInstruction {
|
||||
@ -359,7 +368,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
|
||||
HloCloneContext* context) const override;
|
||||
};
|
||||
|
||||
class HloCollectivePermuteInstruction : public HloInstruction {
|
||||
class HloCollectivePermuteInstruction : public HloChannelInstruction {
|
||||
public:
|
||||
explicit HloCollectivePermuteInstruction(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
|
@ -155,11 +155,11 @@ ENTRY entry {
|
||||
%p = f32[1000, 1000] parameter(0)
|
||||
%token.0 = token[] after-all()
|
||||
%send = (f32[1000, 1000], token[]) send(%p, %token.0),
|
||||
channel_id=0, is_host_transfer=true
|
||||
channel_id=1, is_host_transfer=true
|
||||
%n1 = f32[1000, 1000] negate(%p)
|
||||
%n2 = f32[1000, 1000] negate(%n1)
|
||||
%n3 = f32[1000, 1000] negate(%n2)
|
||||
%send-done = token[] send-done(%send), channel_id=0, is_host_transfer=true
|
||||
%send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true
|
||||
}
|
||||
)";
|
||||
|
||||
|
@ -83,7 +83,7 @@ Status HloModuleGroupMetadata::Build() {
|
||||
if (IsChannelInstruction(hlo)) {
|
||||
peers.push_back(PeerComputation(hlo));
|
||||
} else if (hlo->IsCrossModuleAllReduce()) {
|
||||
for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) {
|
||||
for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) {
|
||||
if (instr == hlo) {
|
||||
continue;
|
||||
}
|
||||
@ -235,7 +235,7 @@ bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
|
||||
HloComputation* HloModuleGroupMetadata::PeerComputation(
|
||||
const HloInstruction* instruction) const {
|
||||
CHECK(IsChannelInstruction(instruction));
|
||||
const Channel& channel = GetChannel(instruction->channel_id());
|
||||
const Channel& channel = GetChannel(*instruction->channel_id());
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
@ -249,8 +249,8 @@ HloComputation* HloModuleGroupMetadata::PeerComputation(
|
||||
}
|
||||
|
||||
const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
|
||||
int64 all_reduce_id) const {
|
||||
auto it = all_reduce_map_.find(all_reduce_id);
|
||||
int64 channel_id) const {
|
||||
auto it = all_reduce_map_.find(channel_id);
|
||||
CHECK(it != all_reduce_map_.end());
|
||||
return it->second;
|
||||
}
|
||||
@ -330,14 +330,14 @@ Status HloModuleGroupMetadata::RecordInstructions() {
|
||||
TrackedInstruction(hlo, ComputationKind::kCallFunction);
|
||||
}
|
||||
|
||||
// Group cross module all-reduce instructions by the all_reduce id.
|
||||
// Group cross module all-reduce instructions by the channel id.
|
||||
if (hlo->IsCrossModuleAllReduce()) {
|
||||
TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) ==
|
||||
TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) ==
|
||||
channel_id_map_.end())
|
||||
<< "all_reduce_id " << *hlo->all_reduce_id()
|
||||
<< "channel_id " << *hlo->channel_id()
|
||||
<< " is already used by a send/recv instruction";
|
||||
all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo);
|
||||
max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id());
|
||||
all_reduce_map_[*hlo->channel_id()].push_back(hlo);
|
||||
max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -345,41 +345,41 @@ Status HloModuleGroupMetadata::RecordInstructions() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) ==
|
||||
TF_RET_CHECK(all_reduce_map_.find(*hlo->channel_id()) ==
|
||||
all_reduce_map_.end())
|
||||
<< "channel id " << hlo->channel_id()
|
||||
<< "channel id " << *hlo->channel_id()
|
||||
<< " is already used by an all-reduce instruction";
|
||||
|
||||
// Add a new channel if needed.
|
||||
if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
|
||||
if (channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) {
|
||||
channels_.emplace_back();
|
||||
channels_.back().id = hlo->channel_id();
|
||||
channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
|
||||
max_channel_id_ = std::max(max_channel_id_, hlo->channel_id());
|
||||
channels_.back().id = *hlo->channel_id();
|
||||
channel_id_map_[*hlo->channel_id()] = channels_.size() - 1;
|
||||
max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
|
||||
}
|
||||
Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
|
||||
Channel& channel = channels_[channel_id_map_[*hlo->channel_id()]];
|
||||
|
||||
if (hlo->opcode() == HloOpcode::kSend) {
|
||||
TF_RET_CHECK(channel.send == nullptr)
|
||||
<< "channel id " << hlo->channel_id()
|
||||
<< "channel id " << *hlo->channel_id()
|
||||
<< " is used by multiple send instructions";
|
||||
channel.send = hlo;
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kRecv) {
|
||||
TF_RET_CHECK(channel.recv == nullptr)
|
||||
<< "channel id " << hlo->channel_id()
|
||||
<< "channel id " << *hlo->channel_id()
|
||||
<< " is used by multiple recv instructions";
|
||||
channel.recv = hlo;
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kSendDone) {
|
||||
TF_RET_CHECK(channel.send_done == nullptr)
|
||||
<< "channel id " << hlo->channel_id()
|
||||
<< "channel id " << *hlo->channel_id()
|
||||
<< " is used by multiple send-done instructions";
|
||||
channel.send_done = hlo;
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kRecvDone) {
|
||||
TF_RET_CHECK(channel.recv_done == nullptr)
|
||||
<< "channel id " << hlo->channel_id()
|
||||
<< "channel id " << *hlo->channel_id()
|
||||
<< " is used by multiple recv-done instructions";
|
||||
channel.recv_done = hlo;
|
||||
}
|
||||
|
@ -137,9 +137,8 @@ class HloModuleGroupMetadata {
|
||||
// Returns if the given channel id exists in metadata.
|
||||
bool HasChannel(int64 channel_id) const;
|
||||
|
||||
// Returns the all-reduce instructions with the same all_reduce_id.
|
||||
const std::vector<HloInstruction*>& GetAllReduceGroup(
|
||||
int64 all_reduce_id) const;
|
||||
// Returns the all-reduce instructions with the same channel_id.
|
||||
const std::vector<HloInstruction*>& GetAllReduceGroup(int64 channel_id) const;
|
||||
|
||||
// Returns the computation that contains the peer channel instructions for
|
||||
// the given instruction.
|
||||
@ -205,7 +204,7 @@ class HloModuleGroupMetadata {
|
||||
// Returns all channels in the module group.
|
||||
const std::vector<Channel>& channels() const { return channels_; }
|
||||
|
||||
// Returns the maximum channel id or all_reduce_id used in the module group.
|
||||
// Returns the maximum channel id used in the module group.
|
||||
int64 max_channel_id() const { return max_channel_id_; }
|
||||
|
||||
HloAliasAnalysis* alias_analysis(HloModule* module) const {
|
||||
|
@ -62,7 +62,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
||||
}
|
||||
if (predecessor->IsCrossModuleAllReduce()) {
|
||||
for (HloInstruction* instr :
|
||||
metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
|
||||
metadata_.GetAllReduceGroup(*predecessor->channel_id())) {
|
||||
if (unique.insert(instr).second) {
|
||||
predecessors.push_back(instr);
|
||||
}
|
||||
@ -82,8 +82,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
||||
instruction_group.push_back(companion);
|
||||
}
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
instruction_group =
|
||||
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
|
||||
instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
|
||||
} else {
|
||||
instruction_group.push_back(instruction);
|
||||
}
|
||||
@ -99,14 +98,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
||||
if (instruction->opcode() == HloOpcode::kRecvDone &&
|
||||
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
|
||||
// Send is a remote predecessor of RecvDone.
|
||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
||||
HloInstruction* send =
|
||||
metadata_.GetChannel(*instruction->channel_id()).send;
|
||||
add_unique_predecessor(send);
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kSend &&
|
||||
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||
// Recv is a remote predecessor of Send.
|
||||
HloInstruction* recv_done =
|
||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
||||
metadata_.GetChannel(*instruction->channel_id()).recv_done;
|
||||
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||
CHECK_EQ(recv_done->operand_count(), 1);
|
||||
HloInstruction* recv = recv_done->mutable_operand(0);
|
||||
@ -139,7 +139,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
||||
}
|
||||
if (successor->IsCrossModuleAllReduce()) {
|
||||
for (HloInstruction* instr :
|
||||
metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
|
||||
metadata_.GetAllReduceGroup(*successor->channel_id())) {
|
||||
if (unique.insert(instr).second) {
|
||||
successors.push_back(instr);
|
||||
}
|
||||
@ -160,8 +160,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
||||
instruction_group.push_back(companion);
|
||||
}
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
instruction_group =
|
||||
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
|
||||
instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
|
||||
} else {
|
||||
instruction_group.push_back(instruction);
|
||||
}
|
||||
@ -179,14 +178,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
||||
// Send is a remote successor of Recv.
|
||||
const HloInstruction* recv_done = instruction->users().front();
|
||||
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
|
||||
HloInstruction* send =
|
||||
metadata_.GetChannel(*instruction->channel_id()).send;
|
||||
add_unique_successor(send);
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kSend &&
|
||||
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
|
||||
// RecvDone is a remote successor of Send.
|
||||
HloInstruction* recv_done =
|
||||
metadata_.GetChannel(instruction->channel_id()).recv_done;
|
||||
metadata_.GetChannel(*instruction->channel_id()).recv_done;
|
||||
add_unique_successor(recv_done);
|
||||
}
|
||||
return successors;
|
||||
@ -256,7 +256,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
|
||||
instruction_group.push_back(companion);
|
||||
}
|
||||
} else if (hlo->IsCrossModuleAllReduce()) {
|
||||
instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id());
|
||||
instruction_group = metadata_.GetAllReduceGroup(*hlo->channel_id());
|
||||
} else {
|
||||
instruction_group.push_back(hlo);
|
||||
}
|
||||
|
@ -836,13 +836,12 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<std::vector<std::vector<int64>>> tmp_groups;
|
||||
optional<HloComputation*> to_apply;
|
||||
optional<std::vector<int64>> replica_group_ids;
|
||||
optional<int64> all_reduce_id;
|
||||
optional<int64> channel_id;
|
||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||
&to_apply};
|
||||
attrs["replica_groups"] = {/*required=*/false,
|
||||
AttrTy::kBracedInt64ListList, &tmp_groups};
|
||||
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
|
||||
&all_reduce_id};
|
||||
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
@ -851,7 +850,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
replica_groups = CreateReplicaGroups(*tmp_groups);
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
|
||||
shape, operands, *to_apply, replica_groups, all_reduce_id));
|
||||
shape, operands, *to_apply, replica_groups, channel_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kAllToAll: {
|
||||
|
@ -1416,8 +1416,8 @@ add {
|
||||
|
||||
ENTRY CRS {
|
||||
input = f32[8]{0} parameter(0)
|
||||
crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
|
||||
ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
|
||||
crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
|
||||
ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
|
||||
}
|
||||
|
||||
)"
|
||||
|
@ -85,8 +85,8 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
||||
std::vector<HloInstruction*> inputs;
|
||||
const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
|
||||
inputs.push_back(input);
|
||||
if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) {
|
||||
auto it = channel_group.find(*input->all_reduce_id());
|
||||
if (input->opcode() == HloOpcode::kAllReduce && input->channel_id()) {
|
||||
auto it = channel_group.find(*input->channel_id());
|
||||
if (it != channel_group.end()) {
|
||||
inputs.insert(inputs.end(), it->second.begin(), it->second.end());
|
||||
}
|
||||
@ -106,7 +106,7 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
||||
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kRecvDone: {
|
||||
auto it = channel_group.find(hlo->channel_id());
|
||||
auto it = channel_group.find(*hlo->channel_id());
|
||||
if (it != channel_group.end()) {
|
||||
for (HloInstruction* channel : it->second) {
|
||||
if (channel->opcode() == HloOpcode::kSend) {
|
||||
@ -117,9 +117,9 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kAllReduce: {
|
||||
auto all_reduce_id = hlo->all_reduce_id();
|
||||
if (all_reduce_id) {
|
||||
auto it = channel_group.find(all_reduce_id.value());
|
||||
auto channel_id = hlo->channel_id();
|
||||
if (channel_id) {
|
||||
auto it = channel_group.find(channel_id.value());
|
||||
if (it != channel_group.end()) {
|
||||
for (HloInstruction* all_reduce : it->second) {
|
||||
add_dependencies(all_reduce);
|
||||
|
@ -1210,8 +1210,8 @@ Status CheckSameChannel(const HloInstruction* instr1,
|
||||
return InternalError(
|
||||
"Expected to have the same channel id, actual channel ids are: %s "
|
||||
"(%d), %s (%d)",
|
||||
instr1->ToString(), instr1->channel_id(), instr2->ToString(),
|
||||
instr2->channel_id());
|
||||
instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
|
||||
*instr2->channel_id());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1282,14 +1282,14 @@ Status VerifySendsAndRecvs(const HloModule& module) {
|
||||
DynCast<const HloSendRecvInstruction>(instruction);
|
||||
if (sendrecv->is_host_transfer()) {
|
||||
auto it_inserted =
|
||||
host_channels.insert({sendrecv->channel_id(), sendrecv});
|
||||
host_channels.insert({*sendrecv->channel_id(), sendrecv});
|
||||
if (!it_inserted.second) {
|
||||
return FailedPrecondition(
|
||||
"Channel %d is used for multiple host send/recv instructions: "
|
||||
"%s "
|
||||
"and "
|
||||
"%s",
|
||||
sendrecv->channel_id(), sendrecv->ToString(),
|
||||
*sendrecv->channel_id(), sendrecv->ToString(),
|
||||
it_inserted.first->second->ToString());
|
||||
}
|
||||
}
|
||||
@ -1574,9 +1574,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
Status HandleAllReduce(HloInstruction* crs) override {
|
||||
if (crs->all_reduce_id().has_value()) {
|
||||
TF_RET_CHECK(crs->all_reduce_id().value() > 0)
|
||||
<< "All reduce id must be greater than 0 for "
|
||||
if (crs->channel_id().has_value()) {
|
||||
TF_RET_CHECK(crs->channel_id().value() > 0)
|
||||
<< "All reduce channel id must be greater than 0 for "
|
||||
<< crs->ToShortString();
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -262,7 +262,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
abs1 = f32[4,3]{1,0} abs(add)
|
||||
log = f32[4,3]{1,0} log(abs1)
|
||||
token0 = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=1
|
||||
abs2 = f32[4,3]{1,0} abs(log)
|
||||
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
|
||||
})")
|
||||
@ -294,7 +294,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
add1 = f32[4,3]{1,0} add(p0, p1)
|
||||
log = f32[4,3]{1,0} log(p0)
|
||||
token0 = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=1
|
||||
add2 = f32[4,3]{1,0} add(log, add1)
|
||||
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
|
||||
})")
|
||||
@ -329,7 +329,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
add2 = f32[4,3]{1,0} add(add1, p1)
|
||||
log = f32[4,3]{1,0} log(add2)
|
||||
token0 = token[] after-all()
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=0
|
||||
send = f32[4,3]{1,0} send(log, token0), channel_id=1
|
||||
sub1 = f32[4,3]{1,0} subtract(log, add2)
|
||||
sub2 = f32[4,3]{1,0} subtract(add2, add1)
|
||||
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
|
||||
|
@ -408,7 +408,7 @@ Status LayoutAssignment::BuildHostChannelConstraints(
|
||||
TF_RET_CHECK(data_shape.IsArray());
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
|
||||
const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
|
||||
send_recv_instr->channel_id(), data_shape.layout());
|
||||
*send_recv_instr->channel_id(), data_shape.layout());
|
||||
TF_RET_CHECK(prev_layout == nullptr)
|
||||
<< "Cannot constrain host transfer layout as it was set to "
|
||||
<< LayoutUtil::HumanString(*prev_layout) << ": "
|
||||
@ -480,7 +480,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
instruction->opcode() == HloOpcode::kRecv) {
|
||||
CHECK(get_channel_constraints(instruction))
|
||||
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
||||
int64 channel_id = instruction->channel_id();
|
||||
int64 channel_id = *instruction->channel_id();
|
||||
if (!get_channel_constraints(instruction)
|
||||
->IsChannelConstrained(channel_id)) {
|
||||
continue;
|
||||
@ -492,7 +492,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
Shape new_buffer_shape =
|
||||
get_channel_constraints(instruction)
|
||||
->LayoutShapeForChannel(send_buffer_shape,
|
||||
instruction->channel_id());
|
||||
*instruction->channel_id());
|
||||
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
||||
new_buffer_shape, instruction->operand(0)));
|
||||
} else {
|
||||
@ -503,18 +503,19 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
const LogicalBuffer* buffer,
|
||||
constraints->points_to_analysis().GetBufferDefinedAt(instruction,
|
||||
{0}));
|
||||
Shape new_shape = get_channel_constraints(instruction)
|
||||
->LayoutShapeForChannel(
|
||||
recv_buffer_shape, instruction->channel_id());
|
||||
Shape new_shape =
|
||||
get_channel_constraints(instruction)
|
||||
->LayoutShapeForChannel(recv_buffer_shape,
|
||||
*instruction->channel_id());
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetBufferLayout(new_shape.layout(), *buffer));
|
||||
}
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
CHECK(get_channel_constraints(instruction))
|
||||
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
|
||||
int64 all_reduce_id = instruction->all_reduce_id().value();
|
||||
int64 channel_id = instruction->channel_id().value();
|
||||
if (!get_channel_constraints(instruction)
|
||||
->IsChannelConstrained(all_reduce_id)) {
|
||||
->IsChannelConstrained(channel_id)) {
|
||||
continue;
|
||||
}
|
||||
// TODO(b/68493863): Change to use SetOperandLayout().
|
||||
@ -522,7 +523,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
TF_RET_CHECK(buffer_shape.IsArray());
|
||||
Shape new_buffer_shape =
|
||||
get_channel_constraints(instruction)
|
||||
->LayoutShapeForChannel(buffer_shape, all_reduce_id);
|
||||
->LayoutShapeForChannel(buffer_shape, channel_id);
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(new_buffer_shape, instruction));
|
||||
}
|
||||
@ -1833,7 +1834,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
||||
const Layout* layout =
|
||||
get_channel_constraints(instruction)
|
||||
->ConstrainChannel(
|
||||
instruction->channel_id(),
|
||||
*instruction->channel_id(),
|
||||
ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
|
||||
TF_RET_CHECK(layout == nullptr)
|
||||
<< instruction->ToString()
|
||||
@ -1848,7 +1849,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
||||
if (instruction->opcode() == HloOpcode::kSend) {
|
||||
HloInstruction* operand = instruction->mutable_operand(0);
|
||||
const Layout* layout = get_channel_constraints(instruction)
|
||||
->ConstrainChannel(instruction->channel_id(),
|
||||
->ConstrainChannel(*instruction->channel_id(),
|
||||
operand->shape().layout());
|
||||
if (layout != nullptr) {
|
||||
// We found an already constrained layout which does not match the one
|
||||
@ -1873,7 +1874,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
|
||||
} else if (instruction->IsCrossModuleAllReduce()) {
|
||||
const Layout* layout =
|
||||
get_channel_constraints(instruction)
|
||||
->ConstrainChannel(instruction->all_reduce_id().value(),
|
||||
->ConstrainChannel(instruction->channel_id().value(),
|
||||
instruction->shape().layout());
|
||||
if (layout != nullptr) {
|
||||
// We found an already constrained layout which does not match the one
|
||||
|
@ -891,11 +891,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
|
||||
param = (f32[2,2]) parameter(0)
|
||||
gte = f32[2,2] get-tuple-element(param), index=0
|
||||
ar.0 = f32[2,2] all-reduce(gte),
|
||||
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
|
||||
channel_id=1, replica_groups={{0}}, to_apply=add,
|
||||
sharding={maximal device=0}
|
||||
const = f32[2,2] constant({{0,1},{2,3}})
|
||||
ROOT ar.1 = f32[2,2] all-reduce(const),
|
||||
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
|
||||
channel_id=1, replica_groups={{0}}, to_apply=add,
|
||||
sharding={maximal device=1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||
|
Loading…
Reference in New Issue
Block a user