[XLA] Make sure to return replicated sharding when all tiles are partially replicated
Propagating a fully replicated sharding as partially replicated with all the tiles can cause spooky action at a distance in other parts of the code. Removing test that used the invalid format of the sharding. PiperOrigin-RevId: 354430571 Change-Id: If8492f2e946d5f1f23a5f8bb7015f6f4cf530ede
This commit is contained in:
parent
dc7d2f1db4
commit
c1202353c7
@ -67,7 +67,9 @@ HloSharding HloSharding::PartialTile(
|
||||
HloSharding HloSharding::PartialTile(
|
||||
const Array<int64>& tile_assignment_last_dim_replicate,
|
||||
absl::Span<const OpMetadata> metadata) {
|
||||
if (tile_assignment_last_dim_replicate.num_dimensions() == 1) {
|
||||
if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
|
||||
tile_assignment_last_dim_replicate.dimensions().back() ==
|
||||
tile_assignment_last_dim_replicate.num_elements()) {
|
||||
return Replicate(metadata);
|
||||
}
|
||||
if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
|
||||
|
@ -4732,5 +4732,43 @@ ENTRY %module {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedMetadataTest, CorrectlyReplicateGatherIndex) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = bf16[1,2,2,2,8]{4,3,2,1,0} parameter(0)
|
||||
%parameter.1 = s32[1,2,2]{2,1,0} parameter(1)
|
||||
%index = s32[1,2,2]{2,1,0} copy(%parameter.1)
|
||||
%gather = bf16[1,2,2,2,8]{4,3,2,1,0} gather(
|
||||
bf16[1,2,2,2,8]{4,3,2,1,0} %parameter.0, s32[1,2,2]{2,1,0} %index),
|
||||
offset_dims={2,3,4}, collapsed_slice_dims={0,1}, start_index_map={0,1},
|
||||
index_vector_dim=2, slice_sizes={1,1,2,2,8},
|
||||
sharding={devices=[1,1,2,1,1]0,1 metadata={op_name="a"}}
|
||||
ROOT %copy = bf16[1,2,2,2,8]{4,3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
if (GetParam().clear_metadata) {
|
||||
ClearMetadata(module.get());
|
||||
}
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed,
|
||||
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
|
||||
.Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
const HloInstruction* index = FindInstruction(module.get(), "index");
|
||||
|
||||
ASSERT_NE(index, nullptr);
|
||||
|
||||
EXPECT_THAT(index, op::Sharding("{replicated}"));
|
||||
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
||||
EXPECT_THAT(index->sharding(), ShardingMetadata({CreateMetadata("a")}));
|
||||
} else {
|
||||
EXPECT_THAT(index->sharding(), ShardingMetadata({}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -5730,34 +5730,6 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, tiled);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
PartialReplicateToPartialReplicateReshard_DynamicSlice2) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0)
|
||||
%copy = f32[8,8] copy(%param0),
|
||||
sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
||||
ROOT %copy0 = f32[8,8] copy(%copy),
|
||||
sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/8));
|
||||
VLOG(1) << module->ToString();
|
||||
auto partially_replicated =
|
||||
AllOf(op::Shape("f32[8,8]"),
|
||||
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
|
||||
op::Constant())));
|
||||
auto tiled =
|
||||
AllOf(op::Shape("f32[4,4]"),
|
||||
op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
|
||||
op::Subtract())));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, tiled);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
|
||||
absl::string_view hlo_string = R"(
|
||||
|
Loading…
x
Reference in New Issue
Block a user