[XLA:SPMD] Fix bug in partial replicate resharding.

PiperOrigin-RevId: 327906446
Change-Id: Ifb39e840f9a8b66a52160f971007f6672f59f679
This commit is contained in:
A. Unique TensorFlower 2020-08-21 18:15:57 -07:00 committed by TensorFlower Gardener
parent 83c96f6b8c
commit 41f429e7b5
3 changed files with 33 additions and 7 deletions

View File

@ -879,7 +879,6 @@ PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
// Add another dimension in tiling_dim_factors if target is partial replicate.
if (target.ReplicateOnLastTileDim()) {
tiling_dim_factors.emplace_back(
sharding().tile_assignment().dimensions().back() /
target.tile_assignment().dimensions().back());
}

View File

@ -4938,6 +4938,33 @@ ENTRY entry {
EXPECT_THAT(root, tiled);
}
TEST_F(SpmdPartitioningTest,
PartialReplicateToPartialReplicateReshard_DynamicSlice2) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8] parameter(0)
%copy = f32[8,8] copy(%param0),
sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
ROOT %copy0 = f32[8,8] copy(%copy),
sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
VLOG(1) << module->ToString();
auto partially_replicated =
AllOf(op::Shape("f32[8,8]"),
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Constant())));
auto tiled = AllOf(op::Shape("f32[4,4]"),
op::Copy(op::DynamicSlice(partially_replicated,
op::Reshape(), op::Reshape())));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, tiled);
}
TEST_F(SpmdPartitioningTest,
PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
const char* const hlo_string = R"(

View File

@ -342,14 +342,14 @@ absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
return absl::nullopt;
}
if (target_is_partial_replicate) {
reshape_dimensions.back() = num_replication / num_target_replication;
} else {
reshape_dimensions.pop_back();
}
reshape_dimensions.pop_back();
reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
expand_tile_sizes.end());
if (target_is_partial_replicate) {
reshape_dimensions.push_back(num_target_replication);
}
auto reshape_tile_assignment = partial_sharding.tile_assignment();
reshape_tile_assignment.Reshape(reshape_dimensions);