[XLA] Make sure to return replicated sharding when all tiles are partially replicated
Propagating a fully replicated sharding as partially replicated with all the tiles can cause spooky action at a distance in other parts of the code. Removing test that used the invalid format of the sharding. PiperOrigin-RevId: 354430571 Change-Id: If8492f2e946d5f1f23a5f8bb7015f6f4cf530ede
This commit is contained in:
parent
dc7d2f1db4
commit
c1202353c7
@ -67,7 +67,9 @@ HloSharding HloSharding::PartialTile(
|
|||||||
HloSharding HloSharding::PartialTile(
|
HloSharding HloSharding::PartialTile(
|
||||||
const Array<int64>& tile_assignment_last_dim_replicate,
|
const Array<int64>& tile_assignment_last_dim_replicate,
|
||||||
absl::Span<const OpMetadata> metadata) {
|
absl::Span<const OpMetadata> metadata) {
|
||||||
if (tile_assignment_last_dim_replicate.num_dimensions() == 1) {
|
if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
|
||||||
|
tile_assignment_last_dim_replicate.dimensions().back() ==
|
||||||
|
tile_assignment_last_dim_replicate.num_elements()) {
|
||||||
return Replicate(metadata);
|
return Replicate(metadata);
|
||||||
}
|
}
|
||||||
if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
|
if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {
|
||||||
|
@ -4732,5 +4732,43 @@ ENTRY %module {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(ParameterizedMetadataTest, CorrectlyReplicateGatherIndex) {
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = bf16[1,2,2,2,8]{4,3,2,1,0} parameter(0)
|
||||||
|
%parameter.1 = s32[1,2,2]{2,1,0} parameter(1)
|
||||||
|
%index = s32[1,2,2]{2,1,0} copy(%parameter.1)
|
||||||
|
%gather = bf16[1,2,2,2,8]{4,3,2,1,0} gather(
|
||||||
|
bf16[1,2,2,2,8]{4,3,2,1,0} %parameter.0, s32[1,2,2]{2,1,0} %index),
|
||||||
|
offset_dims={2,3,4}, collapsed_slice_dims={0,1}, start_index_map={0,1},
|
||||||
|
index_vector_dim=2, slice_sizes={1,1,2,2,8},
|
||||||
|
sharding={devices=[1,1,2,1,1]0,1 metadata={op_name="a"}}
|
||||||
|
ROOT %copy = bf16[1,2,2,2,8]{4,3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
if (GetParam().clear_metadata) {
|
||||||
|
ClearMetadata(module.get());
|
||||||
|
}
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed,
|
||||||
|
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
|
||||||
|
.Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
|
||||||
|
const HloInstruction* index = FindInstruction(module.get(), "index");
|
||||||
|
|
||||||
|
ASSERT_NE(index, nullptr);
|
||||||
|
|
||||||
|
EXPECT_THAT(index, op::Sharding("{replicated}"));
|
||||||
|
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
|
||||||
|
EXPECT_THAT(index->sharding(), ShardingMetadata({CreateMetadata("a")}));
|
||||||
|
} else {
|
||||||
|
EXPECT_THAT(index->sharding(), ShardingMetadata({}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -5730,34 +5730,6 @@ ENTRY entry {
|
|||||||
EXPECT_THAT(root, tiled);
|
EXPECT_THAT(root, tiled);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest,
|
|
||||||
PartialReplicateToPartialReplicateReshard_DynamicSlice2) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule module
|
|
||||||
|
|
||||||
ENTRY entry {
|
|
||||||
%param0 = f32[8,8] parameter(0)
|
|
||||||
%copy = f32[8,8] copy(%param0),
|
|
||||||
sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
|
||||||
ROOT %copy0 = f32[8,8] copy(%copy),
|
|
||||||
sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
|
||||||
})";
|
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
||||||
PartitionComputation(hlo_string, /*num_devices=*/8));
|
|
||||||
VLOG(1) << module->ToString();
|
|
||||||
auto partially_replicated =
|
|
||||||
AllOf(op::Shape("f32[8,8]"),
|
|
||||||
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
|
|
||||||
op::Constant())));
|
|
||||||
auto tiled =
|
|
||||||
AllOf(op::Shape("f32[4,4]"),
|
|
||||||
op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
|
|
||||||
op::Subtract())));
|
|
||||||
auto root = module->entry_computation()->root_instruction();
|
|
||||||
EXPECT_THAT(root, tiled);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest,
|
TEST_F(SpmdPartitioningTest,
|
||||||
PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
|
PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
|
||||||
absl::string_view hlo_string = R"(
|
absl::string_view hlo_string = R"(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user