[XLA:SPMD] Fix vector indices sharding in GatherIndexSharding.
PiperOrigin-RevId: 326781481 Change-Id: I95945e894c3ecac0d74fa0a4f51570d70b43ebf3
This commit is contained in:
parent
9e416e14c0
commit
71ff216576
@ -357,8 +357,8 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
|||||||
}
|
}
|
||||||
|
|
||||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo) {
|
||||||
const int64 index_rank) {
|
CHECK(hlo->opcode() == HloOpcode::kGather);
|
||||||
if (output_sharding.IsTileMaximal()) {
|
if (output_sharding.IsTileMaximal()) {
|
||||||
return output_sharding;
|
return output_sharding;
|
||||||
}
|
}
|
||||||
@ -371,10 +371,12 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
|||||||
output_sharding.tile_assignment().dim(i));
|
output_sharding.tile_assignment().dim(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int64 index_rank = hlo->operand(1)->shape().rank();
|
||||||
|
|
||||||
// Vector indices sharding is not supported yet.
|
// Vector indices sharding is not supported yet.
|
||||||
for (int64 i = dnums.index_vector_dim(); i < index_rank; ++i) {
|
if (index_rank > index_tile_assignment_dims.size()) {
|
||||||
index_tile_assignment_dims.push_back(1);
|
index_tile_assignment_dims.insert(
|
||||||
|
index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
|
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
|
||||||
|
@ -95,8 +95,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
|||||||
// Returns the preferred index sharding for a gather op based on the sharding
|
// Returns the preferred index sharding for a gather op based on the sharding
|
||||||
// of the output.
|
// of the output.
|
||||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||||
const HloInstruction* hlo,
|
const HloInstruction* hlo);
|
||||||
const int64 index_rank);
|
|
||||||
|
|
||||||
// Returns a new HloSharding for a gather op so that only non offset dimensions
|
// Returns a new HloSharding for a gather op so that only non offset dimensions
|
||||||
// are sharded. Assume "result" is returned by this function. It is ensured that
|
// are sharded. Assume "result" is returned by this function. It is ensured that
|
||||||
|
@ -1319,8 +1319,7 @@ absl::optional<HloSharding> GetShardingFromUser(
|
|||||||
}
|
}
|
||||||
case HloOpcode::kGather: {
|
case HloOpcode::kGather: {
|
||||||
if (&instruction == user.operand(1)) {
|
if (&instruction == user.operand(1)) {
|
||||||
return hlo_sharding_util::GatherIndexSharding(
|
return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
|
||||||
user.sharding(), &user, instruction.shape().rank());
|
|
||||||
}
|
}
|
||||||
if (is_spmd) {
|
if (is_spmd) {
|
||||||
return hlo_sharding_util::GatherDataOperandShardingFromOutput(
|
return hlo_sharding_util::GatherDataOperandShardingFromOutput(
|
||||||
|
@ -1839,6 +1839,29 @@ ENTRY entry {
|
|||||||
op::Sharding("{devices=[1,2,1]0,1}"));
|
op::Sharding("{devices=[1,2,1]0,1}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, GatherToIndex3) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
%input = bf16[2,4819,4] parameter(0), sharding={replicated}
|
||||||
|
%p1 = s32[2,2,1000] parameter(1)
|
||||||
|
%indices = s32[2,2,1000] copy(%p1)
|
||||||
|
ROOT %gather = bf16[2,1000,4]
|
||||||
|
gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices),
|
||||||
|
offset_dims={2}, collapsed_slice_dims={0,1},
|
||||||
|
start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4},
|
||||||
|
sharding={devices=[1,2,1]0,1}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||||
|
ShardingPropagation().Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "indices"),
|
||||||
|
op::Sharding("{devices=[1,1,2]0,1}"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
|
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
|
||||||
const char* hlo_string = R"(
|
const char* hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
Loading…
Reference in New Issue
Block a user