Allow more sharding propagation from broadcast operands
MaybeImproveInstructionSharding() already checks that it needs to be compatible with existing tiled sharding (if any). PiperOrigin-RevId: 356375771 Change-Id: Ic5e65e00cea0a25cec2cf93027632c4c7783a20e
This commit is contained in:
parent
324ab95fb0
commit
18f2dd4262
@ -739,11 +739,6 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
if (aggressiveness < 3) {
|
||||
return false;
|
||||
}
|
||||
// Do not override existing tile sharding. This is likely from users.
|
||||
if (IsSpatiallyPartitioned(instruction) &&
|
||||
!instruction->sharding().IsTileMaximal()) {
|
||||
return false;
|
||||
}
|
||||
const HloInstruction* op = instruction->operand(0);
|
||||
if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
|
||||
return false;
|
||||
|
@ -392,6 +392,37 @@ ENTRY %broadcast {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedMetadataTest, BroadcastMerge) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
ENTRY %broadcast {
|
||||
%param0 = f32[3,2048]parameter(0),
|
||||
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate metadata={op_name="a"}}
|
||||
%broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1}
|
||||
ROOT %copy = f32[3,2048,3] copy(%broadcast),
|
||||
sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate metadata={op_name="b"}}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
if (GetParam().clear_metadata) {
|
||||
ClearMetadata(module.get());
|
||||
}
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed,
|
||||
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
|
||||
.Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
auto* instruction = FindInstruction(module.get(), "broadcast");
|
||||
ASSERT_NE(instruction, nullptr);
|
||||
EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}"));
|
||||
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
||||
EXPECT_THAT(instruction->sharding(),
|
||||
ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")}));
|
||||
} else {
|
||||
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedMetadataTest, BroadcastUser) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
Reference in New Issue
Block a user