[XLA:SPMD] Apply all to all when resharding from/to partial replicate at some
special cases. PiperOrigin-RevId: 348091566 Change-Id: I178f5dc9fe83cb9bcf45a87998cb755691935b4f
This commit is contained in:
parent
8a894275a4
commit
629a241b55
@ -417,6 +417,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
try_reshard = ReshardPartialReplicateWithAllToAll(target);
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
}
|
||||
|
||||
if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
|
||||
@ -424,6 +428,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
try_reshard = ReshardPartialReplicateWithAllToAll(target);
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
}
|
||||
|
||||
// If not replicated yet, first replicate and then reshard to use one of the
|
||||
@ -1216,6 +1224,92 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||
.ReshardWithAllToAll(target, remaining_source_target_dims);
|
||||
}
|
||||
|
||||
absl::optional<PartitionedHlo>
|
||||
PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
|
||||
bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
|
||||
const auto& partial_replicate_sharding =
|
||||
source_is_partial_replicate ? sharding() : target;
|
||||
// If neither the source nor the target is partial replicate, return null.
|
||||
if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
|
||||
// If both source and target are partial replicate, should be supported in
|
||||
// Reshard with AllToAll already.
|
||||
if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
|
||||
// to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
|
||||
// the last tile dim will be replicate first before all-to-all.
|
||||
// Or resharding from
|
||||
// sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||
// to sharding={devices=[2,3]0,1,2,3,4,5}, where
|
||||
// the last tile dim will be sharded after all-to-all.
|
||||
const int num_replicas =
|
||||
partial_replicate_sharding.tile_assignment().dimensions().back();
|
||||
if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
|
||||
partial_replicate_sharding.tile_assignment().num_dimensions()) ||
|
||||
(partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int to_replicate_dim = -1;
|
||||
for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
|
||||
--i) {
|
||||
if (tile_sharding.tile_assignment().dim(i) > 1 &&
|
||||
(to_replicate_dim == -1)) {
|
||||
if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
to_replicate_dim = i;
|
||||
}
|
||||
|
||||
if (tile_sharding.tile_assignment().dim(i) !=
|
||||
partial_replicate_sharding.tile_assignment().dim(i + 1)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
if (to_replicate_dim == -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Check if core assignments for source and the target are the same.
|
||||
auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
|
||||
reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
|
||||
if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
auto tmp_tile_assignment = tile_sharding.tile_assignment();
|
||||
auto tmp_tile_assignment_dimensions =
|
||||
tile_sharding.tile_assignment().dimensions();
|
||||
tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
|
||||
tmp_tile_assignment_dimensions.push_back(num_replicas);
|
||||
tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
|
||||
auto tmp_partial_replicate_sharding =
|
||||
HloSharding::PartialTile(tmp_tile_assignment);
|
||||
|
||||
if (source_is_partial_replicate) {
|
||||
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
|
||||
sharding(), tmp_partial_replicate_sharding)) {
|
||||
auto partitioned_hlo =
|
||||
ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
|
||||
return partitioned_hlo.Reshard(target);
|
||||
}
|
||||
} else {
|
||||
auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
|
||||
|
||||
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
|
||||
partitioned_hlo.sharding(), target)) {
|
||||
return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
||||
const HloSharding& target) const {
|
||||
CHECK(CanReshardWithCollectivePermute(sharding(), target))
|
||||
|
@ -338,6 +338,10 @@ class PartitionedHlo {
|
||||
absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
|
||||
const HloSharding& target);
|
||||
|
||||
// Helper function to reshard from partial replicate using AllToAll.
|
||||
absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
|
||||
const HloSharding& target);
|
||||
|
||||
// SPMD instruction.
|
||||
HloInstruction* hlo_;
|
||||
|
||||
|
@ -5542,6 +5542,55 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, partially_replicated);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0),
|
||||
sharding={devices=[2,3]0,1,2,3,4,5}
|
||||
ROOT %copy0 = f32[8,8] copy(%param0),
|
||||
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/6));
|
||||
VLOG(1) << module->ToString();
|
||||
auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
|
||||
auto partially_replicated = AllOf(
|
||||
op::Shape("f32[8,4]"),
|
||||
op::Copy(op::Reshape(
|
||||
op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
|
||||
op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, partially_replicated);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0),
|
||||
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||
ROOT %copy0 = f32[8,8] copy(%param0),
|
||||
sharding={devices=[2,3]0,1,2,3,4,5}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/6));
|
||||
VLOG(1) << module->ToString();
|
||||
auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
|
||||
auto tiled = AllOf(
|
||||
op::Shape("f32[4,3]"),
|
||||
op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
|
||||
op::Reshape(partial_replicated)))),
|
||||
_),
|
||||
_, _)));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, tiled);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
Loading…
x
Reference in New Issue
Block a user