From 71ff2165766a85de539ada60ab7c3f7e3dfc4dc6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 Aug 2020 21:54:17 -0700 Subject: [PATCH] [XLA:SPMD] Fix vector indices sharding in GatherIndexSharding. PiperOrigin-RevId: 326781481 Change-Id: I95945e894c3ecac0d74fa0a4f51570d70b43ebf3 --- .../compiler/xla/service/hlo_sharding_util.cc | 10 ++++---- .../compiler/xla/service/hlo_sharding_util.h | 3 +-- .../xla/service/sharding_propagation.cc | 3 +-- .../xla/service/sharding_propagation_test.cc | 23 +++++++++++++++++++ 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index 2aa2087b2f3..e1e506b2892 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -357,8 +357,8 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding, } HloSharding GatherIndexSharding(const HloSharding& output_sharding, - const HloInstruction* hlo, - const int64 index_rank) { + const HloInstruction* hlo) { + CHECK(hlo->opcode() == HloOpcode::kGather); if (output_sharding.IsTileMaximal()) { return output_sharding; } @@ -371,10 +371,12 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding, output_sharding.tile_assignment().dim(i)); } } + int64 index_rank = hlo->operand(1)->shape().rank(); // Vector indices sharding is not supported yet. - for (int64 i = dnums.index_vector_dim(); i < index_rank; ++i) { - index_tile_assignment_dims.push_back(1); + if (index_rank > index_tile_assignment_dims.size()) { + index_tile_assignment_dims.insert( + index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1); } Array new_tile_assignment = output_sharding.tile_assignment(); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h index 95465be24b3..0de01fcab7e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -95,8 +95,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding, // Returns the preferred index sharding for a gather op based on the sharding // of the output. HloSharding GatherIndexSharding(const HloSharding& output_sharding, - const HloInstruction* hlo, - const int64 index_rank); + const HloInstruction* hlo); // Returns a new HloSharding for a gather op so that only non offset dimensions // are sharded. Assume "result" is returned by this function. It is ensured that diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 408fdfb7612..7aae3e4dc0e 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -1319,8 +1319,7 @@ absl::optional GetShardingFromUser( } case HloOpcode::kGather: { if (&instruction == user.operand(1)) { - return hlo_sharding_util::GatherIndexSharding( - user.sharding(), &user, instruction.shape().rank()); + return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user); } if (is_spmd) { return hlo_sharding_util::GatherDataOperandShardingFromOutput( diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index e41bda32537..03c77c2038c 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -1839,6 +1839,29 @@ ENTRY entry { op::Sharding("{devices=[1,2,1]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToIndex3) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,2,1000] parameter(1) + %indices = s32[2,2,1000] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + TEST_F(ShardingPropagationTest, GatherToDataOperand) { const char* hlo_string = R"( HloModule module