[XLA] Use CollectiveOpGroup mode to drive some of the HLO verification

- Use the group mode to check the size of non-empty replica groups for various
  XLA collective operations.
- In several cases, the replica_count and num_partitions is not initialized and has a
  default value of 1. Skip some of these checks in those cases to not break existing
  tests. This can be improved further by either making these optional in the HloModule
  config or using 0 for the default/uninitialized value.
- Fix several tests that hit verification failures due to this.

PiperOrigin-RevId: 359073636
Change-Id: I9c6bac147fec2bb2f407d0d1e83698a91de4d852
This commit is contained in:
Rahul Joshi 2021-02-23 10:08:04 -08:00 committed by TensorFlower Gardener
parent 8f803592a0
commit 53124e2193
9 changed files with 255 additions and 119 deletions

View File

@ -3787,6 +3787,7 @@ cc_library(
srcs = ["hlo_verifier.cc"],
hdrs = ["hlo_verifier.h"],
deps = [
":collective_ops_utils",
":hlo",
":hlo_casting_utils",
":hlo_pass",
@ -3799,6 +3800,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],

View File

@ -1809,7 +1809,8 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2,
/*num_partitions=*/4));
ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2,
/*spmd_partition=*/true);
auto changed = combiner.Run(module.get()).ValueOrDie();

View File

@ -107,7 +107,7 @@ absl::string_view CollectiveOpGroupModeToString(
case CollectiveOpGroupMode::kCrossPartition:
return "kCrossPartition";
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
return "kCrossReplicAndPartition";
return "kCrossReplicaAndPartition";
case CollectiveOpGroupMode::kFlattenedID:
return "kFlattenedID";
}

View File

@ -632,11 +632,9 @@ ENTRY entry {
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
module_str, /*replica_count=*/2));
auto config = module->config();
config.set_num_partitions(2);
module->set_config(config);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2,
/*num_partitions=*/2));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloReplicationAnalysis> replica_analysis,
HloReplicationAnalysis::Run(module.get(),

View File

@ -15,13 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include <set>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -183,65 +183,87 @@ Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
// Checks that `hlo`'s set of ReplicaGroups:
//
// - names each replica 0 through n-1 exactly once, and
// - names each replica 0 through n-1 exactly once (where n is either number of
// replicas, or number of partitions, or their product)
// - does not contain any empty ReplicaGroups.
//
// Note that although none of the groups may be empty, `hlo` is allowed to have
// 0 groups. That just means it has one big group.
// empty groups when group mode is not kFlattenedID. That just means it has one
// big group.
//
// This is just a minimal set of checks; some instructions may have additional
// requirements. For example, all-to-all requires that all ReplicaGroups have
// the same number of replicas, but that isn't checked here.
// In general, if replica groups is not empty, all replica groups should be of
// the same size. The exception is all-reduce, where non-uniform replica groups
// are allowed. This is controlled by `uniform_replica_group_size`.
static Status CheckReplicaGroups(HloInstruction* hlo,
bool use_global_device_ids) {
std::set<int64> replicas_seen;
for (const ReplicaGroup& g : hlo->replica_groups()) {
if (g.replica_ids().empty()) {
return InternalError("Instruction cannot have an empty replica group: %s",
hlo->ToString());
}
for (int64 i : g.replica_ids()) {
if (!replicas_seen.insert(i).second) {
CollectiveOpGroupMode group_mode,
bool uniform_replica_group_size = true) {
if (!hlo->replica_groups().empty()) {
absl::flat_hash_set<int64> replicas_seen;
for (const ReplicaGroup& g : hlo->replica_groups()) {
if (g.replica_ids().empty()) {
return InternalError(
"Replica %d is repeated in instruction's replica-groups: %s", i,
"Instruction cannot have an empty replica group: %s",
hlo->ToString());
}
for (int64 i : g.replica_ids()) {
if (!replicas_seen.insert(i).second) {
return InternalError(
"Replica %d is repeated in instruction's replica-groups: %s", i,
hlo->ToString());
}
}
}
size_t n = replicas_seen.size();
for (int64 i = 0; i < n; ++i) {
if (!replicas_seen.count(i)) {
return InternalError(
"Replica %d is not named in instruction's replica-groups: %s", i,
hlo->ToString());
}
}
}
for (int64 i = 0; i < replicas_seen.size(); ++i) {
if (!replicas_seen.count(i)) {
return InternalError(
"Replica %d is not named in instruction's replica-groups: %s", i,
hlo->ToString());
}
}
// If use_global_device_ids() is set, replica_groups cannot be empty.
// When the channel_id() or use_global_device_ids() is set, device ids in
// ReplicaGroup config no longer only mean replica ids. So we skip the check
// on the replica count.
if (use_global_device_ids) {
if (hlo->replica_groups().empty()) {
return InternalError(
"Replica group must be specified when use_global_device_ids is true");
// replica-groups have numbers [0, n). This n should be either replica or
// partition count, or their product. In some cases, replica and/or
// partition count is not set in the HloModule config and has a default
// value of 1. For those cases, skip this part of the verification.
int64 replica_count = hlo->GetModule()->config().replica_count();
int64 num_partitions = hlo->GetModule()->config().num_partitions();
switch (group_mode) {
case CollectiveOpGroupMode::kCrossReplica:
case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
TF_RET_CHECK(replica_count == 1 || n == replica_count)
<< "In " << CollectiveOpGroupModeToString(group_mode)
<< " mode, replica groups should contain " << replica_count
<< " replicas, but found " << n << ": " << hlo->ToString();
break;
}
case CollectiveOpGroupMode::kCrossPartition: {
TF_RET_CHECK(num_partitions == 1 || n == num_partitions)
<< "In " << CollectiveOpGroupModeToString(group_mode)
<< " mode, replica groups should contain " << num_partitions
<< " partitions, but found " << n << ": " << hlo->ToString();
break;
}
case CollectiveOpGroupMode::kFlattenedID: {
const int64 num_flattened_ids = replica_count * num_partitions;
TF_RET_CHECK(num_flattened_ids == 1 || n == num_flattened_ids)
<< "In " << CollectiveOpGroupModeToString(group_mode)
<< " mode, replica groups should contain " << num_flattened_ids
<< " flattened IDs, but found " << n << ": " << hlo->ToString();
break;
}
}
// No need to check replica_count.
return Status::OK();
}
if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
if (channel_instr->channel_id()) {
return Status::OK();
if (uniform_replica_group_size) {
int64 size = hlo->replica_groups()[0].replica_ids_size();
for (const ReplicaGroup& g : hlo->replica_groups()) {
TF_RET_CHECK(size == g.replica_ids_size())
<< "Replica groups expected to be of uniform size";
}
}
}
int64 replica_count = hlo->GetModule()->config().replica_count();
if (replica_count != 1 && !replicas_seen.empty() &&
replicas_seen.size() != replica_count) {
return InternalError(
"Replica count in HloModuleConfig is %d, but ReplicaGroup config "
"contains %d replicas: %s",
replica_count, replicas_seen.size(), hlo->ToString());
} else {
TF_RET_CHECK(group_mode != CollectiveOpGroupMode::kFlattenedID)
<< "Replica groups must be specified in flattened-id mode";
}
return Status::OK();
@ -249,7 +271,10 @@ static Status CheckReplicaGroups(HloInstruction* hlo,
Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
auto ag = Cast<HloAllGatherInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ag->channel_id().has_value(),
ag->use_global_device_ids()));
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
@ -257,21 +282,36 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
int64 shard_count = CeilOfRatio(
ag->shape().dimensions(ag->all_gather_dimension()),
ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
if (ag->channel_id().has_value()) {
if (ag->use_global_device_ids()) {
TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
} else {
if (ag->replica_groups().empty() ||
ag->replica_groups()[0].replica_ids_size() != 1) {
const HloModuleConfig& config = hlo->GetModule()->config();
// empty replica groups imply all replicas form a single group.
int64 replica_subgroup_size =
ag->replica_groups().empty() ? config.replica_count()
: ag->replica_groups()[0].replica_ids_size();
auto get_subgroup_size = [&]() -> StatusOr<int64> {
switch (group_mode) {
case CollectiveOpGroupMode::kCrossReplica:
case CollectiveOpGroupMode::kFlattenedID:
return replica_subgroup_size;
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
// Replicas from all partitions participate.
return replica_subgroup_size * config.num_partitions();
case CollectiveOpGroupMode::kCrossPartition:
return InternalError(
"Replica group size must be 1 when use_global_device_ids is "
"false if the all-gather is also cross-partition");
}
"kCrossPartition group mode not expected for all-gather");
}
} else if (!ag->replica_groups().empty()) {
// Cross-replica all-gather: shard count is subgroup size.
TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
}
};
// If replica and partition count is not explicitly set, it will have a
// default value of 1, in which case the subgroup_size will be 1 as well. Skip
// these verification checks in that case.
TF_ASSIGN_OR_RETURN(int64 subgroup_size, get_subgroup_size());
TF_RET_CHECK(subgroup_size == 1 || shard_count == subgroup_size)
<< "shard_count = " << shard_count
<< ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
return CheckShape(ag, ShapeInference::InferAllGatherShape(
ag->operand(0)->shape(), ag->all_gather_dimension(),
shard_count));
@ -279,7 +319,11 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
auto ar = Cast<HloAllReduceInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(ar->channel_id().has_value(),
ar->use_global_device_ids()));
TF_RETURN_IF_ERROR(
CheckReplicaGroups(ar, group_mode, /*uniform_replica_group_size=*/false));
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : hlo->operands()) {
@ -290,7 +334,11 @@ Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(
all_to_all->channel_id().has_value(), absl::nullopt));
TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode));
TF_RET_CHECK(all_to_all != nullptr);
if (all_to_all->split_dimension()) {
@ -300,21 +348,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
}
}
// The size of each replica group must be the same (the split count of the
// operaion). In case the default replica group is used (empty replica group,
// must not be an array all-to-all, as checked above), infer from the number
// of operands.
// The size of each replica group must be the same (checked in
// CheckReplicaGroups). This is the split count of the operation). In case the
// empty replica group is used must not be an array all-to-all, as checked
// above), infer from the number of operands.
const int64 split_count = hlo->replica_groups().empty()
? hlo->operand_count()
: hlo->replica_groups()[0].replica_ids_size();
for (const ReplicaGroup& g : hlo->replica_groups()) {
if (g.replica_ids_size() != split_count) {
return InternalError(
"Replica group has size %d, but all replica groups in an all-to-all "
"must have size N: %s",
g.replica_ids_size(), hlo->ToString());
}
}
if (all_to_all->split_dimension()) {
TF_RET_CHECK(hlo->operand_count() == 1);

View File

@ -886,9 +886,29 @@ int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
return replica_count;
}
StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
std::vector<std::vector<int64>> replica_groups,
absl::optional<int64> replica_count, absl::optional<int64> num_partitions,
absl::string_view other_attributes, absl::string_view template_str) {
HloModuleConfig config;
config.set_replica_count(
replica_count.value_or(ReplicaCount(replica_groups)));
config.set_num_partitions(num_partitions.value_or(1));
return ParseAndReturnUnverifiedModule(
absl::StrReplaceAll(
template_str,
{{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
{"OTHER_ATTRIBUTES", other_attributes.empty()
? ""
: absl::StrCat(",", other_attributes)}}),
config);
}
StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
std::vector<std::vector<int64>> replica_groups,
absl::optional<int64> replica_count = absl::nullopt) {
absl::optional<int64> replica_count = absl::nullopt,
absl::optional<int64> num_partitions = absl::nullopt,
absl::string_view other_attributes = "") {
const char* kTemplate = R"(
HloModule test
add {
@ -899,18 +919,11 @@ StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
ENTRY entry {
p = f32[128]{0} parameter(0)
crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
OTHER_ATTRIBUTES
})";
HloModuleConfig config;
if (replica_count) {
config.set_replica_count(*replica_count);
} else {
config.set_replica_count(ReplicaCount(replica_groups));
}
return ParseAndReturnUnverifiedModule(
absl::StrReplaceAll(
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
config);
return MakeCollectiveCommOpComputation(replica_groups, replica_count,
num_partitions, other_attributes,
kTemplate);
}
TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
@ -947,33 +960,70 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica count in HloModuleConfig is 8, but "
"ReplicaGroup config contains 2 replicas"));
HasSubstr("In kCrossReplica mode, replica groups should contain "
"8 replicas, but found 2"));
}
TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica count in HloModuleConfig is 2, but "
"ReplicaGroup config contains 4 replicas"));
HasSubstr("In kCrossReplica mode, replica groups should contain "
"2 replicas, but found 4"));
}
TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
EXPECT_THAT(
verifier().Run(module.get()).status().error_message(),
HasSubstr(
"In kCrossReplicaAndPartition mode, replica groups should contain "
"2 replicas, but found 4"));
}
TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
"channel_id=1, use_global_device_ids=true"));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("In kFlattenedID mode, replica groups should contain "
"2 flattened IDs, but found 4"));
}
TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
"channel_id=1, use_global_device_ids=true"));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
std::vector<std::vector<int64>> replica_groups) {
std::vector<std::vector<int64>> replica_groups,
absl::optional<int64> replica_count = absl::nullopt,
absl::optional<int64> num_partitions = absl::nullopt,
absl::string_view other_attributes = "") {
const char* kTemplate = R"(
HloModule test
ENTRY entry {
p0 = f32[128]{0} parameter(0)
p1 = f32[128]{0} parameter(1)
a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
OTHER_ATTRIBUTES
})";
HloModuleConfig config;
config.set_replica_count(ReplicaCount(replica_groups));
return ParseAndReturnUnverifiedModule(
absl::StrReplaceAll(
kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}),
config);
return MakeCollectiveCommOpComputation(replica_groups, replica_count,
num_partitions, other_attributes,
kTemplate);
}
TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
@ -984,7 +1034,7 @@ TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("empty replica group"));
HasSubstr("cannot have an empty replica group"));
}
TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
@ -1001,11 +1051,27 @@ TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
HasSubstr("Replica 4 is not named"));
}
TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
TF_ASSERT_OK_AND_ASSIGN(auto module,
MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("Replica group has size 1"));
HasSubstr("Replica groups expected to be of uniform size"));
}
TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
HasSubstr("In kCrossPartition mode, replica groups should "
"contain 2 partitions, but found 4"));
}
TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
TF_ASSERT_OK_AND_ASSIGN(
auto module,
MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
TF_ASSERT_OK(verifier().Run(module.get()).status());
}
TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
@ -1316,6 +1382,30 @@ TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
use_global_device_ids=true, to_apply=add
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr("Replica groups must be specified in flattened-id mode"));
}
TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
const char* const hlo_string = R"(
HloModule Module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
@ -1326,9 +1416,10 @@ TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
HasSubstr("Replica group must be specified when "
"use_global_device_ids is true"));
EXPECT_THAT(
status.error_message(),
HasSubstr(
"Invalid combination of has_channel_id and use_global_device_ids"));
}
} // namespace

View File

@ -112,9 +112,9 @@ HloModule module
ENTRY entry {
param0 = s32[1,8]{1,0} parameter(0)
ag1 = s32[1,16]{1,0} all-gather(param0), replica_groups={{0,1}}, dimensions={1},
channel_id=0, use_global_device_ids=true
channel_id=0
ag2 = s32[1,16]{1,0} all-gather(param0), replica_groups={{0,1}},
dimensions={1}, use_global_device_ids=true
dimensions={1}
ROOT tuple = (s32[1,16]{1,0}, s32[1,16]{1,0}) tuple(ag1, ag2)
})";
auto module_status = RunPass(hlo_string);

View File

@ -126,9 +126,10 @@ std::unique_ptr<VerifiedHloModule> HloTestBase::CreateNewVerifiedModule(
StatusOr<std::unique_ptr<VerifiedHloModule>>
HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
int64 replica_count) {
return ParseAndReturnVerifiedModule(hlo_text,
GetModuleConfigForTest(replica_count));
int64 replica_count,
int64_t num_partitions) {
return ParseAndReturnVerifiedModule(
hlo_text, GetModuleConfigForTest(replica_count, num_partitions));
}
StatusOr<std::unique_ptr<VerifiedHloModule>>

View File

@ -89,7 +89,8 @@ class HloTestBase : public ManifestCheckingTest {
// Parses the given string and returns module as a VerifiedHloModule.
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text, int64 replica_count = 1);
absl::string_view hlo_text, int64 replica_count = 1,
int64 num_partitions = 1);
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config);
@ -135,10 +136,12 @@ class HloTestBase : public ManifestCheckingTest {
virtual DebugOptions GetDebugOptionsForTest();
// Gets an HloModuleConfig with options appropriate for tests.
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) {
HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1,
int64 num_partitions = 1) {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
config.set_replica_count(replica_count);
config.set_num_partitions(num_partitions);
return config;
}