diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc
index 2aa2087b2f3..e1e506b2892 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc
@@ -357,8 +357,8 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
 }
 
 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
-                                const HloInstruction* hlo,
-                                const int64 index_rank) {
+                                const HloInstruction* hlo) {
+  CHECK(hlo->opcode() == HloOpcode::kGather);
   if (output_sharding.IsTileMaximal()) {
     return output_sharding;
   }
@@ -371,10 +371,12 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
           output_sharding.tile_assignment().dim(i));
     }
   }
+  int64 index_rank = hlo->operand(1)->shape().rank();
 
   // Vector indices sharding is not supported yet.
-  for (int64 i = dnums.index_vector_dim(); i < index_rank; ++i) {
-    index_tile_assignment_dims.push_back(1);
+  if (index_rank > index_tile_assignment_dims.size()) {
+    index_tile_assignment_dims.insert(
+        index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
   }
 
   Array<int64> new_tile_assignment = output_sharding.tile_assignment();
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h
index 95465be24b3..0de01fcab7e 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_util.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h
@@ -95,8 +95,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding,
 // Returns the preferred index sharding for a gather op based on the sharding
 // of the output.
 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
-                                const HloInstruction* hlo,
-                                const int64 index_rank);
+                                const HloInstruction* hlo);
 
 // Returns a new HloSharding for a gather op so that only non offset dimensions
 // are sharded. Assume "result" is returned by this function. It is ensured that
diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc
index 408fdfb7612..7aae3e4dc0e 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation.cc
@@ -1319,8 +1319,7 @@ absl::optional<HloSharding> GetShardingFromUser(
     }
     case HloOpcode::kGather: {
       if (&instruction == user.operand(1)) {
-        return hlo_sharding_util::GatherIndexSharding(
-            user.sharding(), &user, instruction.shape().rank());
+        return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
       }
       if (is_spmd) {
         return hlo_sharding_util::GatherDataOperandShardingFromOutput(
diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
index e41bda32537..03c77c2038c 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
@@ -1839,6 +1839,29 @@ ENTRY entry {
               op::Sharding("{devices=[1,2,1]0,1}"));
 }
 
+TEST_F(ShardingPropagationTest, GatherToIndex3) {
+  const char* hlo_string = R"(
+HloModule module
+
+ENTRY entry {
+  %input = bf16[2,4819,4] parameter(0), sharding={replicated}
+  %p1 = s32[2,2,1000] parameter(1)
+  %indices = s32[2,2,1000] copy(%p1)
+  ROOT %gather = bf16[2,1000,4]
+    gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices),
+    offset_dims={2}, collapsed_slice_dims={0,1},
+    start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4},
+    sharding={devices=[1,2,1]0,1}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          ShardingPropagation().Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_THAT(FindInstruction(module.get(), "indices"),
+              op::Sharding("{devices=[1,1,2]0,1}"));
+}
+
 TEST_F(ShardingPropagationTest, GatherToDataOperand) {
   const char* hlo_string = R"(
 HloModule module