[XLA] Make replication analysis consider AllGather and global device IDs.
PiperOrigin-RevId: 315756511 Change-Id: I80862fee5915a4077526287946a097c06a1b7057
This commit is contained in:
parent
e47873165f
commit
31862b7bbe
@ -3010,6 +3010,7 @@ cc_library(
|
||||
hdrs = ["hlo_replication_analysis.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
@ -16,13 +16,17 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -51,25 +55,54 @@ bool DetermineHloInstructionIsReplicated(
|
||||
return true;
|
||||
};
|
||||
|
||||
if (hlo->opcode() == HloOpcode::kAllReduce) {
|
||||
// All-reduce returns same values across partitions/replicas as long as its
|
||||
// operands are replicated.
|
||||
if (hlo->opcode() == HloOpcode::kAllReduce ||
|
||||
hlo->opcode() == HloOpcode::kAllGather) {
|
||||
// All-reduce/all-gather returns same values across partitions/replicas as
|
||||
// long as its operands are replicated.
|
||||
if (all_operands_replicated(hlo)) {
|
||||
return true;
|
||||
}
|
||||
if (hlo->IsCrossReplicaAllReduce()) {
|
||||
if (!hlo->channel_id().has_value()) {
|
||||
// This is cross-replica-only.
|
||||
if (cross_partition_spmd) {
|
||||
return false;
|
||||
}
|
||||
// Only all-reduce across all cores are replicated, which means there
|
||||
// is only one subgroup.
|
||||
// Only all-reduce/all-gather across all cores are replicated, which means
|
||||
// there is only one subgroup.
|
||||
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
|
||||
} else {
|
||||
CHECK(hlo->IsCrossModuleAllReduce());
|
||||
if (cross_partition_spmd) {
|
||||
return true;
|
||||
bool global_id;
|
||||
if (hlo->opcode() == HloOpcode::kAllReduce) {
|
||||
global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
|
||||
} else {
|
||||
global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
|
||||
}
|
||||
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
|
||||
if (global_id) {
|
||||
bool replicated_across_partitions = true;
|
||||
bool replicated_across_replicas = true;
|
||||
const int64 num_partitions =
|
||||
hlo->GetModule()->config().num_partitions();
|
||||
for (const auto& group : hlo->replica_groups()) {
|
||||
absl::flat_hash_set<int64> visited_partitions;
|
||||
absl::flat_hash_set<int64> visited_replicas;
|
||||
for (int64 id : group.replica_ids()) {
|
||||
int64 rid = id / num_partitions;
|
||||
int64 pid = id % num_partitions;
|
||||
visited_partitions.insert(pid);
|
||||
visited_replicas.insert(rid);
|
||||
}
|
||||
replicated_across_partitions &=
|
||||
visited_partitions.size() == num_partitions;
|
||||
replicated_across_replicas &=
|
||||
visited_replicas.size() ==
|
||||
hlo->GetModule()->config().replica_count();
|
||||
}
|
||||
return cross_partition_spmd ? replicated_across_partitions
|
||||
: replicated_across_replicas;
|
||||
}
|
||||
return cross_partition_spmd ? true
|
||||
: hlo->replica_groups().empty() ||
|
||||
hlo->replica_groups().size() == 1;
|
||||
}
|
||||
}
|
||||
if (hlo->HasSideEffectNoRecurse()) {
|
||||
|
@ -586,5 +586,49 @@ ENTRY entry {
|
||||
FindInstruction(module.get(), "ar1"), {}));
|
||||
}
|
||||
|
||||
TEST_F(HloReplicationAnalysisTest, GlobalIdAllGather) {
|
||||
const string module_str = R"(
|
||||
HloModule GlobalIdAllGather
|
||||
|
||||
ENTRY entry {
|
||||
param = f32[1] parameter(0)
|
||||
ag1 = f32[2] all-gather(param), replica_groups={{0,1},{2,3}}, dimensions={0},
|
||||
use_global_device_ids=true, channel_id=1
|
||||
ag2 = f32[2] all-gather(param), replica_groups={{0,2},{1,3}}, dimensions={0},
|
||||
use_global_device_ids=true, channel_id=2
|
||||
ag3 = f32[4] all-gather(param), replica_groups={{0,1,2,3}}, dimensions={0},
|
||||
use_global_device_ids=true, channel_id=3
|
||||
ROOT tuple = (f32[], f32[], f32[]) tuple(ag1, ag2, ag3)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
||||
module_str, /*replica_count=*/2));
|
||||
auto config = module->config();
|
||||
config.set_num_partitions(2);
|
||||
module->set_config(config);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloReplicationAnalysis> replica_analysis,
|
||||
HloReplicationAnalysis::Run(module.get(),
|
||||
/*cross_partition_spmd=*/false));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloReplicationAnalysis> partition_analysis,
|
||||
HloReplicationAnalysis::Run(module.get(),
|
||||
/*cross_partition_spmd=*/true));
|
||||
EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "ag1"), {}));
|
||||
EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "ag2"), {}));
|
||||
EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "ag3"), {}));
|
||||
|
||||
EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "ag1"), {}));
|
||||
EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "ag2"), {}));
|
||||
EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt(
|
||||
FindInstruction(module.get(), "ag3"), {}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user