[XLA:SPMD] Fix bug in partial replicate resharding.
PiperOrigin-RevId: 327906446 Change-Id: Ifb39e840f9a8b66a52160f971007f6672f59f679
This commit is contained in:
parent
83c96f6b8c
commit
41f429e7b5
@ -879,7 +879,6 @@ PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
|
||||
// Add another dimension in tiling_dim_factors if target is partial replicate.
|
||||
if (target.ReplicateOnLastTileDim()) {
|
||||
tiling_dim_factors.emplace_back(
|
||||
sharding().tile_assignment().dimensions().back() /
|
||||
target.tile_assignment().dimensions().back());
|
||||
}
|
||||
|
||||
|
@ -4938,6 +4938,33 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, tiled);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
PartialReplicateToPartialReplicateReshard_DynamicSlice2) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0)
|
||||
%copy = f32[8,8] copy(%param0),
|
||||
sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
||||
ROOT %copy0 = f32[8,8] copy(%copy),
|
||||
sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/8));
|
||||
VLOG(1) << module->ToString();
|
||||
auto partially_replicated =
|
||||
AllOf(op::Shape("f32[8,8]"),
|
||||
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
|
||||
op::Constant())));
|
||||
auto tiled = AllOf(op::Shape("f32[4,4]"),
|
||||
op::Copy(op::DynamicSlice(partially_replicated,
|
||||
op::Reshape(), op::Reshape())));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, tiled);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
|
||||
const char* const hlo_string = R"(
|
||||
|
@ -342,14 +342,14 @@ absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
if (target_is_partial_replicate) {
|
||||
reshape_dimensions.back() = num_replication / num_target_replication;
|
||||
} else {
|
||||
reshape_dimensions.pop_back();
|
||||
}
|
||||
|
||||
reshape_dimensions.pop_back();
|
||||
reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
|
||||
expand_tile_sizes.end());
|
||||
|
||||
if (target_is_partial_replicate) {
|
||||
reshape_dimensions.push_back(num_target_replication);
|
||||
}
|
||||
|
||||
auto reshape_tile_assignment = partial_sharding.tile_assignment();
|
||||
reshape_tile_assignment.Reshape(reshape_dimensions);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user