[XLA] Make sure to return replicated sharding when all tiles are partially replicated

Propagating a fully replicated sharding as partially replicated with all the tiles
can cause spooky action at a distance in other parts of the code.

Removing test that used the invalid format of the sharding.

PiperOrigin-RevId: 354430571
Change-Id: If8492f2e946d5f1f23a5f8bb7015f6f4cf530ede
This commit is contained in:
Marcello Maggioni 2021-01-28 17:17:46 -08:00 committed by TensorFlower Gardener
parent dc7d2f1db4
commit c1202353c7
3 changed files with 41 additions and 29 deletions

View File

@ -67,7 +67,9 @@ HloSharding HloSharding::PartialTile(
HloSharding HloSharding::PartialTile(
const Array<int64>& tile_assignment_last_dim_replicate,
absl::Span<const OpMetadata> metadata) {
if (tile_assignment_last_dim_replicate.num_dimensions() == 1) {
if (tile_assignment_last_dim_replicate.num_dimensions() == 1 ||
tile_assignment_last_dim_replicate.dimensions().back() ==
tile_assignment_last_dim_replicate.num_elements()) {
return Replicate(metadata);
}
if (tile_assignment_last_dim_replicate.dimensions().back() == 1) {

View File

@ -4732,5 +4732,43 @@ ENTRY %module {
}
}
TEST_P(ParameterizedMetadataTest, CorrectlyReplicateGatherIndex) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = bf16[1,2,2,2,8]{4,3,2,1,0} parameter(0)
%parameter.1 = s32[1,2,2]{2,1,0} parameter(1)
%index = s32[1,2,2]{2,1,0} copy(%parameter.1)
%gather = bf16[1,2,2,2,8]{4,3,2,1,0} gather(
bf16[1,2,2,2,8]{4,3,2,1,0} %parameter.0, s32[1,2,2]{2,1,0} %index),
offset_dims={2,3,4}, collapsed_slice_dims={0,1}, start_index_map={0,1},
index_vector_dim=2, slice_sizes={1,1,2,2,8},
sharding={devices=[1,1,2,1,1]0,1 metadata={op_name="a"}}
ROOT %copy = bf16[1,2,2,2,8]{4,3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
.Run(module.get()));
EXPECT_TRUE(changed);
const HloInstruction* index = FindInstruction(module.get(), "index");
ASSERT_NE(index, nullptr);
EXPECT_THAT(index, op::Sharding("{replicated}"));
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(index->sharding(), ShardingMetadata({CreateMetadata("a")}));
} else {
EXPECT_THAT(index->sharding(), ShardingMetadata({}));
}
}
} // namespace
} // namespace xla

View File

@ -5730,34 +5730,6 @@ ENTRY entry {
EXPECT_THAT(root, tiled);
}
TEST_F(SpmdPartitioningTest,
PartialReplicateToPartialReplicateReshard_DynamicSlice2) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8] parameter(0)
%copy = f32[8,8] copy(%param0),
sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
ROOT %copy0 = f32[8,8] copy(%copy),
sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
auto partially_replicated =
AllOf(op::Shape("f32[8,8]"),
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Constant())));
auto tiled =
AllOf(op::Shape("f32[4,4]"),
op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
op::Subtract())));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, tiled);
}
TEST_F(SpmdPartitioningTest,
PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
absl::string_view hlo_string = R"(