Allow more sharding propagation from broadcast operands

MaybeImproveInstructionSharding() already checks that it needs to be compatible with existing tiled sharding (if any).

PiperOrigin-RevId: 356375771
Change-Id: Ic5e65e00cea0a25cec2cf93027632c4c7783a20e
This commit is contained in:
Yuanzhong Xu 2021-02-08 16:08:37 -08:00 committed by TensorFlower Gardener
parent 324ab95fb0
commit 18f2dd4262
2 changed files with 31 additions and 5 deletions

View File

@ -739,11 +739,6 @@ bool InferShardingFromOperands(HloInstruction* instruction,
if (aggressiveness < 3) {
return false;
}
// Do not override existing tile sharding. This is likely from users.
if (IsSpatiallyPartitioned(instruction) &&
!instruction->sharding().IsTileMaximal()) {
return false;
}
const HloInstruction* op = instruction->operand(0);
if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
return false;

View File

@ -392,6 +392,37 @@ ENTRY %broadcast {
}
}
TEST_P(ParameterizedMetadataTest, BroadcastMerge) {
const char* const hlo_string = R"(
HloModule module
ENTRY %broadcast {
%param0 = f32[3,2048]parameter(0),
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate metadata={op_name="a"}}
%broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1}
ROOT %copy = f32[3,2048,3] copy(%broadcast),
sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate metadata={op_name="b"}}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
.Run(module.get()));
EXPECT_TRUE(changed);
auto* instruction = FindInstruction(module.get(), "broadcast");
ASSERT_NE(instruction, nullptr);
EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}"));
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(instruction->sharding(),
ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")}));
} else {
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
}
}
TEST_P(ParameterizedMetadataTest, BroadcastUser) {
const char* const hlo_string = R"(
HloModule module