[XLA] Use CollectiveOpGroup mode to drive some of the HLO verification
- Use the group mode to check the size of non-empty replica groups for various XLA collective operations. - In several cases, the replica_count and num_partitions is not initialized and has a default value of 1. Skip some of these checks in those cases to not break existing tests. This can be improved further by either making these optional in the HloModule config or using 0 for the default/uninitialized value. - Fix several tests that hit verification failures due to this. PiperOrigin-RevId: 359073636 Change-Id: I9c6bac147fec2bb2f407d0d1e83698a91de4d852
This commit is contained in:
parent
8f803592a0
commit
53124e2193
@ -3787,6 +3787,7 @@ cc_library(
|
|||||||
srcs = ["hlo_verifier.cc"],
|
srcs = ["hlo_verifier.cc"],
|
||||||
hdrs = ["hlo_verifier.h"],
|
hdrs = ["hlo_verifier.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":collective_ops_utils",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_casting_utils",
|
":hlo_casting_utils",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
@ -3799,6 +3800,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -1809,7 +1809,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
|
|||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
|
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2,
|
||||||
|
/*num_partitions=*/4));
|
||||||
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
|
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
|
||||||
/*spmd_partition=*/true);
|
/*spmd_partition=*/true);
|
||||||
auto changed = combiner.Run(module.get()).ValueOrDie();
|
auto changed = combiner.Run(module.get()).ValueOrDie();
|
||||||
|
@ -107,7 +107,7 @@ absl::string_view CollectiveOpGroupModeToString(
|
|||||||
case CollectiveOpGroupMode::kCrossPartition:
|
case CollectiveOpGroupMode::kCrossPartition:
|
||||||
return "kCrossPartition";
|
return "kCrossPartition";
|
||||||
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
|
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
|
||||||
return "kCrossReplicAndPartition";
|
return "kCrossReplicaAndPartition";
|
||||||
case CollectiveOpGroupMode::kFlattenedID:
|
case CollectiveOpGroupMode::kFlattenedID:
|
||||||
return "kFlattenedID";
|
return "kFlattenedID";
|
||||||
}
|
}
|
||||||
|
@ -632,11 +632,9 @@ ENTRY entry {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
module_str, /*replica_count=*/2));
|
auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2,
|
||||||
auto config = module->config();
|
/*num_partitions=*/2));
|
||||||
config.set_num_partitions(2);
|
|
||||||
module->set_config(config);
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
std::unique_ptr<HloReplicationAnalysis> replica_analysis,
|
std::unique_ptr<HloReplicationAnalysis> replica_analysis,
|
||||||
HloReplicationAnalysis::Run(module.get(),
|
HloReplicationAnalysis::Run(module.get(),
|
||||||
|
@ -15,13 +15,13 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||||
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
@ -183,21 +183,26 @@ Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
|
|||||||
|
|
||||||
// Checks that `hlo`'s set of ReplicaGroups:
|
// Checks that `hlo`'s set of ReplicaGroups:
|
||||||
//
|
//
|
||||||
// - names each replica 0 through n-1 exactly once, and
|
// - names each replica 0 through n-1 exactly once (where n is either number of
|
||||||
|
// replicas, or number of partitions, or their product)
|
||||||
// - does not contain any empty ReplicaGroups.
|
// - does not contain any empty ReplicaGroups.
|
||||||
//
|
//
|
||||||
// Note that although none of the groups may be empty, `hlo` is allowed to have
|
// Note that although none of the groups may be empty, `hlo` is allowed to have
|
||||||
// 0 groups. That just means it has one big group.
|
// empty groups when group mode is not kFlattenedID. That just means it has one
|
||||||
|
// big group.
|
||||||
//
|
//
|
||||||
// This is just a minimal set of checks; some instructions may have additional
|
// In general, if replica groups is not empty, all replica groups should be of
|
||||||
// requirements. For example, all-to-all requires that all ReplicaGroups have
|
// the same size. The exception is all-reduce, where non-uniform replica groups
|
||||||
// the same number of replicas, but that isn't checked here.
|
// are allowed. This is controlled by `uniform_replica_group_size`.
|
||||||
static Status CheckReplicaGroups(HloInstruction* hlo,
|
static Status CheckReplicaGroups(HloInstruction* hlo,
|
||||||
bool use_global_device_ids) {
|
CollectiveOpGroupMode group_mode,
|
||||||
std::set<int64> replicas_seen;
|
bool uniform_replica_group_size = true) {
|
||||||
|
if (!hlo->replica_groups().empty()) {
|
||||||
|
absl::flat_hash_set<int64> replicas_seen;
|
||||||
for (const ReplicaGroup& g : hlo->replica_groups()) {
|
for (const ReplicaGroup& g : hlo->replica_groups()) {
|
||||||
if (g.replica_ids().empty()) {
|
if (g.replica_ids().empty()) {
|
||||||
return InternalError("Instruction cannot have an empty replica group: %s",
|
return InternalError(
|
||||||
|
"Instruction cannot have an empty replica group: %s",
|
||||||
hlo->ToString());
|
hlo->ToString());
|
||||||
}
|
}
|
||||||
for (int64 i : g.replica_ids()) {
|
for (int64 i : g.replica_ids()) {
|
||||||
@ -208,7 +213,8 @@ static Status CheckReplicaGroups(HloInstruction* hlo,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int64 i = 0; i < replicas_seen.size(); ++i) {
|
size_t n = replicas_seen.size();
|
||||||
|
for (int64 i = 0; i < n; ++i) {
|
||||||
if (!replicas_seen.count(i)) {
|
if (!replicas_seen.count(i)) {
|
||||||
return InternalError(
|
return InternalError(
|
||||||
"Replica %d is not named in instruction's replica-groups: %s", i,
|
"Replica %d is not named in instruction's replica-groups: %s", i,
|
||||||
@ -216,32 +222,48 @@ static Status CheckReplicaGroups(HloInstruction* hlo,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If use_global_device_ids() is set, replica_groups cannot be empty.
|
// replica-groups have numbers [0, n). This n should be either replica or
|
||||||
// When the channel_id() or use_global_device_ids() is set, device ids in
|
// partition count, or their product. In some cases, replica and/or
|
||||||
// ReplicaGroup config no longer only mean replica ids. So we skip the check
|
// partition count is not set in the HloModule config and has a default
|
||||||
// on the replica count.
|
// value of 1. For those cases, skip this part of the verification.
|
||||||
if (use_global_device_ids) {
|
|
||||||
if (hlo->replica_groups().empty()) {
|
|
||||||
return InternalError(
|
|
||||||
"Replica group must be specified when use_global_device_ids is true");
|
|
||||||
}
|
|
||||||
// No need to check replica_count.
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
|
|
||||||
if (channel_instr->channel_id()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 replica_count = hlo->GetModule()->config().replica_count();
|
int64 replica_count = hlo->GetModule()->config().replica_count();
|
||||||
if (replica_count != 1 && !replicas_seen.empty() &&
|
int64 num_partitions = hlo->GetModule()->config().num_partitions();
|
||||||
replicas_seen.size() != replica_count) {
|
switch (group_mode) {
|
||||||
return InternalError(
|
case CollectiveOpGroupMode::kCrossReplica:
|
||||||
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
|
case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
|
||||||
"contains %d replicas: %s",
|
TF_RET_CHECK(replica_count == 1 || n == replica_count)
|
||||||
replica_count, replicas_seen.size(), hlo->ToString());
|
<< "In " << CollectiveOpGroupModeToString(group_mode)
|
||||||
|
<< " mode, replica groups should contain " << replica_count
|
||||||
|
<< " replicas, but found " << n << ": " << hlo->ToString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CollectiveOpGroupMode::kCrossPartition: {
|
||||||
|
TF_RET_CHECK(num_partitions == 1 || n == num_partitions)
|
||||||
|
<< "In " << CollectiveOpGroupModeToString(group_mode)
|
||||||
|
<< " mode, replica groups should contain " << num_partitions
|
||||||
|
<< " partitions, but found " << n << ": " << hlo->ToString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CollectiveOpGroupMode::kFlattenedID: {
|
||||||
|
const int64 num_flattened_ids = replica_count * num_partitions;
|
||||||
|
TF_RET_CHECK(num_flattened_ids == 1 || n == num_flattened_ids)
|
||||||
|
<< "In " << CollectiveOpGroupModeToString(group_mode)
|
||||||
|
<< " mode, replica groups should contain " << num_flattened_ids
|
||||||
|
<< " flattened IDs, but found " << n << ": " << hlo->ToString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (uniform_replica_group_size) {
|
||||||
|
int64 size = hlo->replica_groups()[0].replica_ids_size();
|
||||||
|
for (const ReplicaGroup& g : hlo->replica_groups()) {
|
||||||
|
TF_RET_CHECK(size == g.replica_ids_size())
|
||||||
|
<< "Replica groups expected to be of uniform size";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TF_RET_CHECK(group_mode != CollectiveOpGroupMode::kFlattenedID)
|
||||||
|
<< "Replica groups must be specified in flattened-id mode";
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -249,7 +271,10 @@ static Status CheckReplicaGroups(HloInstruction* hlo,
|
|||||||
|
|
||||||
Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
||||||
auto ag = Cast<HloAllGatherInstruction>(hlo);
|
auto ag = Cast<HloAllGatherInstruction>(hlo);
|
||||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
|
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
|
||||||
|
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
|
||||||
|
ag->use_global_device_ids()));
|
||||||
|
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
|
||||||
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
|
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
|
||||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
|
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
|
||||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
|
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
|
||||||
@ -257,21 +282,36 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
|||||||
int64 shard_count = CeilOfRatio(
|
int64 shard_count = CeilOfRatio(
|
||||||
ag->shape().dimensions(ag->all_gather_dimension()),
|
ag->shape().dimensions(ag->all_gather_dimension()),
|
||||||
ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
|
ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
|
||||||
if (ag->channel_id().has_value()) {
|
const HloModuleConfig& config = hlo->GetModule()->config();
|
||||||
if (ag->use_global_device_ids()) {
|
// empty replica groups imply all replicas form a single group.
|
||||||
TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
|
int64 replica_subgroup_size =
|
||||||
} else {
|
ag->replica_groups().empty() ? config.replica_count()
|
||||||
if (ag->replica_groups().empty() ||
|
: ag->replica_groups()[0].replica_ids_size();
|
||||||
ag->replica_groups()[0].replica_ids_size() != 1) {
|
|
||||||
|
auto get_subgroup_size = [&]() -> StatusOr<int64> {
|
||||||
|
switch (group_mode) {
|
||||||
|
case CollectiveOpGroupMode::kCrossReplica:
|
||||||
|
case CollectiveOpGroupMode::kFlattenedID:
|
||||||
|
return replica_subgroup_size;
|
||||||
|
|
||||||
|
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
|
||||||
|
// Replicas from all partitions participate.
|
||||||
|
return replica_subgroup_size * config.num_partitions();
|
||||||
|
|
||||||
|
case CollectiveOpGroupMode::kCrossPartition:
|
||||||
return InternalError(
|
return InternalError(
|
||||||
"Replica group size must be 1 when use_global_device_ids is "
|
"kCrossPartition group mode not expected for all-gather");
|
||||||
"false if the all-gather is also cross-partition");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (!ag->replica_groups().empty()) {
|
|
||||||
// Cross-replica all-gather: shard count is subgroup size.
|
|
||||||
TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// If replica and partition count is not explicitly set, it will have a
|
||||||
|
// default value of 1, in which case the subgroup_size will be 1 as well. Skip
|
||||||
|
// these verification checks in that case.
|
||||||
|
TF_ASSIGN_OR_RETURN(int64 subgroup_size, get_subgroup_size());
|
||||||
|
TF_RET_CHECK(subgroup_size == 1 || shard_count == subgroup_size)
|
||||||
|
<< "shard_count = " << shard_count
|
||||||
|
<< ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
|
||||||
|
|
||||||
return CheckShape(ag, ShapeInference::InferAllGatherShape(
|
return CheckShape(ag, ShapeInference::InferAllGatherShape(
|
||||||
ag->operand(0)->shape(), ag->all_gather_dimension(),
|
ag->operand(0)->shape(), ag->all_gather_dimension(),
|
||||||
shard_count));
|
shard_count));
|
||||||
@ -279,7 +319,11 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
|||||||
|
|
||||||
Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
|
Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
|
||||||
auto ar = Cast<HloAllReduceInstruction>(hlo);
|
auto ar = Cast<HloAllReduceInstruction>(hlo);
|
||||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
|
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
|
||||||
|
GetCollectiveOpGroupMode(ar->channel_id().has_value(),
|
||||||
|
ar->use_global_device_ids()));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CheckReplicaGroups(ar, group_mode, /*uniform_replica_group_size=*/false));
|
||||||
|
|
||||||
std::vector<const Shape*> operand_shapes;
|
std::vector<const Shape*> operand_shapes;
|
||||||
for (const HloInstruction* operand : hlo->operands()) {
|
for (const HloInstruction* operand : hlo->operands()) {
|
||||||
@ -290,7 +334,11 @@ Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
|
|||||||
|
|
||||||
Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
|
Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
|
||||||
auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
|
auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
|
||||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
|
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
|
||||||
|
GetCollectiveOpGroupMode(
|
||||||
|
all_to_all->channel_id().has_value(), absl::nullopt));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode));
|
||||||
|
|
||||||
TF_RET_CHECK(all_to_all != nullptr);
|
TF_RET_CHECK(all_to_all != nullptr);
|
||||||
if (all_to_all->split_dimension()) {
|
if (all_to_all->split_dimension()) {
|
||||||
@ -300,21 +348,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The size of each replica group must be the same (the split count of the
|
// The size of each replica group must be the same (checked in
|
||||||
// operaion). In case the default replica group is used (empty replica group,
|
// CheckReplicaGroups). This is the split count of the operation). In case the
|
||||||
// must not be an array all-to-all, as checked above), infer from the number
|
// empty replica group is used must not be an array all-to-all, as checked
|
||||||
// of operands.
|
// above), infer from the number of operands.
|
||||||
const int64 split_count = hlo->replica_groups().empty()
|
const int64 split_count = hlo->replica_groups().empty()
|
||||||
? hlo->operand_count()
|
? hlo->operand_count()
|
||||||
: hlo->replica_groups()[0].replica_ids_size();
|
: hlo->replica_groups()[0].replica_ids_size();
|
||||||
for (const ReplicaGroup& g : hlo->replica_groups()) {
|
|
||||||
if (g.replica_ids_size() != split_count) {
|
|
||||||
return InternalError(
|
|
||||||
"Replica group has size %d, but all replica groups in an all-to-all "
|
|
||||||
"must have size N: %s",
|
|
||||||
g.replica_ids_size(), hlo->ToString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (all_to_all->split_dimension()) {
|
if (all_to_all->split_dimension()) {
|
||||||
TF_RET_CHECK(hlo->operand_count() == 1);
|
TF_RET_CHECK(hlo->operand_count() == 1);
|
||||||
|
@ -886,9 +886,29 @@ int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
|
|||||||
return replica_count;
|
return replica_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
|
||||||
|
std::vector<std::vector<int64>> replica_groups,
|
||||||
|
absl::optional<int64> replica_count, absl::optional<int64> num_partitions,
|
||||||
|
absl::string_view other_attributes, absl::string_view template_str) {
|
||||||
|
HloModuleConfig config;
|
||||||
|
config.set_replica_count(
|
||||||
|
replica_count.value_or(ReplicaCount(replica_groups)));
|
||||||
|
config.set_num_partitions(num_partitions.value_or(1));
|
||||||
|
return ParseAndReturnUnverifiedModule(
|
||||||
|
absl::StrReplaceAll(
|
||||||
|
template_str,
|
||||||
|
{{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
|
||||||
|
{"OTHER_ATTRIBUTES", other_attributes.empty()
|
||||||
|
? ""
|
||||||
|
: absl::StrCat(",", other_attributes)}}),
|
||||||
|
config);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
||||||
std::vector<std::vector<int64>> replica_groups,
|
std::vector<std::vector<int64>> replica_groups,
|
||||||
absl::optional<int64> replica_count = absl::nullopt) {
|
absl::optional<int64> replica_count = absl::nullopt,
|
||||||
|
absl::optional<int64> num_partitions = absl::nullopt,
|
||||||
|
absl::string_view other_attributes = "") {
|
||||||
const char* kTemplate = R"(
|
const char* kTemplate = R"(
|
||||||
HloModule test
|
HloModule test
|
||||||
add {
|
add {
|
||||||
@ -899,18 +919,11 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
|
|||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
p = f32[128]{0} parameter(0)
|
p = f32[128]{0} parameter(0)
|
||||||
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
|
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
|
||||||
|
OTHER_ATTRIBUTES
|
||||||
})";
|
})";
|
||||||
|
return MakeCollectiveCommOpComputation(replica_groups, replica_count,
|
||||||
HloModuleConfig config;
|
num_partitions, other_attributes,
|
||||||
if (replica_count) {
|
kTemplate);
|
||||||
config.set_replica_count(*replica_count);
|
|
||||||
} else {
|
|
||||||
config.set_replica_count(ReplicaCount(replica_groups));
|
|
||||||
}
|
|
||||||
return ParseAndReturnUnverifiedModule(
|
|
||||||
absl::StrReplaceAll(
|
|
||||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
|
|
||||||
config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
|
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
|
||||||
@ -947,33 +960,70 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
|
|||||||
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
|
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
|
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
|
||||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
HasSubstr("Replica count in HloModuleConfig is 8, but "
|
HasSubstr("In kCrossReplica mode, replica groups should contain "
|
||||||
"ReplicaGroup config contains 2 replicas"));
|
"8 replicas, but found 2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
|
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
|
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
|
||||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
HasSubstr("Replica count in HloModuleConfig is 2, but "
|
HasSubstr("In kCrossReplica mode, replica groups should contain "
|
||||||
"ReplicaGroup config contains 4 replicas"));
|
"2 replicas, but found 4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto module,
|
||||||
|
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
verifier().Run(module.get()).status().error_message(),
|
||||||
|
HasSubstr(
|
||||||
|
"In kCrossReplicaAndPartition mode, replica groups should contain "
|
||||||
|
"2 replicas, but found 4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto module,
|
||||||
|
MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
|
||||||
|
TF_ASSERT_OK(verifier().Run(module.get()).status());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto module,
|
||||||
|
MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
|
||||||
|
"channel_id=1, use_global_device_ids=true"));
|
||||||
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
|
HasSubstr("In kFlattenedID mode, replica groups should contain "
|
||||||
|
"2 flattened IDs, but found 4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto module,
|
||||||
|
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
|
||||||
|
"channel_id=1, use_global_device_ids=true"));
|
||||||
|
TF_ASSERT_OK(verifier().Run(module.get()).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
|
||||||
std::vector<std::vector<int64>> replica_groups) {
|
std::vector<std::vector<int64>> replica_groups,
|
||||||
|
absl::optional<int64> replica_count = absl::nullopt,
|
||||||
|
absl::optional<int64> num_partitions = absl::nullopt,
|
||||||
|
absl::string_view other_attributes = "") {
|
||||||
const char* kTemplate = R"(
|
const char* kTemplate = R"(
|
||||||
HloModule test
|
HloModule test
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
p0 = f32[128]{0} parameter(0)
|
p0 = f32[128]{0} parameter(0)
|
||||||
p1 = f32[128]{0} parameter(1)
|
p1 = f32[128]{0} parameter(1)
|
||||||
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
|
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
|
||||||
|
OTHER_ATTRIBUTES
|
||||||
})";
|
})";
|
||||||
HloModuleConfig config;
|
return MakeCollectiveCommOpComputation(replica_groups, replica_count,
|
||||||
config.set_replica_count(ReplicaCount(replica_groups));
|
num_partitions, other_attributes,
|
||||||
return ParseAndReturnUnverifiedModule(
|
kTemplate);
|
||||||
absl::StrReplaceAll(
|
|
||||||
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
|
|
||||||
config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
|
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
|
||||||
@ -984,7 +1034,7 @@ TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
|
|||||||
TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
|
TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
|
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
|
||||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
HasSubstr("empty replica group"));
|
HasSubstr("cannot have an empty replica group"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
|
TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
|
||||||
@ -1001,11 +1051,27 @@ TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
|
|||||||
HasSubstr("Replica 4 is not named"));
|
HasSubstr("Replica 4 is not named"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
|
TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
|
MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
|
||||||
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
HasSubstr("Replica group has size 1"));
|
HasSubstr("Replica groups expected to be of uniform size"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto module,
|
||||||
|
MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
|
||||||
|
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
|
||||||
|
HasSubstr("In kCrossPartition mode, replica groups should "
|
||||||
|
"contain 2 partitions, but found 4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto module,
|
||||||
|
MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
|
||||||
|
TF_ASSERT_OK(verifier().Run(module.get()).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
|
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
|
||||||
@ -1316,6 +1382,30 @@ TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
|
|||||||
ROOT add = f32[] add(lhs, rhs)
|
ROOT add = f32[] add(lhs, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ENTRY CRS {
|
||||||
|
input = f32[8]{0} parameter(0)
|
||||||
|
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
|
||||||
|
use_global_device_ids=true, to_apply=add
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnUnverifiedModule(hlo_string));
|
||||||
|
|
||||||
|
auto status = verifier().Run(module.get()).status();
|
||||||
|
ASSERT_FALSE(status.ok());
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.error_message(),
|
||||||
|
HasSubstr("Replica groups must be specified in flattened-id mode"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule Module
|
||||||
|
add {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
ENTRY CRS {
|
ENTRY CRS {
|
||||||
input = f32[8]{0} parameter(0)
|
input = f32[8]{0} parameter(0)
|
||||||
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
|
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
|
||||||
@ -1326,9 +1416,10 @@ TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
|
|||||||
|
|
||||||
auto status = verifier().Run(module.get()).status();
|
auto status = verifier().Run(module.get()).status();
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.error_message(),
|
EXPECT_THAT(
|
||||||
HasSubstr("Replica group must be specified when "
|
status.error_message(),
|
||||||
"use_global_device_ids is true"));
|
HasSubstr(
|
||||||
|
"Invalid combination of has_channel_id and use_global_device_ids"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -112,9 +112,9 @@ HloModule module
|
|||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
param0 = s32[1,8]{1,0} parameter(0)
|
param0 = s32[1,8]{1,0} parameter(0)
|
||||||
ag1 = s32[1,16]{1,0} all-gather(param0), replica_groups={{0,1}}, dimensions={1},
|
ag1 = s32[1,16]{1,0} all-gather(param0), replica_groups={{0,1}}, dimensions={1},
|
||||||
channel_id=0, use_global_device_ids=true
|
channel_id=0
|
||||||
ag2 = s32[1,16]{1,0} all-gather(param0), replica_groups={{0,1}},
|
ag2 = s32[1,16]{1,0} all-gather(param0), replica_groups={{0,1}},
|
||||||
dimensions={1}, use_global_device_ids=true
|
dimensions={1}
|
||||||
ROOT tuple = (s32[1,16]{1,0}, s32[1,16]{1,0}) tuple(ag1, ag2)
|
ROOT tuple = (s32[1,16]{1,0}, s32[1,16]{1,0}) tuple(ag1, ag2)
|
||||||
})";
|
})";
|
||||||
auto module_status = RunPass(hlo_string);
|
auto module_status = RunPass(hlo_string);
|
||||||
|
@ -126,9 +126,10 @@ std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||||
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
|
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
|
||||||
int64 replica_count) {
|
int64 replica_count,
|
||||||
return ParseAndReturnVerifiedModule(hlo_text,
|
int64_t num_partitions) {
|
||||||
GetModuleConfigForTest(replica_count));
|
return ParseAndReturnVerifiedModule(
|
||||||
|
hlo_text, GetModuleConfigForTest(replica_count, num_partitions));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
StatusOr<std::unique_ptr<VerifiedHloModule>>
|
||||||
|
@ -89,7 +89,8 @@ class HloTestBase : public ManifestCheckingTest {
|
|||||||
|
|
||||||
// Parses the given string and returns module as a VerifiedHloModule.
|
// Parses the given string and returns module as a VerifiedHloModule.
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
absl::string_view hlo_text, int64 replica_count = 1);
|
absl::string_view hlo_text, int64 replica_count = 1,
|
||||||
|
int64 num_partitions = 1);
|
||||||
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
|
||||||
absl::string_view hlo_text, const HloModuleConfig& config);
|
absl::string_view hlo_text, const HloModuleConfig& config);
|
||||||
|
|
||||||
@ -135,10 +136,12 @@ class HloTestBase : public ManifestCheckingTest {
|
|||||||
virtual DebugOptions GetDebugOptionsForTest();
|
virtual DebugOptions GetDebugOptionsForTest();
|
||||||
|
|
||||||
// Gets an HloModuleConfig with options appropriate for tests.
|
// Gets an HloModuleConfig with options appropriate for tests.
|
||||||
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) {
|
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1,
|
||||||
|
int64 num_partitions = 1) {
|
||||||
HloModuleConfig config;
|
HloModuleConfig config;
|
||||||
config.set_debug_options(GetDebugOptionsForTest());
|
config.set_debug_options(GetDebugOptionsForTest());
|
||||||
config.set_replica_count(replica_count);
|
config.set_replica_count(replica_count);
|
||||||
|
config.set_num_partitions(num_partitions);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user