From 852061b75b2b75fea85e5bc271a1f87af83c54ed Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Wed, 26 Jun 2019 19:47:50 -0700 Subject: [PATCH] Change all_reduce_id to use channel_id PiperOrigin-RevId: 255315385 --- tensorflow/compiler/xla/client/xla_builder.cc | 2 +- .../compiler/xla/service/ar_crs_combiner.cc | 20 +-- .../compiler/xla/service/ar_crs_combiner.h | 4 +- .../xla/service/ar_crs_combiner_test.cc | 56 ++++---- .../bfloat16_conversion_folding_test.cc | 2 +- .../service/bfloat16_normalization_test.cc | 2 +- .../xla/service/bfloat16_propagation_test.cc | 2 +- .../xla/service/gpu/nccl_all_reduce_thunk.cc | 12 +- tensorflow/compiler/xla/service/hlo.proto | 9 +- .../compiler/xla/service/hlo_computation.cc | 11 +- .../xla/service/hlo_computation_test.cc | 4 +- .../compiler/xla/service/hlo_instruction.cc | 34 ++--- .../compiler/xla/service/hlo_instruction.h | 19 ++- .../compiler/xla/service/hlo_instructions.cc | 129 ++++++++++-------- .../compiler/xla/service/hlo_instructions.h | 57 ++++---- .../xla/service/hlo_memory_scheduler_test.cc | 4 +- .../xla/service/hlo_module_group_metadata.cc | 40 +++--- .../xla/service/hlo_module_group_metadata.h | 7 +- .../xla/service/hlo_module_group_util.cc | 22 +-- tensorflow/compiler/xla/service/hlo_parser.cc | 7 +- .../compiler/xla/service/hlo_parser_test.cc | 4 +- .../compiler/xla/service/hlo_reachability.cc | 12 +- .../compiler/xla/service/hlo_verifier.cc | 14 +- .../xla/service/instruction_fusion_test.cc | 6 +- .../compiler/xla/service/layout_assignment.cc | 25 ++-- .../xla/service/layout_assignment_test.cc | 4 +- 26 files changed, 267 insertions(+), 241 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index a488b8ca5b1..738c488349d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2068,7 +2068,7 @@ XlaOp XlaBuilder::CrossReplicaSum( } if (channel_id.has_value()) { - instr.set_all_reduce_id(channel_id->handle()); + instr.set_channel_id(channel_id->handle()); } AddCalledComputation(computation, &instr); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 24f910caa7c..ae39906ef52 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -319,7 +319,7 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { auto maybe_pair = MatchesArCrsPattern(instruction); if (maybe_pair) { auto pair = *maybe_pair; - int64 ar_id = *(instruction->all_reduce_id()); + int64 ar_id = *(instruction->channel_id()); if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) { continue; } @@ -365,9 +365,10 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { - auto all_reduce_id = it.first; - VLOG(2) << "KeepProvablyEqualInstructionGroups. Checking ar_id: " - << all_reduce_id << "\n"; + auto channel_id = it.first; + VLOG(2) + << "KeepProvablyEqualInstructionGroups. Checking AllReduce channel id: " + << channel_id << "\n"; auto pairs_vec = it.second; CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); auto instr_0 = pairs_vec[0].ar; @@ -378,9 +379,10 @@ void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { absl::flat_hash_map visited_pairs; while (true) { if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { - all_reduce_map_.erase(all_reduce_id); - VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased ar_id: " - << all_reduce_id << "\n"; + all_reduce_map_.erase(channel_id); + VLOG(2) << "KeepProvablyEqualInstructionGroups. Erased AllReduce " + "channel id: " + << channel_id << "\n"; break; } if (next_0->IsCrossReplicaAllReduce()) { @@ -402,7 +404,7 @@ StatusOr ArCrsCombiner::RewriteGraph() { for (auto pair : pairs_vec) { auto all_reduce = pair.ar; auto parent_computation = all_reduce->parent(); - auto all_reduce_id = all_reduce->all_reduce_id(); + auto channel_id = all_reduce->channel_id(); auto prev = all_reduce->mutable_operand(0); auto next = all_reduce->users()[0]; TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); @@ -447,7 +449,7 @@ StatusOr ArCrsCombiner::RewriteGraph() { next = next->users()[0]; } // The AllReduce and the CRS are combined to an all-core AllReduce. - next->set_all_reduce_id(all_reduce_id); + next->set_channel_id(channel_id); } } return true; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index 250252b6390..a85e18d328c 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -36,7 +36,7 @@ namespace xla { // // The steps are: // 1) Find CMARs followed by simple ops followed by CRARs. -// 2) Group CMARs by all_reduce_id. They must all be rewritten. +// 2) Group CMARs by channel_id. They must all be rewritten. // 3) Prove that the CMAR patterns in each core produce the same result. // 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the // other operand by the number of spatial partitions. @@ -104,7 +104,7 @@ class ArCrsCombiner : public HloModulePass { } pieces.push_back(instruction->name()); pieces.push_back(")[id:"); - pieces.push_back(std::to_string(*(ar->all_reduce_id()))); + pieces.push_back(std::to_string(*(ar->channel_id()))); pieces.push_back(",dist:"); pieces.push_back(std::to_string(distance)); pieces.push_back("]"); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 0be31899d53..accc0684e8e 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -414,7 +414,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=0} %convert.1 = f32[] @@ -429,7 +429,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %all-reduce.ar.2 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=1} %convert.2 = f32[] @@ -486,7 +486,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { %all-reduce.ar.1 = f32[2,1] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.1, sharding={maximal device=0} %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1) @@ -499,7 +499,7 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { %all-reduce.ar.2 = f32[2,1] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.1, sharding={maximal device=1} %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2) @@ -549,7 +549,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.f32, sharding={maximal device=0} %multiply.1 = f32[] @@ -564,7 +564,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.2 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.f32, sharding={maximal device=1} %multiply.2 = f32[] @@ -624,7 +624,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=0} %convert.1 = f32[] @@ -642,7 +642,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.2 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=1} %convert.2 = f32[] @@ -709,7 +709,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.1 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=0} %convert.1 = f32[] @@ -727,7 +727,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.2 = bf16[] all-reduce(%constant.bf16), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=1} %convert.2 = f32[] @@ -772,7 +772,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.1, sharding={maximal device=0} %all-reduce.1 = f32[] @@ -787,7 +787,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.2 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.1, sharding={maximal device=1} %all-reduce.2 = f32[] @@ -840,7 +840,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum, sharding={maximal device=0} %add.11 = f32[] @@ -858,7 +858,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.2 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum, sharding={maximal device=0} %add.21 = f32[] @@ -919,7 +919,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.1 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.f32, sharding={maximal device=0} %sub.1 = f32[] @@ -934,7 +934,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %all-reduce.ar.2 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.f32, sharding={maximal device=1} %sub.2 = f32[] @@ -991,7 +991,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum, sharding={maximal device=0} %add11 = f32[] @@ -1000,7 +1000,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=2, + channel_id=2, to_apply=%sum, sharding={maximal device=0} %add12 = f32[] @@ -1015,7 +1015,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %ar21 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum, sharding={maximal device=1} %add21 = f32[] @@ -1024,7 +1024,7 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %ar22 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=2, + channel_id=2, to_apply=%sum, sharding={maximal device=1} %add22 = f32[] @@ -1083,13 +1083,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %ar11 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum, sharding={maximal device=0} %ar12 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=2, + channel_id=2, to_apply=%sum, sharding={maximal device=0} %add11 = f32[] @@ -1107,13 +1107,13 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { %ar21 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum, sharding={maximal device=1} %ar22 = f32[] all-reduce(%p), replica_groups={{0},{1}}, - all_reduce_id=2, + channel_id=2, to_apply=%sum, sharding={maximal device=1} %add21 = f32[] @@ -1182,7 +1182,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %all-reduce.ar.1 = bf16[] all-reduce(%p), replica_groups={{0}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=0} %convert.1 = f32[] @@ -1197,7 +1197,7 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %all-reduce.ar.2 = bf16[] all-reduce(%constant.bf16), replica_groups={{0}}, - all_reduce_id=1, + channel_id=1, to_apply=%sum.bf16, sharding={maximal device=1} %convert.2 = f32[] @@ -1276,9 +1276,9 @@ HloModule foobar ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { %p = bf16[] parameter(0) - %all-reduce.0 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}}, + %all-reduce.0 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %all-reduce.1 = f32[] all-reduce(%p), all_reduce_id=1, replica_groups={{0,1}}, + %all-reduce.1 = f32[] all-reduce(%p), channel_id=1, replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} %all-reduce.2 = f32[] all-reduce(%all-reduce.0), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 8170d16b889..eb6692ade5b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -239,7 +239,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum, /*replica_groups=*/{}, - /*all_reduce_id=*/absl::nullopt)); + /*channel_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 6de4d26475d..087d4a63ffa 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -257,7 +257,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, /*replica_groups=*/{}, - /*all_reduce_id=*/absl::nullopt)); + /*channel_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index c16fb5a5bf0..86eb8cb240c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -211,7 +211,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { HloInstruction* all_reduce = builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, - /*replica_groups=*/{}, /*all_reduce_id=*/1)); + /*replica_groups=*/{}, /*channel_id=*/1)); HloInstruction* gte0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, all_reduce, 0)); HloInstruction* gte1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 995a1fcd676..0e184d5681d 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -177,13 +177,13 @@ class NcclComm { // * Only ops with the same opcode can communicate with each other. At the // moment we only support kAllReduce, so we don't check for this explicitly. // -// * For cross-module all-reduces (i.e. instr->all_reduce_id().has_value()), -// only ops with the same value for all_reduce_id() can communicate with each +// * For cross-module all-reduces (i.e. instr->channel_id().has_value()), +// only ops with the same value for channel_id() can communicate with each // other. // // * For cross-replica (i.e. same-module) all-reduces (i.e. -// !all_reduce_id().has_value()), only ops from the same module (as identified -// by its unique_id()) can communicate with each other. +// !channel_id().has_value()), only ops from the same module (as +// identified by its unique_id()) can communicate with each other. // struct RendezvousKey { enum AllReduceKind { @@ -196,8 +196,8 @@ struct RendezvousKey { const HloAllReduceInstruction* instr) : run_id(run_id), participating_replicas(participating_replicas) { std::tie(all_reduce_kind, op_id) = - instr->all_reduce_id().has_value() - ? std::make_pair(kCrossModule, instr->all_reduce_id().value()) + instr->channel_id().has_value() + ? std::make_pair(kCrossModule, instr->channel_id().value()) : std::make_pair( kCrossReplica, static_cast(instr->GetModule()->unique_id())); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e7c6f71a2ea..331bbcb7836 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -126,8 +126,9 @@ message HloInstructionProto { // Only present for kBatchNormTraining. int64 feature_index = 25; - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. + // Represents a unique identifier for each Send/Recv instruction pair or + // optionally for collective instructions (AllReduce, CollectivePermute, + // AllToAll). Non-positive channel_id is equivalent to no channel id. int64 channel_id = 26; // The string representation of the infeed configuration. @@ -174,7 +175,9 @@ message HloInstructionProto { // Cross replica op fields. repeated ReplicaGroup replica_groups = 49; - int64 all_reduce_id = 45; + // Deprecated, but keeping it for backward compatibility. Use channel_id. + // Non-positive all_reduce_id is equivalent to no all_reduce_id. + int64 all_reduce_id = 45 [deprecated = true]; // Whether this Send/Recv instruction transfers data to/from the host. Only // present for Send and Recv instructions and their SendDone and RecvDone diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index e9169c46ca6..b377da2e44f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -373,7 +373,7 @@ void HloComputation::ComputeInstructionPostOrder( case HloOpcode::kRecvDone: return inst->channel_id(); case HloOpcode::kAllReduce: - return inst->all_reduce_id(); + return inst->channel_id(); default: return absl::nullopt; } @@ -428,13 +428,10 @@ HloComputation::ComputeChannelDependencies() const { switch (instruction->opcode()) { case HloOpcode::kSend: case HloOpcode::kRecvDone: - channel_dependency_group[instruction->channel_id()].push_back( - instruction.get()); - break; case HloOpcode::kAllReduce: { - auto all_reduce_id = instruction->all_reduce_id(); - if (all_reduce_id) { - channel_dependency_group[all_reduce_id.value()].push_back( + auto channel_id = instruction->channel_id(); + if (channel_id) { + channel_dependency_group[channel_id.value()].push_back( instruction.get()); } break; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 45466ccc141..466720f2ba0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -688,10 +688,10 @@ add { ENTRY entry { param = f32[128] parameter(0), sharding={maximal device=0} crs0 = f32[128] all-reduce(param), - replica_groups={{0}}, all_reduce_id=1, to_apply=add, + replica_groups={{0}}, channel_id=1, to_apply=add, sharding={maximal device=0} crs1 = f32[128] all-reduce(param), - replica_groups={{0}}, all_reduce_id=1, to_apply=add, + replica_groups={{0}}, channel_id=1, to_apply=add, sharding={maximal device=1} add = f32[128] add(crs0, crs0), sharding={maximal device=0} ROOT t = (f32[128], f32[128]) tuple(add, crs1) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a627727da48..8b8cea95e60 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -383,16 +383,21 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "AllReduce should have 1 called computation but sees " << proto.called_computation_ids_size(); - absl::optional all_reduce_id; + TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0) + << "AllReduce cannot have both channel_id() and all_reduce_id()"; + absl::optional channel_id; + if (proto.channel_id() > 0) { + channel_id = proto.channel_id(); + } if (proto.all_reduce_id() > 0) { - all_reduce_id = proto.all_reduce_id(); + channel_id = proto.all_reduce_id(); } instruction = CreateAllReduce( shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), - /*all_reduce_id=*/all_reduce_id); + /*channel_id=*/channel_id); break; } case HloOpcode::kAllToAll: { @@ -860,9 +865,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape, const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, - const absl::optional& all_reduce_id) { + const absl::optional& channel_id) { return absl::make_unique( - shape, operands, reduce_computation, replica_groups, all_reduce_id); + shape, operands, reduce_computation, replica_groups, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateAllToAll( @@ -1279,7 +1284,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kTrace: return true; case HloOpcode::kAllReduce: - return all_reduce_id().has_value(); + return channel_id().has_value(); case HloOpcode::kCustomCall: return Cast(this) ->custom_call_has_side_effect(); @@ -2232,11 +2237,11 @@ bool HloInstruction::IsElementwiseImpl( } bool HloInstruction::IsCrossModuleAllReduce() const { - return opcode() == HloOpcode::kAllReduce && all_reduce_id(); + return opcode() == HloOpcode::kAllReduce && channel_id(); } bool HloInstruction::IsCrossReplicaAllReduce() const { - return opcode() == HloOpcode::kAllReduce && !all_reduce_id(); + return opcode() == HloOpcode::kAllReduce && !channel_id(); } string HloInstruction::ToStringWithCanonicalNameMap( @@ -3332,10 +3337,6 @@ const std::vector& HloInstruction::fft_length() const { return Cast(this)->fft_length(); } -int64 HloInstruction::channel_id() const { - return Cast(this)->channel_id(); -} - int64 HloInstruction::concatenate_dimension() const { return Cast(this)->concatenate_dimension(); } @@ -3535,13 +3536,12 @@ HloInstruction::source_target_pairs() const { return Cast(this)->source_target_pairs(); } -absl::optional HloInstruction::all_reduce_id() const { - return Cast(this)->all_reduce_id(); +absl::optional HloInstruction::channel_id() const { + return Cast(this)->channel_id(); } -void HloInstruction::set_all_reduce_id( - const absl::optional& all_reduce_id) { - return Cast(this)->set_all_reduce_id(all_reduce_id); +void HloInstruction::set_channel_id(const absl::optional& channel_id) { + return Cast(this)->set_channel_id(channel_id); } const ConvolutionDimensionNumbers& diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 55564803da7..7ec47582276 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -497,14 +497,14 @@ class HloInstruction { // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // `all_reduce_id`: for Allreduce nodes from different modules, if they have - // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will - // not be applied cross modules. + // `channel_id`: for Allreduce nodes from different modules, if + // they have the same channel_id, they will be 'Allreduce'd. If + // empty, Allreduce will not be applied cross modules. static std::unique_ptr CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, - const absl::optional& all_reduce_id); + const absl::optional& channel_id); // An all-to-all op takes N array operands of the same shape and scatters them // to N replicas. Each replica gathers the results into a tuple. @@ -952,7 +952,7 @@ class HloInstruction { return false; } - // Two AllReduces are Identical if they have the same all_reduce_id. + // Two AllReduces are Identical if they have the same channel_id. // Their operands don't have to be Identical. if (!IsCrossModuleAllReduce()) { // Use an explicit loop rather than ContainerEquals, because copying @@ -1428,8 +1428,9 @@ class HloInstruction { // Delegates to HloFftInstruction::fft_length. const std::vector& fft_length() const; - // Delegates to HloSendRecvInstruction::channel_id. - int64 channel_id() const; + // Delegates to HloChannelInstruction::channel_id. + absl::optional channel_id() const; + void set_channel_id(const absl::optional& channel_id); // Returns the dimension sizes or numbers associated with this instruction. virtual const std::vector& dimensions() const { @@ -1571,10 +1572,6 @@ class HloInstruction { // Delegates to HloCollectivePermuteInstruction::source_target_pairs. const std::vector>& source_target_pairs() const; - // Delegates to HloAllReduceInstruction::all_reduce_id. - absl::optional all_reduce_id() const; - void set_all_reduce_id(const absl::optional& all_reduce_id); - // Returns data on the window in a windowed operation such as // convolution. virtual const Window& window() const { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index cd43a2c41de..2a4809a572d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -361,25 +361,60 @@ HloCholeskyInstruction::CloneWithNewOperandsImpl( cholesky_options()); } +HloChannelInstruction::HloChannelInstruction( + HloOpcode opcode, const Shape& shape, + const absl::optional& channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +void HloChannelInstruction::set_channel_id( + const absl::optional& channel_id) { + channel_id_ = channel_id; +} + +HloInstructionProto HloChannelInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (channel_id_) { + CHECK_GT(channel_id_.value(), 0) + << "Non-positive channel id is equivalent to no channel id"; + proto.set_channel_id(*channel_id_); + } + return proto; +} + +std::vector HloChannelInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + if (channel_id_) { + result.push_back(StrCat("channel_id=", *channel_id_)); + } + return result; +} + +bool HloChannelInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = static_cast(other); + return channel_id() == casted_other.channel_id(); +} + HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, int64 channel_id, bool is_host_transfer) - : HloInstruction(opcode, shape), - channel_id_(channel_id), + : HloChannelInstruction(opcode, shape, channel_id), is_host_transfer_(is_host_transfer) {} HloInstructionProto HloSendRecvInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - proto.set_channel_id(channel_id_); + HloInstructionProto proto = HloChannelInstruction::ToProto(); proto.set_is_host_transfer(is_host_transfer_); return proto; } std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - std::vector attrs; - attrs.push_back(StrCat("channel_id=", channel_id_)); + std::vector attrs = + HloChannelInstruction::ExtraAttributesToStringImpl(options); if (is_host_transfer()) { attrs.push_back("is_host_transfer=true"); } @@ -413,13 +448,13 @@ std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - new_operands[0], new_operands[1], channel_id(), is_host_transfer()); + new_operands[0], new_operands[1], *channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, bool is_host_transfer) : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), - CHECK_NOTNULL(operand)->channel_id(), + CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) { AppendOperand(operand); } @@ -450,7 +485,7 @@ std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return absl::make_unique( - ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], *channel_id(), is_host_transfer()); } @@ -461,7 +496,7 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(operand->shape(), 0), ShapeUtil::MakeTokenShape()}), - CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) { + CHECK_NOTNULL(operand)->channel_id().value(), is_host_transfer) { AppendOperand(operand); } @@ -477,32 +512,39 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCollectiveInstruction::HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, absl::Span operands, - const std::vector& replica_groups) - : HloInstruction(opcode, shape), replica_groups_(replica_groups) { + const std::vector& replica_groups, + const absl::optional& channel_id) + : HloChannelInstruction(opcode, shape, channel_id), + replica_groups_({replica_groups.begin(), replica_groups.end()}) { for (auto operand : operands) { AppendOperand(operand); } } HloInstructionProto HloCollectiveInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); + HloInstructionProto proto = HloChannelInstruction::ToProto(); *proto.mutable_replica_groups() = {replica_groups_.begin(), replica_groups_.end()}; return proto; } std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& /*options*/) const { - return {StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))}; + const HloPrintOptions& options) const { + std::vector result = + HloChannelInstruction::ExtraAttributesToStringImpl(options); + result.push_back( + StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))); + return result; } bool HloCollectiveInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - /*eq_computations*/) const { + eq_computations) const { const auto& casted_other = static_cast(other); - return absl::c_equal(replica_groups(), casted_other.replica_groups(), + return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && + absl::c_equal(replica_groups(), casted_other.replica_groups(), [](const ReplicaGroup& a, const ReplicaGroup& b) { return absl::c_equal(a.replica_ids(), b.replica_ids()); }); @@ -512,44 +554,19 @@ HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, - const absl::optional& all_reduce_id) + const absl::optional& channel_id) : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, - replica_groups), - all_reduce_id_(all_reduce_id) { + replica_groups, channel_id) { AppendComputation(reduce_computation); } -void HloAllReduceInstruction::set_all_reduce_id( - const absl::optional& all_reduce_id) { - all_reduce_id_ = all_reduce_id; -} - -HloInstructionProto HloAllReduceInstruction::ToProto() const { - HloInstructionProto proto = HloCollectiveInstruction::ToProto(); - // Proto3 is so sad. - if (all_reduce_id_) { - proto.set_all_reduce_id(*all_reduce_id_); - } - return proto; -} - bool HloAllReduceInstruction::IsNoop() const { for (auto replica_group : replica_groups()) { if (replica_group.replica_ids().size() != 1) { return false; } } - return !all_reduce_id(); -} - -std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& options) const { - std::vector result = - HloCollectiveInstruction::ExtraAttributesToStringImpl(options); - if (all_reduce_id_) { - result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); - } - return result; + return !channel_id(); } bool HloAllReduceInstruction::IdenticalSlowPath( @@ -558,8 +575,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast(other); return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && - eq_computations(to_apply(), casted_other.to_apply()) && - all_reduce_id() == casted_other.all_reduce_id(); + eq_computations(to_apply(), casted_other.to_apply()); } std::unique_ptr @@ -567,14 +583,14 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands, to_apply(), replica_groups(), all_reduce_id()); + shape, new_operands, to_apply(), replica_groups(), channel_id()); } HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, absl::Span operands, const std::vector& replica_groups) : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, - replica_groups) {} + replica_groups, absl::nullopt) {} std::unique_ptr HloAllToAllInstruction::CloneWithNewOperandsImpl( @@ -587,13 +603,14 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs) - : HloInstruction(HloOpcode::kCollectivePermute, shape), + : HloChannelInstruction(HloOpcode::kCollectivePermute, shape, + absl::nullopt), source_target_pairs_(source_target_pairs) { AppendOperand(operand); } HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); + HloInstructionProto proto = HloChannelInstruction::ToProto(); for (const auto& pair : source_target_pairs()) { auto* proto_pair = proto.add_source_target_pairs(); proto_pair->set_source(pair.first); @@ -604,8 +621,9 @@ HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { std::vector HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& /*options*/) const { - std::vector result; + const HloPrintOptions& options) const { + std::vector result = + HloChannelInstruction::ExtraAttributesToStringImpl(options); std::vector strs; for (const auto& pair : source_target_pairs()) { strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); @@ -617,10 +635,11 @@ HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( bool HloCollectivePermuteInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - /*eq_computations*/) const { + eq_computations) const { const auto& casted_other = static_cast(other); - return absl::c_equal(source_target_pairs(), + return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && + absl::c_equal(source_target_pairs(), casted_other.source_target_pairs(), [](const std::pair& a, const std::pair& b) { return a == b; }); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index f7f0eb4536b..8ab1995c622 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -206,13 +206,37 @@ class HloCholeskyInstruction : public HloInstruction { CholeskyOptions cholesky_options_; }; -class HloSendRecvInstruction : public HloInstruction { +// Class that represents instructions that synchronize and transfer data between +// partitioned devices. Send/Recv and collective instructions (AllReduce, +// AllToAll, CollectivePermute) belong to this instruction type. A group of +// instructions (of the same opcode) with the same channel_id communicate during +// execution. +class HloChannelInstruction : public HloInstruction { public: // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - int64 channel_id() const { return channel_id_; } + // shared between each Send/Recv pair or a group of collective instructions + // and is globally unique to identify each channel. + absl::optional channel_id() const { return channel_id_; } + void set_channel_id(const absl::optional& channel_id); + protected: + explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape, + const absl::optional& channel_id); + + HloInstructionProto ToProto() const override; + + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + absl::optional channel_id_; +}; + +class HloSendRecvInstruction : public HloChannelInstruction { + public: // Returns whether this send/recv instruction sends data to/from the host. bool is_host_transfer() const { return is_host_transfer_; } @@ -230,9 +254,6 @@ class HloSendRecvInstruction : public HloInstruction { const HloInstruction& other, const std::function& eq_computations) const override; - // Represents a unique identifier for each Send/Recv instruction pair. - int64 channel_id_; - // Whether this send/recv instruction sends data to/from the host. bool is_host_transfer_; }; @@ -285,7 +306,7 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { HloCloneContext* context) const override; }; -class HloCollectiveInstruction : public HloInstruction { +class HloCollectiveInstruction : public HloChannelInstruction { public: const std::vector& replica_groups() const { return replica_groups_; @@ -295,7 +316,8 @@ class HloCollectiveInstruction : public HloInstruction { explicit HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, absl::Span operands, - const std::vector& replica_groups); + const std::vector& replica_groups, + const absl::optional& channel_id); HloInstructionProto ToProto() const override; @@ -315,21 +337,13 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, - const absl::optional& all_reduce_id); - - absl::optional all_reduce_id() const { return all_reduce_id_; } - void set_all_reduce_id(const absl::optional& all_reduce_id); - - // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const override; + const absl::optional& channel_id); // Returns true if the AllReduce does no communication, so it's equivalent // to a mem copy. bool IsNoop() const; private: - std::vector ExtraAttributesToStringImpl( - const HloPrintOptions& options) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -339,11 +353,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - - // For Allreduce nodes from different modules, if they have the same - // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross modules. - absl::optional all_reduce_id_; }; class HloAllToAllInstruction : public HloCollectiveInstruction { @@ -359,7 +368,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { HloCloneContext* context) const override; }; -class HloCollectivePermuteInstruction : public HloInstruction { +class HloCollectivePermuteInstruction : public HloChannelInstruction { public: explicit HloCollectivePermuteInstruction( const Shape& shape, HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index e0cb473907e..281ff764ef6 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -155,11 +155,11 @@ ENTRY entry { %p = f32[1000, 1000] parameter(0) %token.0 = token[] after-all() %send = (f32[1000, 1000], token[]) send(%p, %token.0), - channel_id=0, is_host_transfer=true + channel_id=1, is_host_transfer=true %n1 = f32[1000, 1000] negate(%p) %n2 = f32[1000, 1000] negate(%n1) %n3 = f32[1000, 1000] negate(%n2) - %send-done = token[] send-done(%send), channel_id=0, is_host_transfer=true + %send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true } )"; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index c326db1573e..92e4ebc5512 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -83,7 +83,7 @@ Status HloModuleGroupMetadata::Build() { if (IsChannelInstruction(hlo)) { peers.push_back(PeerComputation(hlo)); } else if (hlo->IsCrossModuleAllReduce()) { - for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) { if (instr == hlo) { continue; } @@ -235,7 +235,7 @@ bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const { HloComputation* HloModuleGroupMetadata::PeerComputation( const HloInstruction* instruction) const { CHECK(IsChannelInstruction(instruction)); - const Channel& channel = GetChannel(instruction->channel_id()); + const Channel& channel = GetChannel(*instruction->channel_id()); switch (instruction->opcode()) { case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -249,8 +249,8 @@ HloComputation* HloModuleGroupMetadata::PeerComputation( } const std::vector& HloModuleGroupMetadata::GetAllReduceGroup( - int64 all_reduce_id) const { - auto it = all_reduce_map_.find(all_reduce_id); + int64 channel_id) const { + auto it = all_reduce_map_.find(channel_id); CHECK(it != all_reduce_map_.end()); return it->second; } @@ -330,14 +330,14 @@ Status HloModuleGroupMetadata::RecordInstructions() { TrackedInstruction(hlo, ComputationKind::kCallFunction); } - // Group cross module all-reduce instructions by the all_reduce id. + // Group cross module all-reduce instructions by the channel id. if (hlo->IsCrossModuleAllReduce()) { - TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) == + TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) - << "all_reduce_id " << *hlo->all_reduce_id() + << "channel_id " << *hlo->channel_id() << " is already used by a send/recv instruction"; - all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo); - max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id()); + all_reduce_map_[*hlo->channel_id()].push_back(hlo); + max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id()); return Status::OK(); } @@ -345,41 +345,41 @@ Status HloModuleGroupMetadata::RecordInstructions() { return Status::OK(); } - TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) == + TF_RET_CHECK(all_reduce_map_.find(*hlo->channel_id()) == all_reduce_map_.end()) - << "channel id " << hlo->channel_id() + << "channel id " << *hlo->channel_id() << " is already used by an all-reduce instruction"; // Add a new channel if needed. - if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { + if (channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) { channels_.emplace_back(); - channels_.back().id = hlo->channel_id(); - channel_id_map_[hlo->channel_id()] = channels_.size() - 1; - max_channel_id_ = std::max(max_channel_id_, hlo->channel_id()); + channels_.back().id = *hlo->channel_id(); + channel_id_map_[*hlo->channel_id()] = channels_.size() - 1; + max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id()); } - Channel& channel = channels_[channel_id_map_[hlo->channel_id()]]; + Channel& channel = channels_[channel_id_map_[*hlo->channel_id()]]; if (hlo->opcode() == HloOpcode::kSend) { TF_RET_CHECK(channel.send == nullptr) - << "channel id " << hlo->channel_id() + << "channel id " << *hlo->channel_id() << " is used by multiple send instructions"; channel.send = hlo; } if (hlo->opcode() == HloOpcode::kRecv) { TF_RET_CHECK(channel.recv == nullptr) - << "channel id " << hlo->channel_id() + << "channel id " << *hlo->channel_id() << " is used by multiple recv instructions"; channel.recv = hlo; } if (hlo->opcode() == HloOpcode::kSendDone) { TF_RET_CHECK(channel.send_done == nullptr) - << "channel id " << hlo->channel_id() + << "channel id " << *hlo->channel_id() << " is used by multiple send-done instructions"; channel.send_done = hlo; } if (hlo->opcode() == HloOpcode::kRecvDone) { TF_RET_CHECK(channel.recv_done == nullptr) - << "channel id " << hlo->channel_id() + << "channel id " << *hlo->channel_id() << " is used by multiple recv-done instructions"; channel.recv_done = hlo; } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index ba6639bd6e7..4db0afeab0a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -137,9 +137,8 @@ class HloModuleGroupMetadata { // Returns if the given channel id exists in metadata. bool HasChannel(int64 channel_id) const; - // Returns the all-reduce instructions with the same all_reduce_id. - const std::vector& GetAllReduceGroup( - int64 all_reduce_id) const; + // Returns the all-reduce instructions with the same channel_id. + const std::vector& GetAllReduceGroup(int64 channel_id) const; // Returns the computation that contains the peer channel instructions for // the given instruction. @@ -205,7 +204,7 @@ class HloModuleGroupMetadata { // Returns all channels in the module group. const std::vector& channels() const { return channels_; } - // Returns the maximum channel id or all_reduce_id used in the module group. + // Returns the maximum channel id used in the module group. int64 max_channel_id() const { return max_channel_id_; } HloAliasAnalysis* alias_analysis(HloModule* module) const { diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9b7f54c5c6f..98e0357802c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -62,7 +62,7 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( } if (predecessor->IsCrossModuleAllReduce()) { for (HloInstruction* instr : - metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + metadata_.GetAllReduceGroup(*predecessor->channel_id())) { if (unique.insert(instr).second) { predecessors.push_back(instr); } @@ -82,8 +82,7 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( instruction_group.push_back(companion); } } else if (instruction->IsCrossModuleAllReduce()) { - instruction_group = - metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); + instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id()); } else { instruction_group.push_back(instruction); } @@ -99,14 +98,15 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( if (instruction->opcode() == HloOpcode::kRecvDone && !DynCast(instruction)->is_host_transfer()) { // Send is a remote predecessor of RecvDone. - HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + HloInstruction* send = + metadata_.GetChannel(*instruction->channel_id()).send; add_unique_predecessor(send); } if (instruction->opcode() == HloOpcode::kSend && !DynCast(instruction)->is_host_transfer()) { // Recv is a remote predecessor of Send. HloInstruction* recv_done = - metadata_.GetChannel(instruction->channel_id()).recv_done; + metadata_.GetChannel(*instruction->channel_id()).recv_done; CHECK(recv_done->opcode() == HloOpcode::kRecvDone); CHECK_EQ(recv_done->operand_count(), 1); HloInstruction* recv = recv_done->mutable_operand(0); @@ -139,7 +139,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( } if (successor->IsCrossModuleAllReduce()) { for (HloInstruction* instr : - metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + metadata_.GetAllReduceGroup(*successor->channel_id())) { if (unique.insert(instr).second) { successors.push_back(instr); } @@ -160,8 +160,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( instruction_group.push_back(companion); } } else if (instruction->IsCrossModuleAllReduce()) { - instruction_group = - metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); + instruction_group = metadata_.GetAllReduceGroup(*instruction->channel_id()); } else { instruction_group.push_back(instruction); } @@ -179,14 +178,15 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( // Send is a remote successor of Recv. const HloInstruction* recv_done = instruction->users().front(); CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + HloInstruction* send = + metadata_.GetChannel(*instruction->channel_id()).send; add_unique_successor(send); } if (instruction->opcode() == HloOpcode::kSend && !DynCast(instruction)->is_host_transfer()) { // RecvDone is a remote successor of Send. HloInstruction* recv_done = - metadata_.GetChannel(instruction->channel_id()).recv_done; + metadata_.GetChannel(*instruction->channel_id()).recv_done; add_unique_successor(recv_done); } return successors; @@ -256,7 +256,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( instruction_group.push_back(companion); } } else if (hlo->IsCrossModuleAllReduce()) { - instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id()); + instruction_group = metadata_.GetAllReduceGroup(*hlo->channel_id()); } else { instruction_group.push_back(hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 30b2037374c..06f66505827 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -836,13 +836,12 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional>> tmp_groups; optional to_apply; optional> replica_group_ids; - optional all_reduce_id; + optional channel_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; attrs["replica_groups"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &tmp_groups}; - attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, - &all_reduce_id}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -851,7 +850,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, replica_groups = CreateReplicaGroups(*tmp_groups); } instruction = builder->AddInstruction(HloInstruction::CreateAllReduce( - shape, operands, *to_apply, replica_groups, all_reduce_id)); + shape, operands, *to_apply, replica_groups, channel_id)); break; } case HloOpcode::kAllToAll: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d282b868b5b..7d093cc8ac0 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1416,8 +1416,8 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add - ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add + ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add } )" diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index af07eb83a5c..b0911d5f3cb 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -85,8 +85,8 @@ std::unique_ptr HloReachabilityMap::Build( std::vector inputs; const auto add_input = [&channel_group, &inputs](HloInstruction* input) { inputs.push_back(input); - if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) { - auto it = channel_group.find(*input->all_reduce_id()); + if (input->opcode() == HloOpcode::kAllReduce && input->channel_id()) { + auto it = channel_group.find(*input->channel_id()); if (it != channel_group.end()) { inputs.insert(inputs.end(), it->second.begin(), it->second.end()); } @@ -106,7 +106,7 @@ std::unique_ptr HloReachabilityMap::Build( switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - auto it = channel_group.find(hlo->channel_id()); + auto it = channel_group.find(*hlo->channel_id()); if (it != channel_group.end()) { for (HloInstruction* channel : it->second) { if (channel->opcode() == HloOpcode::kSend) { @@ -117,9 +117,9 @@ std::unique_ptr HloReachabilityMap::Build( break; } case HloOpcode::kAllReduce: { - auto all_reduce_id = hlo->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_group.find(all_reduce_id.value()); + auto channel_id = hlo->channel_id(); + if (channel_id) { + auto it = channel_group.find(channel_id.value()); if (it != channel_group.end()) { for (HloInstruction* all_reduce : it->second) { add_dependencies(all_reduce); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 89a3f8e6b02..efe0ce1af2a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1210,8 +1210,8 @@ Status CheckSameChannel(const HloInstruction* instr1, return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%d), %s (%d)", - instr1->ToString(), instr1->channel_id(), instr2->ToString(), - instr2->channel_id()); + instr1->ToString(), *instr1->channel_id(), instr2->ToString(), + *instr2->channel_id()); } return Status::OK(); } @@ -1282,14 +1282,14 @@ Status VerifySendsAndRecvs(const HloModule& module) { DynCast(instruction); if (sendrecv->is_host_transfer()) { auto it_inserted = - host_channels.insert({sendrecv->channel_id(), sendrecv}); + host_channels.insert({*sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( "Channel %d is used for multiple host send/recv instructions: " "%s " "and " "%s", - sendrecv->channel_id(), sendrecv->ToString(), + *sendrecv->channel_id(), sendrecv->ToString(), it_inserted.first->second->ToString()); } } @@ -1574,9 +1574,9 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } Status HandleAllReduce(HloInstruction* crs) override { - if (crs->all_reduce_id().has_value()) { - TF_RET_CHECK(crs->all_reduce_id().value() > 0) - << "All reduce id must be greater than 0 for " + if (crs->channel_id().has_value()) { + TF_RET_CHECK(crs->channel_id().value() > 0) + << "All reduce channel id must be greater than 0 for " << crs->ToShortString(); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 864a9ac2069..d209d4d8d6b 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -262,7 +262,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) token0 = token[] after-all() - send = f32[4,3]{1,0} send(log, token0), channel_id=0 + send = f32[4,3]{1,0} send(log, token0), channel_id=1 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -294,7 +294,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add1 = f32[4,3]{1,0} add(p0, p1) log = f32[4,3]{1,0} log(p0) token0 = token[] after-all() - send = f32[4,3]{1,0} send(log, token0), channel_id=0 + send = f32[4,3]{1,0} send(log, token0), channel_id=1 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -329,7 +329,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add2 = f32[4,3]{1,0} add(add1, p1) log = f32[4,3]{1,0} log(add2) token0 = token[] after-all() - send = f32[4,3]{1,0} send(log, token0), channel_id=0 + send = f32[4,3]{1,0} send(log, token0), channel_id=1 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 285fe3b3581..c2372aa0c8a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -408,7 +408,7 @@ Status LayoutAssignment::BuildHostChannelConstraints( TF_RET_CHECK(data_shape.IsArray()); TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( - send_recv_instr->channel_id(), data_shape.layout()); + *send_recv_instr->channel_id(), data_shape.layout()); TF_RET_CHECK(prev_layout == nullptr) << "Cannot constrain host transfer layout as it was set to " << LayoutUtil::HumanString(*prev_layout) << ": " @@ -480,7 +480,7 @@ Status LayoutAssignment::AddMandatoryConstraints( instruction->opcode() == HloOpcode::kRecv) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; - int64 channel_id = instruction->channel_id(); + int64 channel_id = *instruction->channel_id(); if (!get_channel_constraints(instruction) ->IsChannelConstrained(channel_id)) { continue; @@ -492,7 +492,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(send_buffer_shape, - instruction->channel_id()); + *instruction->channel_id()); TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( new_buffer_shape, instruction->operand(0))); } else { @@ -503,18 +503,19 @@ Status LayoutAssignment::AddMandatoryConstraints( const LogicalBuffer* buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, {0})); - Shape new_shape = get_channel_constraints(instruction) - ->LayoutShapeForChannel( - recv_buffer_shape, instruction->channel_id()); + Shape new_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(recv_buffer_shape, + *instruction->channel_id()); TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } } else if (instruction->IsCrossModuleAllReduce()) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; - int64 all_reduce_id = instruction->all_reduce_id().value(); + int64 channel_id = instruction->channel_id().value(); if (!get_channel_constraints(instruction) - ->IsChannelConstrained(all_reduce_id)) { + ->IsChannelConstrained(channel_id)) { continue; } // TODO(b/68493863): Change to use SetOperandLayout(). @@ -522,7 +523,7 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RET_CHECK(buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) - ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + ->LayoutShapeForChannel(buffer_shape, channel_id); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(new_buffer_shape, instruction)); } @@ -1833,7 +1834,7 @@ Status LayoutAssignment::ConstrainChannelLayouts( const Layout* layout = get_channel_constraints(instruction) ->ConstrainChannel( - instruction->channel_id(), + *instruction->channel_id(), ShapeUtil::GetSubshape(instruction->shape(), {0}).layout()); TF_RET_CHECK(layout == nullptr) << instruction->ToString() @@ -1848,7 +1849,7 @@ Status LayoutAssignment::ConstrainChannelLayouts( if (instruction->opcode() == HloOpcode::kSend) { HloInstruction* operand = instruction->mutable_operand(0); const Layout* layout = get_channel_constraints(instruction) - ->ConstrainChannel(instruction->channel_id(), + ->ConstrainChannel(*instruction->channel_id(), operand->shape().layout()); if (layout != nullptr) { // We found an already constrained layout which does not match the one @@ -1873,7 +1874,7 @@ Status LayoutAssignment::ConstrainChannelLayouts( } else if (instruction->IsCrossModuleAllReduce()) { const Layout* layout = get_channel_constraints(instruction) - ->ConstrainChannel(instruction->all_reduce_id().value(), + ->ConstrainChannel(instruction->channel_id().value(), instruction->shape().layout()); if (layout != nullptr) { // We found an already constrained layout which does not match the one diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 5597afc15a3..046ffde7616 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -891,11 +891,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 ar.0 = f32[2,2] all-reduce(gte), - all_reduce_id=1, replica_groups={{0}}, to_apply=add, + channel_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} const = f32[2,2] constant({{0,1},{2,3}}) ROOT ar.1 = f32[2,2] all-reduce(const), - all_reduce_id=1, replica_groups={{0}}, to_apply=add, + channel_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m,