diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
index f765ee5ecc2..a346d8778d6 100644
--- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc
+++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc
@@ -289,6 +289,12 @@ StatusOr<HloInstruction*> PartitionBaseCase(
           to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
               LiteralUtil::Zero(output_base_shape.element_type()))));
     }
+    if (operands_sharded_at_contracting_dims) {
+      auto zero = b->AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::Zero(output_base_shape.element_type())));
+      lhs = lhs.PadWithValue(zero);
+      rhs = rhs.PadWithValue(zero);
+    }
     auto result_buffer = CreateZero(padded_result_buffer_shape, b);
     auto iteration = b->AddInstruction(
         HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
@@ -333,57 +339,28 @@ StatusOr<HloInstruction*> PartitionBaseCase(
     if (windowed_at_contracting_dims || windowed_at_batch_dims ||
         operands_sharded_at_contracting_dims) {
       // Slice the matching operand according to the partitioned dimensions on
-      // the windowed operand.
+      // the windowed operand or the output.
       auto slice_operand = matching_operand == 0 ? l : r;
-      HloInstruction* slice;
+      // We do this by treating the matching operand as replicated, and
+      // resharding it to match the windowed operand or the output.
+      slice_operand->set_sharding(HloSharding::Replicate());
+      auto state = lhs.state();
+      state.b = &body_b;
+      state.partition_id = data_partition_id;
+      const HloSharding* slice_sharding;
       if (operands_sharded_at_contracting_dims) {
-        CHECK_NE(output_sharding_dim, -1);
-        int64 output_sharding_dim_size =
-            o->shape().dimensions(output_sharding_dim);
-        int64 slice_dim = matching_operand == 0
-                              ? output_to_lhs_indices[output_sharding_dim]
-                              : output_to_rhs_indices[output_sharding_dim];
-        auto slice_shape = slice_operand->shape();
-        slice_shape.set_dimensions(slice_dim, output_sharding_dim_size);
-        std::vector<HloInstruction*> slice_offsets(slice_shape.rank());
-        for (int64 i = 0; i < slice_offsets.size(); ++i) {
-          if (i != slice_dim) {
-            slice_offsets[i] =
-                body_b.AddInstruction(HloInstruction::CreateConstant(
-                    LiteralUtil::CreateR0<uint32>(0)));
-          } else {
-            auto stride = body_b.AddInstruction(HloInstruction::CreateConstant(
-                LiteralUtil::CreateR0<uint32>(output_sharding_dim_size)));
-            slice_offsets[i] =
-                body_b.AddInstruction(HloInstruction::CreateBinary(
-                    data_partition_id->shape(), HloOpcode::kMultiply,
-                    data_partition_id, stride));
-          }
-        }
-        auto padded_shape = slice_operand->shape();
-        padded_shape.set_dimensions(
-            slice_dim,
-            o->shape().dimensions(output_sharding_dim) * num_partitions);
-        auto padded_slice_operand =
-            PadToShape(slice_operand, padded_shape, &body_b);
-        slice = body_b.AddInstruction(HloInstruction::CreateDynamicSlice(
-            slice_shape, padded_slice_operand, slice_offsets,
-            slice_shape.dimensions()));
+        slice_sharding = windowing_operand == 0
+                             ? &*output_sharding_transposed_to_match_rhs
+                             : &*output_sharding_transposed_to_match_lhs;
       } else {
-        // For windowed operand that partitioned along contracting dimensions,
-        // we do this by treating the matching operand as replicated, and
-        // resharding it to match the windowed operand.
-        slice_operand->set_sharding(HloSharding::Replicate());
-        auto state = lhs.state();
-        state.b = &body_b;
-        state.partition_id = data_partition_id;
-        slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
-                    .Reshard(windowing_operand == 0
-                                 ? *lhs_sharding_transposed_to_match_rhs
-                                 : *rhs_sharding_transposed_to_match_lhs)
-                    .hlo();
-        slice_operand->clear_sharding();
+        slice_sharding = windowing_operand == 0
+                             ? &*lhs_sharding_transposed_to_match_rhs
+                             : &*rhs_sharding_transposed_to_match_lhs;
       }
+      auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
+                       .Reshard(*slice_sharding)
+                       .hlo();
+      slice_operand->clear_sharding();
       if (matching_operand == 0) {
         dot_lhs = slice;
       } else {
diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
index 91a0c44b51a..e4bd272e361 100644
--- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
+++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
@@ -3818,7 +3818,7 @@ ENTRY entry {
   auto ds =
       AllOf(op::DynamicSlice(
                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
-                op::Constant(), op::Multiply(), op::Constant(), op::Constant()),
+                op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
             op::Shape("f32[320,7,16,128]"));
   auto partial_output =
       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
@@ -3909,7 +3909,7 @@ ENTRY entry {
   auto ds =
       AllOf(op::DynamicSlice(
                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
-                op::Constant(), op::Multiply(), op::Constant()),
+                op::Constant(), op::Reshape(), op::Constant()),
             op::Shape("f32[4096,17,128]"));
   auto partial_output =
       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),