[XLA:SPMD] Apply all to all when resharding from/to partial replicate at some
special cases. PiperOrigin-RevId: 348091566 Change-Id: I178f5dc9fe83cb9bcf45a87998cb755691935b4f
This commit is contained in:
parent
8a894275a4
commit
629a241b55
@ -417,6 +417,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
|||||||
if (try_reshard.has_value()) {
|
if (try_reshard.has_value()) {
|
||||||
return try_reshard.value();
|
return try_reshard.value();
|
||||||
}
|
}
|
||||||
|
try_reshard = ReshardPartialReplicateWithAllToAll(target);
|
||||||
|
if (try_reshard.has_value()) {
|
||||||
|
return try_reshard.value();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
|
if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
|
||||||
@ -424,6 +428,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
|||||||
if (try_reshard.has_value()) {
|
if (try_reshard.has_value()) {
|
||||||
return try_reshard.value();
|
return try_reshard.value();
|
||||||
}
|
}
|
||||||
|
try_reshard = ReshardPartialReplicateWithAllToAll(target);
|
||||||
|
if (try_reshard.has_value()) {
|
||||||
|
return try_reshard.value();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If not replicated yet, first replicate and then reshard to use one of the
|
// If not replicated yet, first replicate and then reshard to use one of the
|
||||||
@ -1216,6 +1224,92 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
|||||||
.ReshardWithAllToAll(target, remaining_source_target_dims);
|
.ReshardWithAllToAll(target, remaining_source_target_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::optional<PartitionedHlo>
|
||||||
|
PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
|
||||||
|
bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
|
||||||
|
const auto& partial_replicate_sharding =
|
||||||
|
source_is_partial_replicate ? sharding() : target;
|
||||||
|
// If neither the source nor the target is partial replicate, return null.
|
||||||
|
if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
|
||||||
|
// If both source and target are partial replicate, should be supported in
|
||||||
|
// Reshard with AllToAll already.
|
||||||
|
if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
|
||||||
|
// to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
|
||||||
|
// the last tile dim will be replicate first before all-to-all.
|
||||||
|
// Or resharding from
|
||||||
|
// sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||||
|
// to sharding={devices=[2,3]0,1,2,3,4,5}, where
|
||||||
|
// the last tile dim will be sharded after all-to-all.
|
||||||
|
const int num_replicas =
|
||||||
|
partial_replicate_sharding.tile_assignment().dimensions().back();
|
||||||
|
if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
|
||||||
|
partial_replicate_sharding.tile_assignment().num_dimensions()) ||
|
||||||
|
(partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
int to_replicate_dim = -1;
|
||||||
|
for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
|
||||||
|
--i) {
|
||||||
|
if (tile_sharding.tile_assignment().dim(i) > 1 &&
|
||||||
|
(to_replicate_dim == -1)) {
|
||||||
|
if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
to_replicate_dim = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tile_sharding.tile_assignment().dim(i) !=
|
||||||
|
partial_replicate_sharding.tile_assignment().dim(i + 1)) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (to_replicate_dim == -1) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if core assignments for source and the target are the same.
|
||||||
|
auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
|
||||||
|
reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
|
||||||
|
if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tmp_tile_assignment = tile_sharding.tile_assignment();
|
||||||
|
auto tmp_tile_assignment_dimensions =
|
||||||
|
tile_sharding.tile_assignment().dimensions();
|
||||||
|
tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
|
||||||
|
tmp_tile_assignment_dimensions.push_back(num_replicas);
|
||||||
|
tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
|
||||||
|
auto tmp_partial_replicate_sharding =
|
||||||
|
HloSharding::PartialTile(tmp_tile_assignment);
|
||||||
|
|
||||||
|
if (source_is_partial_replicate) {
|
||||||
|
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
|
||||||
|
sharding(), tmp_partial_replicate_sharding)) {
|
||||||
|
auto partitioned_hlo =
|
||||||
|
ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
|
||||||
|
return partitioned_hlo.Reshard(target);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
|
||||||
|
|
||||||
|
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
|
||||||
|
partitioned_hlo.sharding(), target)) {
|
||||||
|
return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
||||||
const HloSharding& target) const {
|
const HloSharding& target) const {
|
||||||
CHECK(CanReshardWithCollectivePermute(sharding(), target))
|
CHECK(CanReshardWithCollectivePermute(sharding(), target))
|
||||||
|
@ -338,6 +338,10 @@ class PartitionedHlo {
|
|||||||
absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
|
absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
|
||||||
const HloSharding& target);
|
const HloSharding& target);
|
||||||
|
|
||||||
|
// Helper function to reshard from partial replicate using AllToAll.
|
||||||
|
absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
|
||||||
|
const HloSharding& target);
|
||||||
|
|
||||||
// SPMD instruction.
|
// SPMD instruction.
|
||||||
HloInstruction* hlo_;
|
HloInstruction* hlo_;
|
||||||
|
|
||||||
|
@ -5542,6 +5542,55 @@ ENTRY entry {
|
|||||||
EXPECT_THAT(root, partially_replicated);
|
EXPECT_THAT(root, partially_replicated);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
%param0 = f32[8,8] parameter(0),
|
||||||
|
sharding={devices=[2,3]0,1,2,3,4,5}
|
||||||
|
ROOT %copy0 = f32[8,8] copy(%param0),
|
||||||
|
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/6));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
|
||||||
|
auto partially_replicated = AllOf(
|
||||||
|
op::Shape("f32[8,4]"),
|
||||||
|
op::Copy(op::Reshape(
|
||||||
|
op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
|
||||||
|
op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, partially_replicated);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
%param0 = f32[8,8] parameter(0),
|
||||||
|
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||||
|
ROOT %copy0 = f32[8,8] copy(%param0),
|
||||||
|
sharding={devices=[2,3]0,1,2,3,4,5}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/6));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
|
||||||
|
auto tiled = AllOf(
|
||||||
|
op::Shape("f32[4,3]"),
|
||||||
|
op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
|
||||||
|
op::Reshape(partial_replicated)))),
|
||||||
|
_),
|
||||||
|
_, _)));
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, tiled);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
|
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
|
||||||
const char* const hlo_string = R"(
|
const char* const hlo_string = R"(
|
||||||
HloModule module
|
HloModule module
|
||||||
|
Loading…
x
Reference in New Issue
Block a user