From 31862b7bbef086a6a54f1482f2e634fb1006edf3 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 10 Jun 2020 13:28:06 -0700 Subject: [PATCH] [XLA] Make replication analysis consider AllGather and global device IDs. PiperOrigin-RevId: 315756511 Change-Id: I80862fee5915a4077526287946a097c06a1b7057 --- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/hlo_replication_analysis.cc | 53 +++++++++++++++---- .../service/hlo_replication_analysis_test.cc | 44 +++++++++++++++ 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4fc052ce17a..4ad21d1d765 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3010,6 +3010,7 @@ cc_library( hdrs = ["hlo_replication_analysis.h"], deps = [ ":hlo", + ":hlo_casting_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc index ab5c98ab48a..dec119d8aba 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc @@ -16,13 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h" #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -51,25 +55,54 @@ bool DetermineHloInstructionIsReplicated( return true; }; - if (hlo->opcode() == HloOpcode::kAllReduce) { - // All-reduce returns same values across partitions/replicas as long as its - // operands are replicated. + if (hlo->opcode() == HloOpcode::kAllReduce || + hlo->opcode() == HloOpcode::kAllGather) { + // All-reduce/all-gather returns same values across partitions/replicas as + // long as its operands are replicated. if (all_operands_replicated(hlo)) { return true; } - if (hlo->IsCrossReplicaAllReduce()) { + if (!hlo->channel_id().has_value()) { + // This is cross-replica-only. if (cross_partition_spmd) { return false; } - // Only all-reduce across all cores are replicated, which means there - // is only one subgroup. + // Only all-reduce/all-gather across all cores are replicated, which means + // there is only one subgroup. return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; } else { - CHECK(hlo->IsCrossModuleAllReduce()); - if (cross_partition_spmd) { - return true; + bool global_id; + if (hlo->opcode() == HloOpcode::kAllReduce) { + global_id = Cast(hlo)->use_global_device_ids(); + } else { + global_id = Cast(hlo)->use_global_device_ids(); } - return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1; + if (global_id) { + bool replicated_across_partitions = true; + bool replicated_across_replicas = true; + const int64 num_partitions = + hlo->GetModule()->config().num_partitions(); + for (const auto& group : hlo->replica_groups()) { + absl::flat_hash_set visited_partitions; + absl::flat_hash_set visited_replicas; + for (int64 id : group.replica_ids()) { + int64 rid = id / num_partitions; + int64 pid = id % num_partitions; + visited_partitions.insert(pid); + visited_replicas.insert(rid); + } + replicated_across_partitions &= + visited_partitions.size() == num_partitions; + replicated_across_replicas &= + visited_replicas.size() == + hlo->GetModule()->config().replica_count(); + } + return cross_partition_spmd ? replicated_across_partitions + : replicated_across_replicas; + } + return cross_partition_spmd ? true + : hlo->replica_groups().empty() || + hlo->replica_groups().size() == 1; } } if (hlo->HasSideEffectNoRecurse()) { diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index d858d6aa1c7..c2d86e808c2 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -586,5 +586,49 @@ ENTRY entry { FindInstruction(module.get(), "ar1"), {})); } +TEST_F(HloReplicationAnalysisTest, GlobalIdAllGather) { + const string module_str = R"( +HloModule GlobalIdAllGather + +ENTRY entry { + param = f32[1] parameter(0) + ag1 = f32[2] all-gather(param), replica_groups={{0,1},{2,3}}, dimensions={0}, + use_global_device_ids=true, channel_id=1 + ag2 = f32[2] all-gather(param), replica_groups={{0,2},{1,3}}, dimensions={0}, + use_global_device_ids=true, channel_id=2 + ag3 = f32[4] all-gather(param), replica_groups={{0,1,2,3}}, dimensions={0}, + use_global_device_ids=true, channel_id=3 + ROOT tuple = (f32[], f32[], f32[]) tuple(ag1, ag2, ag3) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/2)); + auto config = module->config(); + config.set_num_partitions(2); + module->set_config(config); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr replica_analysis, + HloReplicationAnalysis::Run(module.get(), + /*cross_partition_spmd=*/false)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr partition_analysis, + HloReplicationAnalysis::Run(module.get(), + /*cross_partition_spmd=*/true)); + EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag1"), {})); + EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag2"), {})); + EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag3"), {})); + + EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag1"), {})); + EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag2"), {})); + EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "ag3"), {})); +} + } // namespace } // namespace xla