Change all_reduce_id to use channel_id

PiperOrigin-RevId: 255315385
This commit is contained in:
HyoukJoong Lee 2019-06-26 19:47:50 -07:00 committed by TensorFlower Gardener
parent 6ae9600988
commit 852061b75b
26 changed files with 267 additions and 241 deletions

View File

@ -2068,7 +2068,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
}
if (channel_id.has_value()) {
instr.set_all_reduce_id(channel_id->handle());
instr.set_channel_id(channel_id->handle());
}
AddCalledComputation(computation, &instr);

View File

@ -319,7 +319,7 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
auto maybe_pair = MatchesArCrsPattern(instruction);
if (maybe_pair) {
auto pair = *maybe_pair;
int64 ar_id = *(instruction->all_reduce_id());
int64 ar_id = *(instruction->channel_id());
if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) {
continue;
}
@ -365,9 +365,10 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
for (auto it : all_reduce_map_) {
auto all_reduce_id = it.first;
VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking ar_id: "
<< all_reduce_id << "\n";
auto channel_id = it.first;
VLOG(2)
<< "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: "
<< channel_id << "\n";
auto pairs_vec = it.second;
CHECK_EQ(pairs_vec.size(), num_spatial_partitions_);
auto instr_0 = pairs_vec[0].ar;
@ -378,9 +379,10 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
absl::flat_hash_map<int64, int64> visited_pairs;
while (true) {
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
all_reduce_map_.erase(all_reduce_id);
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased ar_id: "
<< all_reduce_id << "\n";
all_reduce_map_.erase(channel_id);
VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce "
"channel id: "
<< channel_id << "\n";
break;
}
if (next_0->IsCrossReplicaAllReduce()) {
@ -402,7 +404,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
for (auto pair : pairs_vec) {
auto all_reduce = pair.ar;
auto parent_computation = all_reduce->parent();
auto all_reduce_id = all_reduce->all_reduce_id();
auto channel_id = all_reduce->channel_id();
auto prev = all_reduce->mutable_operand(0);
auto next = all_reduce->users()[0];
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
@ -447,7 +449,7 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
next = next->users()[0];
}
// The AllReduce and the CRS are combined to an all-core AllReduce.
next->set_all_reduce_id(all_reduce_id);
next->set_channel_id(channel_id);
}
}
return true;

View File

@ -36,7 +36,7 @@ namespace xla {
//
// The steps are:
// 1) Find CMARs followed by simple ops followed by CRARs.
// 2) Group CMARs by all_reduce_id. They must all be rewritten.
// 2) Group CMARs by channel_id. They must all be rewritten.
// 3) Prove that the CMAR patterns in each core produce the same result.
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
// other operand by the number of spatial partitions.
@ -104,7 +104,7 @@ class ArCrsCombiner : public HloModulePass {
}
pieces.push_back(instruction->name());
pieces.push_back(")[id:");
pieces.push_back(std::to_string(*(ar->all_reduce_id())));
pieces.push_back(std::to_string(*(ar->channel_id())));
pieces.push_back(",dist:");
pieces.push_back(std::to_string(distance));
pieces.push_back("]");

View File

@ -414,7 +414,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[]
@ -429,7 +429,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[]
@ -486,7 +486,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
%all-reduce.ar.1 = f32[2,1]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.1,
sharding={maximal device=0}
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
@ -499,7 +499,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
%all-reduce.ar.2 = f32[2,1]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.1,
sharding={maximal device=1}
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
@ -549,7 +549,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.f32,
sharding={maximal device=0}
%multiply.1 = f32[]
@ -564,7 +564,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.f32,
sharding={maximal device=1}
%multiply.2 = f32[]
@ -624,7 +624,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[]
@ -642,7 +642,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[]
@ -709,7 +709,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[]
@ -727,7 +727,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[]
@ -772,7 +772,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.1,
sharding={maximal device=0}
%all-reduce.1 = f32[]
@ -787,7 +787,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.1,
sharding={maximal device=1}
%all-reduce.2 = f32[]
@ -840,7 +840,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum,
sharding={maximal device=0}
%add.11 = f32[]
@ -858,7 +858,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum,
sharding={maximal device=0}
%add.21 = f32[]
@ -919,7 +919,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.f32,
sharding={maximal device=0}
%sub.1 = f32[]
@ -934,7 +934,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.f32,
sharding={maximal device=1}
%sub.2 = f32[]
@ -991,7 +991,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar11 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum,
sharding={maximal device=0}
%add11 = f32[]
@ -1000,7 +1000,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar12 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=2,
channel_id=2,
to_apply=%sum,
sharding={maximal device=0}
%add12 = f32[]
@ -1015,7 +1015,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar21 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum,
sharding={maximal device=1}
%add21 = f32[]
@ -1024,7 +1024,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar22 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=2,
channel_id=2,
to_apply=%sum,
sharding={maximal device=1}
%add22 = f32[]
@ -1083,13 +1083,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar11 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum,
sharding={maximal device=0}
%ar12 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=2,
channel_id=2,
to_apply=%sum,
sharding={maximal device=0}
%add11 = f32[]
@ -1107,13 +1107,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%ar21 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum,
sharding={maximal device=1}
%ar22 = f32[]
all-reduce(%p),
replica_groups={{0},{1}},
all_reduce_id=2,
channel_id=2,
to_apply=%sum,
sharding={maximal device=1}
%add21 = f32[]
@ -1182,7 +1182,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.1 = bf16[]
all-reduce(%p),
replica_groups={{0}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[]
@ -1197,7 +1197,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%all-reduce.ar.2 = bf16[]
all-reduce(%constant.bf16),
replica_groups={{0}},
all_reduce_id=1,
channel_id=1,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[]
@ -1276,9 +1276,9 @@ HloModule foobar
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%p = bf16[] parameter(0)
%all-reduce.0 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}},
%all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=0}
%all-reduce.1 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}},
%all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=1}
%all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}},
to_apply=%sum.f32, sharding={maximal device=0}

View File

@ -239,7 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
/*replica_groups=*/{},
/*all_reduce_id=*/absl::nullopt));
/*channel_id=*/absl::nullopt));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(

View File

@ -257,7 +257,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
/*replica_groups=*/{},
/*all_reduce_id=*/absl::nullopt));
/*channel_id=*/absl::nullopt));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));

View File

@ -211,7 +211,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
HloInstruction* all_reduce =
builder.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
/*replica_groups=*/{}, /*all_reduce_id=*/1));
/*replica_groups=*/{}, /*channel_id=*/1));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
HloInstruction* gte1 = builder.AddInstruction(

View File

@ -177,13 +177,13 @@ class NcclComm {
// * Only ops with the same opcode can communicate with each other. At the
// moment we only support kAllReduce, so we don't check for this explicitly.
//
// * For cross-module all-reduces (i.e. instr->all_reduce_id().has_value()),
// only ops with the same value for all_reduce_id() can communicate with each
// * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
// only ops with the same value for channel_id() can communicate with each
// other.
//
// * For cross-replica (i.e. same-module) all-reduces (i.e.
// !all_reduce_id().has_value()), only ops from the same module (as identified
// by its unique_id()) can communicate with each other.
// !channel_id().has_value()), only ops from the same module (as
// identified by its unique_id()) can communicate with each other.
//
struct RendezvousKey {
enum AllReduceKind {
@ -196,8 +196,8 @@ struct RendezvousKey {
const HloAllReduceInstruction* instr)
: run_id(run_id), participating_replicas(participating_replicas) {
std::tie(all_reduce_kind, op_id) =
instr->all_reduce_id().has_value()
? std::make_pair(kCrossModule, instr->all_reduce_id().value())
instr->channel_id().has_value()
? std::make_pair(kCrossModule, instr->channel_id().value())
: std::make_pair(
kCrossReplica,
static_cast<int64>(instr->GetModule()->unique_id()));

View File

@ -126,8 +126,9 @@ message HloInstructionProto {
// Only present for kBatchNormTraining.
int64 feature_index = 25;
// Represents a unique identifier for each Send/Recv instruction pair.
// Only present for kSend or kRecv.
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id is equivalent to no channel id.
int64 channel_id = 26;
// The string representation of the infeed configuration.
@ -174,7 +175,9 @@ message HloInstructionProto {
// Cross replica op fields.
repeated ReplicaGroup replica_groups = 49;
int64 all_reduce_id = 45;
// Deprecated, but keeping it for backward compatibility. Use channel_id.
// Non-positive all_reduce_id is equivalent to no all_reduce_id.
int64 all_reduce_id = 45 [deprecated = true];
// Whether this Send/Recv instruction transfers data to/from the host. Only
// present for Send and Recv instructions and their SendDone and RecvDone

View File

@ -373,7 +373,7 @@ void HloComputation::ComputeInstructionPostOrder(
case HloOpcode::kRecvDone:
return inst->channel_id();
case HloOpcode::kAllReduce:
return inst->all_reduce_id();
return inst->channel_id();
default:
return absl::nullopt;
}
@ -428,13 +428,10 @@ HloComputation::ComputeChannelDependencies() const {
switch (instruction->opcode()) {
case HloOpcode::kSend:
case HloOpcode::kRecvDone:
channel_dependency_group[instruction->channel_id()].push_back(
instruction.get());
break;
case HloOpcode::kAllReduce: {
auto all_reduce_id = instruction->all_reduce_id();
if (all_reduce_id) {
channel_dependency_group[all_reduce_id.value()].push_back(
auto channel_id = instruction->channel_id();
if (channel_id) {
channel_dependency_group[channel_id.value()].push_back(
instruction.get());
}
break;

View File

@ -688,10 +688,10 @@ add {
ENTRY entry {
param = f32[128] parameter(0), sharding={maximal device=0}
crs0 = f32[128] all-reduce(param),
replica_groups={{0}}, all_reduce_id=1, to_apply=add,
replica_groups={{0}}, channel_id=1, to_apply=add,
sharding={maximal device=0}
crs1 = f32[128] all-reduce(param),
replica_groups={{0}}, all_reduce_id=1, to_apply=add,
replica_groups={{0}}, channel_id=1, to_apply=add,
sharding={maximal device=1}
add = f32[128] add(crs0, crs0), sharding={maximal device=0}
ROOT t = (f32[128], f32[128]) tuple(add, crs1)

View File

@ -383,16 +383,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "AllReduce should have 1 called computation but sees "
<< proto.called_computation_ids_size();
absl::optional<int64> all_reduce_id;
TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
<< "AllReduce cannot have both channel_id() and all_reduce_id()";
absl::optional<int64> channel_id;
if (proto.channel_id() > 0) {
channel_id = proto.channel_id();
}
if (proto.all_reduce_id() > 0) {
all_reduce_id = proto.all_reduce_id();
channel_id = proto.all_reduce_id();
}
instruction = CreateAllReduce(
shape, all_operands(), computations(0),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
/*all_reduce_id=*/all_reduce_id);
/*channel_id=*/channel_id);
break;
}
case HloOpcode::kAllToAll: {
@ -860,9 +865,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id) {
const absl::optional<int64>& channel_id) {
return absl::make_unique<HloAllReduceInstruction>(
shape, operands, reduce_computation, replica_groups, all_reduce_id);
shape, operands, reduce_computation, replica_groups, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
@ -1279,7 +1284,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kTrace:
return true;
case HloOpcode::kAllReduce:
return all_reduce_id().has_value();
return channel_id().has_value();
case HloOpcode::kCustomCall:
return Cast<HloCustomCallInstruction>(this)
->custom_call_has_side_effect();
@ -2232,11 +2237,11 @@ bool HloInstruction::IsElementwiseImpl(
}
bool HloInstruction::IsCrossModuleAllReduce() const {
return opcode() == HloOpcode::kAllReduce && all_reduce_id();
return opcode() == HloOpcode::kAllReduce && channel_id();
}
bool HloInstruction::IsCrossReplicaAllReduce() const {
return opcode() == HloOpcode::kAllReduce && !all_reduce_id();
return opcode() == HloOpcode::kAllReduce && !channel_id();
}
string HloInstruction::ToStringWithCanonicalNameMap(
@ -3332,10 +3337,6 @@ const std::vector<int64>& HloInstruction::fft_length() const {
return Cast<HloFftInstruction>(this)->fft_length();
}
int64 HloInstruction::channel_id() const {
return Cast<HloSendRecvInstruction>(this)->channel_id();
}
int64 HloInstruction::concatenate_dimension() const {
return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
}
@ -3535,13 +3536,12 @@ HloInstruction::source_target_pairs() const {
return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
}
absl::optional<int64> HloInstruction::all_reduce_id() const {
return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
absl::optional<int64> HloInstruction::channel_id() const {
return Cast<HloChannelInstruction>(this)->channel_id();
}
void HloInstruction::set_all_reduce_id(
const absl::optional<int64>& all_reduce_id) {
return Cast<HloAllReduceInstruction>(this)->set_all_reduce_id(all_reduce_id);
void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
}
const ConvolutionDimensionNumbers&

View File

@ -497,14 +497,14 @@ class HloInstruction {
// For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// `all_reduce_id`: for Allreduce nodes from different modules, if they have
// the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
// not be applied cross modules.
// `channel_id`: for Allreduce nodes from different modules, if
// they have the same channel_id, they will be 'Allreduce'd. If
// empty, Allreduce will not be applied cross modules.
static std::unique_ptr<HloInstruction> CreateAllReduce(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id);
const absl::optional<int64>& channel_id);
// An all-to-all op takes N array operands of the same shape and scatters them
// to N replicas. Each replica gathers the results into a tuple.
@ -952,7 +952,7 @@ class HloInstruction {
return false;
}
// Two AllReduces are Identical if they have the same all_reduce_id.
// Two AllReduces are Identical if they have the same channel_id.
// Their operands don't have to be Identical.
if (!IsCrossModuleAllReduce()) {
// Use an explicit loop rather than ContainerEquals, because copying
@ -1428,8 +1428,9 @@ class HloInstruction {
// Delegates to HloFftInstruction::fft_length.
const std::vector<int64>& fft_length() const;
// Delegates to HloSendRecvInstruction::channel_id.
int64 channel_id() const;
// Delegates to HloChannelInstruction::channel_id.
absl::optional<int64> channel_id() const;
void set_channel_id(const absl::optional<int64>& channel_id);
// Returns the dimension sizes or numbers associated with this instruction.
virtual const std::vector<int64>& dimensions() const {
@ -1571,10 +1572,6 @@ class HloInstruction {
// Delegates to HloCollectivePermuteInstruction::source_target_pairs.
const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
// Delegates to HloAllReduceInstruction::all_reduce_id.
absl::optional<int64> all_reduce_id() const;
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
// Returns data on the window in a windowed operation such as
// convolution.
virtual const Window& window() const {

View File

@ -361,25 +361,60 @@ HloCholeskyInstruction::CloneWithNewOperandsImpl(
cholesky_options());
}
HloChannelInstruction::HloChannelInstruction(
HloOpcode opcode, const Shape& shape,
const absl::optional<int64>& channel_id)
: HloInstruction(opcode, shape), channel_id_(channel_id) {}
void HloChannelInstruction::set_channel_id(
const absl::optional<int64>& channel_id) {
channel_id_ = channel_id;
}
HloInstructionProto HloChannelInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
if (channel_id_) {
CHECK_GT(channel_id_.value(), 0)
<< "Non-positive channel id is equivalent to no channel id";
proto.set_channel_id(*channel_id_);
}
return proto;
}
std::vector<string> HloChannelInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const {
std::vector<string> result;
if (channel_id_) {
result.push_back(StrCat("channel_id=", *channel_id_));
}
return result;
}
bool HloChannelInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
/*eq_computations*/) const {
const auto& casted_other = static_cast<const HloChannelInstruction&>(other);
return channel_id() == casted_other.channel_id();
}
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape,
int64 channel_id,
bool is_host_transfer)
: HloInstruction(opcode, shape),
channel_id_(channel_id),
: HloChannelInstruction(opcode, shape, channel_id),
is_host_transfer_(is_host_transfer) {}
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_channel_id(channel_id_);
HloInstructionProto proto = HloChannelInstruction::ToProto();
proto.set_is_host_transfer(is_host_transfer_);
return proto;
}
std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> attrs;
attrs.push_back(StrCat("channel_id=", channel_id_));
std::vector<string> attrs =
HloChannelInstruction::ExtraAttributesToStringImpl(options);
if (is_host_transfer()) {
attrs.push_back("is_host_transfer=true");
}
@ -413,13 +448,13 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloSendInstruction>(
new_operands[0], new_operands[1], channel_id(), is_host_transfer());
new_operands[0], new_operands[1], *channel_id(), is_host_transfer());
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand,
bool is_host_transfer)
: HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
CHECK_NOTNULL(operand)->channel_id(),
CHECK_NOTNULL(operand)->channel_id().value(),
is_host_transfer) {
AppendOperand(operand);
}
@ -450,7 +485,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRecvInstruction>(
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(),
ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(),
is_host_transfer());
}
@ -461,7 +496,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand,
ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(operand->shape(), 0),
ShapeUtil::MakeTokenShape()}),
CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) {
CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) {
AppendOperand(operand);
}
@ -477,32 +512,39 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl(
HloCollectiveInstruction::HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups)
: HloInstruction(opcode, shape), replica_groups_(replica_groups) {
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& channel_id)
: HloChannelInstruction(opcode, shape, channel_id),
replica_groups_({replica_groups.begin(), replica_groups.end()}) {
for (auto operand : operands) {
AppendOperand(operand);
}
}
HloInstructionProto HloCollectiveInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
HloInstructionProto proto = HloChannelInstruction::ToProto();
*proto.mutable_replica_groups() = {replica_groups_.begin(),
replica_groups_.end()};
return proto;
}
std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const {
return {StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))};
const HloPrintOptions& options) const {
std::vector<string> result =
HloChannelInstruction::ExtraAttributesToStringImpl(options);
result.push_back(
StrCat("replica_groups=", ReplicaGroupsToString(replica_groups())));
return result;
}
bool HloCollectiveInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
/*eq_computations*/) const {
eq_computations) const {
const auto& casted_other =
static_cast<const HloCollectiveInstruction&>(other);
return absl::c_equal(replica_groups(), casted_other.replica_groups(),
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
absl::c_equal(replica_groups(), casted_other.replica_groups(),
[](const ReplicaGroup& a, const ReplicaGroup& b) {
return absl::c_equal(a.replica_ids(), b.replica_ids());
});
@ -512,44 +554,19 @@ HloAllReduceInstruction::HloAllReduceInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id)
const absl::optional<int64>& channel_id)
: HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands,
replica_groups),
all_reduce_id_(all_reduce_id) {
replica_groups, channel_id) {
AppendComputation(reduce_computation);
}
void HloAllReduceInstruction::set_all_reduce_id(
const absl::optional<int64>& all_reduce_id) {
all_reduce_id_ = all_reduce_id;
}
HloInstructionProto HloAllReduceInstruction::ToProto() const {
HloInstructionProto proto = HloCollectiveInstruction::ToProto();
// Proto3 is so sad.
if (all_reduce_id_) {
proto.set_all_reduce_id(*all_reduce_id_);
}
return proto;
}
bool HloAllReduceInstruction::IsNoop() const {
for (auto replica_group : replica_groups()) {
if (replica_group.replica_ids().size() != 1) {
return false;
}
}
return !all_reduce_id();
}
std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> result =
HloCollectiveInstruction::ExtraAttributesToStringImpl(options);
if (all_reduce_id_) {
result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
}
return result;
return !channel_id();
}
bool HloAllReduceInstruction::IdenticalSlowPath(
@ -558,8 +575,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) &&
eq_computations(to_apply(), casted_other.to_apply()) &&
all_reduce_id() == casted_other.all_reduce_id();
eq_computations(to_apply(), casted_other.to_apply());
}
std::unique_ptr<HloInstruction>
@ -567,14 +583,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
return absl::make_unique<HloAllReduceInstruction>(
shape, new_operands, to_apply(), replica_groups(), all_reduce_id());
shape, new_operands, to_apply(), replica_groups(), channel_id());
}
HloAllToAllInstruction::HloAllToAllInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups)
: HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands,
replica_groups) {}
replica_groups, absl::nullopt) {}
std::unique_ptr<HloInstruction>
HloAllToAllInstruction::CloneWithNewOperandsImpl(
@ -587,13 +603,14 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl(
HloCollectivePermuteInstruction::HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand,
const std::vector<std::pair<int64, int64>>& source_target_pairs)
: HloInstruction(HloOpcode::kCollectivePermute, shape),
: HloChannelInstruction(HloOpcode::kCollectivePermute, shape,
absl::nullopt),
source_target_pairs_(source_target_pairs) {
AppendOperand(operand);
}
HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
HloInstructionProto proto = HloChannelInstruction::ToProto();
for (const auto& pair : source_target_pairs()) {
auto* proto_pair = proto.add_source_target_pairs();
proto_pair->set_source(pair.first);
@ -604,8 +621,9 @@ HloInstructionProto HloCollectivePermuteInstruction::ToProto() const {
std::vector<string>
HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& /*options*/) const {
std::vector<string> result;
const HloPrintOptions& options) const {
std::vector<string> result =
HloChannelInstruction::ExtraAttributesToStringImpl(options);
std::vector<string> strs;
for (const auto& pair : source_target_pairs()) {
strs.push_back(StrCat("{", pair.first, ",", pair.second, "}"));
@ -617,10 +635,11 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl(
bool HloCollectivePermuteInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
/*eq_computations*/) const {
eq_computations) const {
const auto& casted_other =
static_cast<const HloCollectivePermuteInstruction&>(other);
return absl::c_equal(source_target_pairs(),
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
absl::c_equal(source_target_pairs(),
casted_other.source_target_pairs(),
[](const std::pair<int64, int64>& a,
const std::pair<int64, int64>& b) { return a == b; });

View File

@ -206,13 +206,37 @@ class HloCholeskyInstruction : public HloInstruction {
CholeskyOptions cholesky_options_;
};
class HloSendRecvInstruction : public HloInstruction {
// Class that represents instructions that synchronize and transfer data between
// partitioned devices. Send/Recv and collective instructions (AllReduce,
// AllToAll, CollectivePermute) belong to this instruction type. A group of
// instructions (of the same opcode) with the same channel_id communicate during
// execution.
class HloChannelInstruction : public HloInstruction {
public:
// Returns the channel id associated with the instruction. The id is
// shared between each Send/Recv pair and is globally unique to identify each
// channel.
int64 channel_id() const { return channel_id_; }
// shared between each Send/Recv pair or a group of collective instructions
// and is globally unique to identify each channel.
absl::optional<int64> channel_id() const { return channel_id_; }
void set_channel_id(const absl::optional<int64>& channel_id);
protected:
explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
absl::optional<int64> channel_id_;
};
class HloSendRecvInstruction : public HloChannelInstruction {
public:
// Returns whether this send/recv instruction sends data to/from the host.
bool is_host_transfer() const { return is_host_transfer_; }
@ -230,9 +254,6 @@ class HloSendRecvInstruction : public HloInstruction {
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Represents a unique identifier for each Send/Recv instruction pair.
int64 channel_id_;
// Whether this send/recv instruction sends data to/from the host.
bool is_host_transfer_;
};
@ -285,7 +306,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction {
HloCloneContext* context) const override;
};
class HloCollectiveInstruction : public HloInstruction {
class HloCollectiveInstruction : public HloChannelInstruction {
public:
const std::vector<ReplicaGroup>& replica_groups() const {
return replica_groups_;
@ -295,7 +316,8 @@ class HloCollectiveInstruction : public HloInstruction {
explicit HloCollectiveInstruction(
HloOpcode opcode, const Shape& shape,
absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups);
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& channel_id);
HloInstructionProto ToProto() const override;
@ -315,21 +337,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* reduce_computation,
const std::vector<ReplicaGroup>& replica_groups,
const absl::optional<int64>& all_reduce_id);
absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
const absl::optional<int64>& channel_id);
// Returns true if the AllReduce does no communication, so it's equivalent
// to a mem copy.
bool IsNoop() const;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
@ -339,11 +353,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
// For Allreduce nodes from different modules, if they have the same
// all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross modules.
absl::optional<int64> all_reduce_id_;
};
class HloAllToAllInstruction : public HloCollectiveInstruction {
@ -359,7 +368,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction {
HloCloneContext* context) const override;
};
class HloCollectivePermuteInstruction : public HloInstruction {
class HloCollectivePermuteInstruction : public HloChannelInstruction {
public:
explicit HloCollectivePermuteInstruction(
const Shape& shape, HloInstruction* operand,

View File

@ -155,11 +155,11 @@ ENTRY entry {
%p = f32[1000, 1000] parameter(0)
%token.0 = token[] after-all()
%send = (f32[1000, 1000], token[]) send(%p, %token.0),
channel_id=0, is_host_transfer=true
channel_id=1, is_host_transfer=true
%n1 = f32[1000, 1000] negate(%p)
%n2 = f32[1000, 1000] negate(%n1)
%n3 = f32[1000, 1000] negate(%n2)
%send-done = token[] send-done(%send), channel_id=0, is_host_transfer=true
%send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true
}
)";

View File

@ -83,7 +83,7 @@ Status HloModuleGroupMetadata::Build() {
if (IsChannelInstruction(hlo)) {
peers.push_back(PeerComputation(hlo));
} else if (hlo->IsCrossModuleAllReduce()) {
for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) {
for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) {
if (instr == hlo) {
continue;
}
@ -235,7 +235,7 @@ bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction));
const Channel& channel = GetChannel(instruction->channel_id());
const Channel& channel = GetChannel(*instruction->channel_id());
switch (instruction->opcode()) {
case HloOpcode::kSend:
case HloOpcode::kSendDone:
@ -249,8 +249,8 @@ HloComputation* HloModuleGroupMetadata::PeerComputation(
}
const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
int64 all_reduce_id) const {
auto it = all_reduce_map_.find(all_reduce_id);
int64 channel_id) const {
auto it = all_reduce_map_.find(channel_id);
CHECK(it != all_reduce_map_.end());
return it->second;
}
@ -330,14 +330,14 @@ Status HloModuleGroupMetadata::RecordInstructions() {
TrackedInstruction(hlo, ComputationKind::kCallFunction);
}
// Group cross module all-reduce instructions by the all_reduce id.
// Group cross module all-reduce instructions by the channel id.
if (hlo->IsCrossModuleAllReduce()) {
TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) ==
TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) ==
channel_id_map_.end())
<< "all_reduce_id " << *hlo->all_reduce_id()
<< "channel_id " << *hlo->channel_id()
<< " is already used by a send/recv instruction";
all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo);
max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id());
all_reduce_map_[*hlo->channel_id()].push_back(hlo);
max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
return Status::OK();
}
@ -345,41 +345,41 @@ Status HloModuleGroupMetadata::RecordInstructions() {
return Status::OK();
}
TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) ==
TF_RET_CHECK(all_reduce_map_.find(*hlo->channel_id()) ==
all_reduce_map_.end())
<< "channel id " << hlo->channel_id()
<< "channel id " << *hlo->channel_id()
<< " is already used by an all-reduce instruction";
// Add a new channel if needed.
if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) {
if (channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) {
channels_.emplace_back();
channels_.back().id = hlo->channel_id();
channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
max_channel_id_ = std::max(max_channel_id_, hlo->channel_id());
channels_.back().id = *hlo->channel_id();
channel_id_map_[*hlo->channel_id()] = channels_.size() - 1;
max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
}
Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
Channel& channel = channels_[channel_id_map_[*hlo->channel_id()]];
if (hlo->opcode() == HloOpcode::kSend) {
TF_RET_CHECK(channel.send == nullptr)
<< "channel id " << hlo->channel_id()
<< "channel id " << *hlo->channel_id()
<< " is used by multiple send instructions";
channel.send = hlo;
}
if (hlo->opcode() == HloOpcode::kRecv) {
TF_RET_CHECK(channel.recv == nullptr)
<< "channel id " << hlo->channel_id()
<< "channel id " << *hlo->channel_id()
<< " is used by multiple recv instructions";
channel.recv = hlo;
}
if (hlo->opcode() == HloOpcode::kSendDone) {
TF_RET_CHECK(channel.send_done == nullptr)
<< "channel id " << hlo->channel_id()
<< "channel id " << *hlo->channel_id()
<< " is used by multiple send-done instructions";
channel.send_done = hlo;
}
if (hlo->opcode() == HloOpcode::kRecvDone) {
TF_RET_CHECK(channel.recv_done == nullptr)
<< "channel id " << hlo->channel_id()
<< "channel id " << *hlo->channel_id()
<< " is used by multiple recv-done instructions";
channel.recv_done = hlo;
}

View File

@ -137,9 +137,8 @@ class HloModuleGroupMetadata {
// Returns if the given channel id exists in metadata.
bool HasChannel(int64 channel_id) const;
// Returns the all-reduce instructions with the same all_reduce_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(
int64 all_reduce_id) const;
// Returns the all-reduce instructions with the same channel_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(int64 channel_id) const;
// Returns the computation that contains the peer channel instructions for
// the given instruction.
@ -205,7 +204,7 @@ class HloModuleGroupMetadata {
// Returns all channels in the module group.
const std::vector<Channel>& channels() const { return channels_; }
// Returns the maximum channel id or all_reduce_id used in the module group.
// Returns the maximum channel id used in the module group.
int64 max_channel_id() const { return max_channel_id_; }
HloAliasAnalysis* alias_analysis(HloModule* module) const {

View File

@ -62,7 +62,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
}
if (predecessor->IsCrossModuleAllReduce()) {
for (HloInstruction* instr :
metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
metadata_.GetAllReduceGroup(*predecessor->channel_id())) {
if (unique.insert(instr).second) {
predecessors.push_back(instr);
}
@ -82,8 +82,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
instruction_group.push_back(companion);
}
} else if (instruction->IsCrossModuleAllReduce()) {
instruction_group =
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
} else {
instruction_group.push_back(instruction);
}
@ -99,14 +98,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
if (instruction->opcode() == HloOpcode::kRecvDone &&
!DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
HloInstruction* send =
metadata_.GetChannel(*instruction->channel_id()).send;
add_unique_predecessor(send);
}
if (instruction->opcode() == HloOpcode::kSend &&
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
metadata_.GetChannel(*instruction->channel_id()).recv_done;
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
CHECK_EQ(recv_done->operand_count(), 1);
HloInstruction* recv = recv_done->mutable_operand(0);
@ -139,7 +139,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
}
if (successor->IsCrossModuleAllReduce()) {
for (HloInstruction* instr :
metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
metadata_.GetAllReduceGroup(*successor->channel_id())) {
if (unique.insert(instr).second) {
successors.push_back(instr);
}
@ -160,8 +160,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
instruction_group.push_back(companion);
}
} else if (instruction->IsCrossModuleAllReduce()) {
instruction_group =
metadata_.GetAllReduceGroup(*instruction->all_reduce_id());
instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id());
} else {
instruction_group.push_back(instruction);
}
@ -179,14 +178,15 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
// Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
HloInstruction* send =
metadata_.GetChannel(*instruction->channel_id()).send;
add_unique_successor(send);
}
if (instruction->opcode() == HloOpcode::kSend &&
!DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
metadata_.GetChannel(*instruction->channel_id()).recv_done;
add_unique_successor(recv_done);
}
return successors;
@ -256,7 +256,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder(
instruction_group.push_back(companion);
}
} else if (hlo->IsCrossModuleAllReduce()) {
instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id());
instruction_group = metadata_.GetAllReduceGroup(*hlo->channel_id());
} else {
instruction_group.push_back(hlo);
}

View File

@ -836,13 +836,12 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<int64> all_reduce_id;
optional<int64> channel_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
&all_reduce_id};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@ -851,7 +850,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
shape, operands, *to_apply, replica_groups, all_reduce_id));
shape, operands, *to_apply, replica_groups, channel_id));
break;
}
case HloOpcode::kAllToAll: {

View File

@ -1416,8 +1416,8 @@ add {
ENTRY CRS {
input = f32[8]{0} parameter(0)
crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
}
)"

View File

@ -85,8 +85,8 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
std::vector<HloInstruction*> inputs;
const auto add_input = [&channel_group, &inputs](HloInstruction* input) {
inputs.push_back(input);
if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) {
auto it = channel_group.find(*input->all_reduce_id());
if (input->opcode() == HloOpcode::kAllReduce && input->channel_id()) {
auto it = channel_group.find(*input->channel_id());
if (it != channel_group.end()) {
inputs.insert(inputs.end(), it->second.begin(), it->second.end());
}
@ -106,7 +106,7 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
switch (hlo->opcode()) {
case HloOpcode::kRecvDone: {
auto it = channel_group.find(hlo->channel_id());
auto it = channel_group.find(*hlo->channel_id());
if (it != channel_group.end()) {
for (HloInstruction* channel : it->second) {
if (channel->opcode() == HloOpcode::kSend) {
@ -117,9 +117,9 @@ std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
break;
}
case HloOpcode::kAllReduce: {
auto all_reduce_id = hlo->all_reduce_id();
if (all_reduce_id) {
auto it = channel_group.find(all_reduce_id.value());
auto channel_id = hlo->channel_id();
if (channel_id) {
auto it = channel_group.find(channel_id.value());
if (it != channel_group.end()) {
for (HloInstruction* all_reduce : it->second) {
add_dependencies(all_reduce);

View File

@ -1210,8 +1210,8 @@ Status CheckSameChannel(const HloInstruction* instr1,
return InternalError(
"Expected to have the same channel id, actual channel ids are: %s "
"(%d), %s (%d)",
instr1->ToString(), instr1->channel_id(), instr2->ToString(),
instr2->channel_id());
instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
*instr2->channel_id());
}
return Status::OK();
}
@ -1282,14 +1282,14 @@ Status VerifySendsAndRecvs(const HloModule& module) {
DynCast<const HloSendRecvInstruction>(instruction);
if (sendrecv->is_host_transfer()) {
auto it_inserted =
host_channels.insert({sendrecv->channel_id(), sendrecv});
host_channels.insert({*sendrecv->channel_id(), sendrecv});
if (!it_inserted.second) {
return FailedPrecondition(
"Channel %d is used for multiple host send/recv instructions: "
"%s "
"and "
"%s",
sendrecv->channel_id(), sendrecv->ToString(),
*sendrecv->channel_id(), sendrecv->ToString(),
it_inserted.first->second->ToString());
}
}
@ -1574,9 +1574,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
}
Status HandleAllReduce(HloInstruction* crs) override {
if (crs->all_reduce_id().has_value()) {
TF_RET_CHECK(crs->all_reduce_id().value() > 0)
<< "All reduce id must be greater than 0 for "
if (crs->channel_id().has_value()) {
TF_RET_CHECK(crs->channel_id().value() > 0)
<< "All reduce channel id must be greater than 0 for "
<< crs->ToShortString();
}
return Status::OK();

View File

@ -262,7 +262,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
abs1 = f32[4,3]{1,0} abs(add)
log = f32[4,3]{1,0} log(abs1)
token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=0
send = f32[4,3]{1,0} send(log, token0), channel_id=1
abs2 = f32[4,3]{1,0} abs(log)
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
})")
@ -294,7 +294,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
add1 = f32[4,3]{1,0} add(p0, p1)
log = f32[4,3]{1,0} log(p0)
token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=0
send = f32[4,3]{1,0} send(log, token0), channel_id=1
add2 = f32[4,3]{1,0} add(log, add1)
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
})")
@ -329,7 +329,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
add2 = f32[4,3]{1,0} add(add1, p1)
log = f32[4,3]{1,0} log(add2)
token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=0
send = f32[4,3]{1,0} send(log, token0), channel_id=1
sub1 = f32[4,3]{1,0} subtract(log, add2)
sub2 = f32[4,3]{1,0} subtract(add2, add1)
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)

View File

@ -408,7 +408,7 @@ Status LayoutAssignment::BuildHostChannelConstraints(
TF_RET_CHECK(data_shape.IsArray());
TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
send_recv_instr->channel_id(), data_shape.layout());
*send_recv_instr->channel_id(), data_shape.layout());
TF_RET_CHECK(prev_layout == nullptr)
<< "Cannot constrain host transfer layout as it was set to "
<< LayoutUtil::HumanString(*prev_layout) << ": "
@ -480,7 +480,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
instruction->opcode() == HloOpcode::kRecv) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id();
int64 channel_id = *instruction->channel_id();
if (!get_channel_constraints(instruction)
->IsChannelConstrained(channel_id)) {
continue;
@ -492,7 +492,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
Shape new_buffer_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(send_buffer_shape,
instruction->channel_id());
*instruction->channel_id());
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
new_buffer_shape, instruction->operand(0)));
} else {
@ -503,18 +503,19 @@ Status LayoutAssignment::AddMandatoryConstraints(
const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(instruction,
{0}));
Shape new_shape = get_channel_constraints(instruction)
->LayoutShapeForChannel(
recv_buffer_shape, instruction->channel_id());
Shape new_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(recv_buffer_shape,
*instruction->channel_id());
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
} else if (instruction->IsCrossModuleAllReduce()) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 all_reduce_id = instruction->all_reduce_id().value();
int64 channel_id = instruction->channel_id().value();
if (!get_channel_constraints(instruction)
->IsChannelConstrained(all_reduce_id)) {
->IsChannelConstrained(channel_id)) {
continue;
}
// TODO(b/68493863): Change to use SetOperandLayout().
@ -522,7 +523,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RET_CHECK(buffer_shape.IsArray());
Shape new_buffer_shape =
get_channel_constraints(instruction)
->LayoutShapeForChannel(buffer_shape, all_reduce_id);
->LayoutShapeForChannel(buffer_shape, channel_id);
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
@ -1833,7 +1834,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
const Layout* layout =
get_channel_constraints(instruction)
->ConstrainChannel(
instruction->channel_id(),
*instruction->channel_id(),
ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
TF_RET_CHECK(layout == nullptr)
<< instruction->ToString()
@ -1848,7 +1849,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
if (instruction->opcode() == HloOpcode::kSend) {
HloInstruction* operand = instruction->mutable_operand(0);
const Layout* layout = get_channel_constraints(instruction)
->ConstrainChannel(instruction->channel_id(),
->ConstrainChannel(*instruction->channel_id(),
operand->shape().layout());
if (layout != nullptr) {
// We found an already constrained layout which does not match the one
@ -1873,7 +1874,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
} else if (instruction->IsCrossModuleAllReduce()) {
const Layout* layout =
get_channel_constraints(instruction)
->ConstrainChannel(instruction->all_reduce_id().value(),
->ConstrainChannel(instruction->channel_id().value(),
instruction->shape().layout());
if (layout != nullptr) {
// We found an already constrained layout which does not match the one

View File

@ -891,11 +891,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
ar.0 = f32[2,2] all-reduce(gte),
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
channel_id=1, replica_groups={{0}}, to_apply=add,
sharding={maximal device=0}
const = f32[2,2] constant({{0,1},{2,3}})
ROOT ar.1 = f32[2,2] all-reduce(const),
all_reduce_id=1, replica_groups={{0}}, to_apply=add,
channel_id=1, replica_groups={{0}}, to_apply=add,
sharding={maximal device=1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,