[XLA:SPMD] Fix vector indices sharding in GatherIndexSharding.

PiperOrigin-RevId: 326781481
Change-Id: I95945e894c3ecac0d74fa0a4f51570d70b43ebf3
This commit is contained in:
A. Unique TensorFlower 2020-08-14 21:54:17 -07:00 committed by TensorFlower Gardener
parent 9e416e14c0
commit 71ff216576
4 changed files with 31 additions and 8 deletions

View File

@ -357,8 +357,8 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
}
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
const HloInstruction* hlo,
const int64 index_rank) {
const HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kGather);
if (output_sharding.IsTileMaximal()) {
return output_sharding;
}
@ -371,10 +371,12 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
output_sharding.tile_assignment().dim(i));
}
}
int64 index_rank = hlo->operand(1)->shape().rank();
// Vector indices sharding is not supported yet.
for (int64 i = dnums.index_vector_dim(); i < index_rank; ++i) {
index_tile_assignment_dims.push_back(1);
if (index_rank > index_tile_assignment_dims.size()) {
index_tile_assignment_dims.insert(
index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
}
Array<int64> new_tile_assignment = output_sharding.tile_assignment();

View File

@ -95,8 +95,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
// Returns the preferred index sharding for a gather op based on the sharding
// of the output.
HloSharding GatherIndexSharding(const HloSharding& output_sharding,
const HloInstruction* hlo,
const int64 index_rank);
const HloInstruction* hlo);
// Returns a new HloSharding for a gather op so that only non offset dimensions
// are sharded. Assume "result" is returned by this function. It is ensured that

View File

@ -1319,8 +1319,7 @@ absl::optional<HloSharding> GetShardingFromUser(
}
case HloOpcode::kGather: {
if (&instruction == user.operand(1)) {
return hlo_sharding_util::GatherIndexSharding(
user.sharding(), &user, instruction.shape().rank());
return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
}
if (is_spmd) {
return hlo_sharding_util::GatherDataOperandShardingFromOutput(

View File

@ -1839,6 +1839,29 @@ ENTRY entry {
op::Sharding("{devices=[1,2,1]0,1}"));
}
TEST_F(ShardingPropagationTest, GatherToIndex3) {
const char* hlo_string = R"(
HloModule module
ENTRY entry {
%input = bf16[2,4819,4] parameter(0), sharding={replicated}
%p1 = s32[2,2,1000] parameter(1)
%indices = s32[2,2,1000] copy(%p1)
ROOT %gather = bf16[2,1000,4]
gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices),
offset_dims={2}, collapsed_slice_dims={0,1},
start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4},
sharding={devices=[1,2,1]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
ShardingPropagation().Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "indices"),
op::Sharding("{devices=[1,1,2]0,1}"));
}
TEST_F(ShardingPropagationTest, GatherToDataOperand) {
const char* hlo_string = R"(
HloModule module