diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc
index c7ece50d202..49caa2ee347 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation.cc
@@ -739,11 +739,6 @@ bool InferShardingFromOperands(HloInstruction* instruction,
       if (aggressiveness < 3) {
         return false;
       }
-      // Do not override existing tile sharding. This is likely from users.
-      if (IsSpatiallyPartitioned(instruction) &&
-          !instruction->sharding().IsTileMaximal()) {
-        return false;
-      }
       const HloInstruction* op = instruction->operand(0);
       if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
         return false;
diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
index eed6de9b5b7..f91ecec8b9f 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc
@@ -392,6 +392,37 @@ ENTRY %broadcast {
   }
 }
 
+TEST_P(ParameterizedMetadataTest, BroadcastMerge) {
+  const char* const hlo_string = R"(
+HloModule module
+ENTRY %broadcast {
+  %param0 = f32[3,2048]parameter(0),
+    sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate metadata={op_name="a"}}
+  %broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1}
+  ROOT %copy = f32[3,2048,3] copy(%broadcast),
+    sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate metadata={op_name="b"}}
+})";
+  TF_ASSERT_OK_AND_ASSIGN(auto module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  if (GetParam().clear_metadata) {
+    ClearMetadata(module.get());
+  }
+  TF_ASSERT_OK_AND_ASSIGN(
+      bool changed,
+      ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
+          .Run(module.get()));
+  EXPECT_TRUE(changed);
+  auto* instruction = FindInstruction(module.get(), "broadcast");
+  ASSERT_NE(instruction, nullptr);
+  EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}"));
+  if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
+    EXPECT_THAT(instruction->sharding(),
+                ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")}));
+  } else {
+    EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
+  }
+}
+
 TEST_P(ParameterizedMetadataTest, BroadcastUser) {
   const char* const hlo_string = R"(
 HloModule module