diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index c7ece50d202..49caa2ee347 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -739,11 +739,6 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (aggressiveness < 3) { return false; } - // Do not override existing tile sharding. This is likely from users. - if (IsSpatiallyPartitioned(instruction) && - !instruction->sharding().IsTileMaximal()) { - return false; - } const HloInstruction* op = instruction->operand(0); if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { return false; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index eed6de9b5b7..f91ecec8b9f 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -392,6 +392,37 @@ ENTRY %broadcast { } } +TEST_P(ParameterizedMetadataTest, BroadcastMerge) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[3,2048]parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate metadata={op_name="a"}} + %broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1} + ROOT %copy = f32[3,2048,3] copy(%broadcast), + sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate metadata={op_name="b"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + EXPECT_TRUE(changed); + auto* instruction = FindInstruction(module.get(), "broadcast"); + ASSERT_NE(instruction, nullptr); + EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}")); + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(instruction->sharding(), + ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + } else { + EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); + } +} + TEST_P(ParameterizedMetadataTest, BroadcastUser) { const char* const hlo_string = R"( HloModule module