Allow more sharding propagation from broadcast operands
MaybeImproveInstructionSharding() already checks that it needs to be compatible with existing tiled sharding (if any). PiperOrigin-RevId: 356375771 Change-Id: Ic5e65e00cea0a25cec2cf93027632c4c7783a20e
This commit is contained in:
parent
324ab95fb0
commit
18f2dd4262
tensorflow/compiler/xla/service
@ -739,11 +739,6 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
if (aggressiveness < 3) {
|
if (aggressiveness < 3) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Do not override existing tile sharding. This is likely from users.
|
|
||||||
if (IsSpatiallyPartitioned(instruction) &&
|
|
||||||
!instruction->sharding().IsTileMaximal()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const HloInstruction* op = instruction->operand(0);
|
const HloInstruction* op = instruction->operand(0);
|
||||||
if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
|
if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -392,6 +392,37 @@ ENTRY %broadcast {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedMetadataTest, BroadcastMerge) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
ENTRY %broadcast {
|
||||||
|
%param0 = f32[3,2048]parameter(0),
|
||||||
|
sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate metadata={op_name="a"}}
|
||||||
|
%broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1}
|
||||||
|
ROOT %copy = f32[3,2048,3] copy(%broadcast),
|
||||||
|
sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate metadata={op_name="b"}}
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
if (GetParam().clear_metadata) {
|
||||||
|
ClearMetadata(module.get());
|
||||||
|
}
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed,
|
||||||
|
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
|
||||||
|
.Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
auto* instruction = FindInstruction(module.get(), "broadcast");
|
||||||
|
ASSERT_NE(instruction, nullptr);
|
||||||
|
EXPECT_THAT(instruction, op::Sharding("{devices=[1,2,2]0,1,2,3}"));
|
||||||
|
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
||||||
|
EXPECT_THAT(instruction->sharding(),
|
||||||
|
ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")}));
|
||||||
|
} else {
|
||||||
|
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(ParameterizedMetadataTest, BroadcastUser) {
|
TEST_P(ParameterizedMetadataTest, BroadcastUser) {
|
||||||
const char* const hlo_string = R"(
|
const char* const hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
Loading…
Reference in New Issue
Block a user