[XLA] Make replication analysis consider AllGather and global device IDs.

PiperOrigin-RevId: 315756511
Change-Id: I80862fee5915a4077526287946a097c06a1b7057
This commit is contained in:
Yuanzhong Xu 2020-06-10 13:28:06 -07:00 committed by TensorFlower Gardener
parent e47873165f
commit 31862b7bbe
3 changed files with 88 additions and 10 deletions

View File

@ -3010,6 +3010,7 @@ cc_library(
hdrs = ["hlo_replication_analysis.h"],
deps = [
":hlo",
":hlo_casting_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",

View File

@ -16,13 +16,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
#include <memory>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -51,25 +55,54 @@ bool DetermineHloInstructionIsReplicated(
return true;
};
if (hlo->opcode() == HloOpcode::kAllReduce) {
// All-reduce returns same values across partitions/replicas as long as its
// operands are replicated.
if (hlo->opcode() == HloOpcode::kAllReduce ||
hlo->opcode() == HloOpcode::kAllGather) {
// All-reduce/all-gather returns same values across partitions/replicas as
// long as its operands are replicated.
if (all_operands_replicated(hlo)) {
return true;
}
if (hlo->IsCrossReplicaAllReduce()) {
if (!hlo->channel_id().has_value()) {
// This is cross-replica-only.
if (cross_partition_spmd) {
return false;
}
// Only all-reduce across all cores are replicated, which means there
// is only one subgroup.
// Only all-reduce/all-gather across all cores are replicated, which means
// there is only one subgroup.
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
} else {
CHECK(hlo->IsCrossModuleAllReduce());
if (cross_partition_spmd) {
return true;
bool global_id;
if (hlo->opcode() == HloOpcode::kAllReduce) {
global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
} else {
global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
}
return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
if (global_id) {
bool replicated_across_partitions = true;
bool replicated_across_replicas = true;
const int64 num_partitions =
hlo->GetModule()->config().num_partitions();
for (const auto& group : hlo->replica_groups()) {
absl::flat_hash_set<int64> visited_partitions;
absl::flat_hash_set<int64> visited_replicas;
for (int64 id : group.replica_ids()) {
int64 rid = id / num_partitions;
int64 pid = id % num_partitions;
visited_partitions.insert(pid);
visited_replicas.insert(rid);
}
replicated_across_partitions &=
visited_partitions.size() == num_partitions;
replicated_across_replicas &=
visited_replicas.size() ==
hlo->GetModule()->config().replica_count();
}
return cross_partition_spmd ? replicated_across_partitions
: replicated_across_replicas;
}
return cross_partition_spmd ? true
: hlo->replica_groups().empty() ||
hlo->replica_groups().size() == 1;
}
}
if (hlo->HasSideEffectNoRecurse()) {

View File

@ -586,5 +586,49 @@ ENTRY entry {
FindInstruction(module.get(), "ar1"), {}));
}
TEST_F(HloReplicationAnalysisTest, GlobalIdAllGather) {
const string module_str = R"(
HloModule GlobalIdAllGather
ENTRY entry {
param = f32[1] parameter(0)
ag1 = f32[2] all-gather(param), replica_groups={{0,1},{2,3}}, dimensions={0},
use_global_device_ids=true, channel_id=1
ag2 = f32[2] all-gather(param), replica_groups={{0,2},{1,3}}, dimensions={0},
use_global_device_ids=true, channel_id=2
ag3 = f32[4] all-gather(param), replica_groups={{0,1,2,3}}, dimensions={0},
use_global_device_ids=true, channel_id=3
ROOT tuple = (f32[], f32[], f32[]) tuple(ag1, ag2, ag3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
module_str, /*replica_count=*/2));
auto config = module->config();
config.set_num_partitions(2);
module->set_config(config);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloReplicationAnalysis> replica_analysis,
HloReplicationAnalysis::Run(module.get(),
/*cross_partition_spmd=*/false));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloReplicationAnalysis> partition_analysis,
HloReplicationAnalysis::Run(module.get(),
/*cross_partition_spmd=*/true));
EXPECT_FALSE(replica_analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "ag1"), {}));
EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "ag2"), {}));
EXPECT_TRUE(replica_analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "ag3"), {}));
EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "ag1"), {}));
EXPECT_FALSE(partition_analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "ag2"), {}));
EXPECT_TRUE(partition_analysis->HloInstructionIsReplicatedAt(
FindInstruction(module.get(), "ag3"), {}));
}
} // namespace
} // namespace xla