[XLA:SPMD] Fix vector indices sharding in GatherIndexSharding.
PiperOrigin-RevId: 326781481 Change-Id: I95945e894c3ecac0d74fa0a4f51570d70b43ebf3
This commit is contained in:
parent
9e416e14c0
commit
71ff216576
tensorflow/compiler/xla/service
@ -357,8 +357,8 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
||||
}
|
||||
|
||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||
const HloInstruction* hlo,
|
||||
const int64 index_rank) {
|
||||
const HloInstruction* hlo) {
|
||||
CHECK(hlo->opcode() == HloOpcode::kGather);
|
||||
if (output_sharding.IsTileMaximal()) {
|
||||
return output_sharding;
|
||||
}
|
||||
@ -371,10 +371,12 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||
output_sharding.tile_assignment().dim(i));
|
||||
}
|
||||
}
|
||||
int64 index_rank = hlo->operand(1)->shape().rank();
|
||||
|
||||
// Vector indices sharding is not supported yet.
|
||||
for (int64 i = dnums.index_vector_dim(); i < index_rank; ++i) {
|
||||
index_tile_assignment_dims.push_back(1);
|
||||
if (index_rank > index_tile_assignment_dims.size()) {
|
||||
index_tile_assignment_dims.insert(
|
||||
index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
|
||||
}
|
||||
|
||||
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
|
||||
|
@ -95,8 +95,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
|
||||
// Returns the preferred index sharding for a gather op based on the sharding
|
||||
// of the output.
|
||||
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
|
||||
const HloInstruction* hlo,
|
||||
const int64 index_rank);
|
||||
const HloInstruction* hlo);
|
||||
|
||||
// Returns a new HloSharding for a gather op so that only non offset dimensions
|
||||
// are sharded. Assume "result" is returned by this function. It is ensured that
|
||||
|
@ -1319,8 +1319,7 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
if (&instruction == user.operand(1)) {
|
||||
return hlo_sharding_util::GatherIndexSharding(
|
||||
user.sharding(), &user, instruction.shape().rank());
|
||||
return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
|
||||
}
|
||||
if (is_spmd) {
|
||||
return hlo_sharding_util::GatherDataOperandShardingFromOutput(
|
||||
|
@ -1839,6 +1839,29 @@ ENTRY entry {
|
||||
op::Sharding("{devices=[1,2,1]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, GatherToIndex3) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%input = bf16[2,4819,4] parameter(0), sharding={replicated}
|
||||
%p1 = s32[2,2,1000] parameter(1)
|
||||
%indices = s32[2,2,1000] copy(%p1)
|
||||
ROOT %gather = bf16[2,1000,4]
|
||||
gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices),
|
||||
offset_dims={2}, collapsed_slice_dims={0,1},
|
||||
start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4},
|
||||
sharding={devices=[1,2,1]0,1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
ShardingPropagation().Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "indices"),
|
||||
op::Sharding("{devices=[1,1,2]0,1}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
|
||||
const char* hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
Reference in New Issue
Block a user