[XLA:SPMD] Apply all to all when resharding from/to partial replicate at some

special cases.

PiperOrigin-RevId: 348091566
Change-Id: I178f5dc9fe83cb9bcf45a87998cb755691935b4f
This commit is contained in:
A. Unique TensorFlower 2020-12-17 14:12:56 -08:00 committed by TensorFlower Gardener
parent 8a894275a4
commit 629a241b55
3 changed files with 147 additions and 0 deletions

View File

@ -417,6 +417,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
if (try_reshard.has_value()) {
return try_reshard.value();
}
try_reshard = ReshardPartialReplicateWithAllToAll(target);
if (try_reshard.has_value()) {
return try_reshard.value();
}
}
if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
@ -424,6 +428,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
if (try_reshard.has_value()) {
return try_reshard.value();
}
try_reshard = ReshardPartialReplicateWithAllToAll(target);
if (try_reshard.has_value()) {
return try_reshard.value();
}
}
// If not replicated yet, first replicate and then reshard to use one of the
@ -1216,6 +1224,92 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
.ReshardWithAllToAll(target, remaining_source_target_dims);
}
absl::optional<PartitionedHlo>
PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
const auto& partial_replicate_sharding =
source_is_partial_replicate ? sharding() : target;
// If neither the source nor the target is partial replicate, return null.
if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
return absl::nullopt;
}
const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
// If both source and target are partial replicate, should be supported in
// Reshard with AllToAll already.
if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
return absl::nullopt;
}
// Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
// to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
// the last tile dim will be replicate first before all-to-all.
// Or resharding from
// sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
// to sharding={devices=[2,3]0,1,2,3,4,5}, where
// the last tile dim will be sharded after all-to-all.
const int num_replicas =
partial_replicate_sharding.tile_assignment().dimensions().back();
if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
partial_replicate_sharding.tile_assignment().num_dimensions()) ||
(partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
return absl::nullopt;
}
int to_replicate_dim = -1;
for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
--i) {
if (tile_sharding.tile_assignment().dim(i) > 1 &&
(to_replicate_dim == -1)) {
if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
return absl::nullopt;
}
to_replicate_dim = i;
}
if (tile_sharding.tile_assignment().dim(i) !=
partial_replicate_sharding.tile_assignment().dim(i + 1)) {
return absl::nullopt;
}
}
if (to_replicate_dim == -1) {
return absl::nullopt;
}
// Check if core assignments for source and the target are the same.
auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
return absl::nullopt;
}
auto tmp_tile_assignment = tile_sharding.tile_assignment();
auto tmp_tile_assignment_dimensions =
tile_sharding.tile_assignment().dimensions();
tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
tmp_tile_assignment_dimensions.push_back(num_replicas);
tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
auto tmp_partial_replicate_sharding =
HloSharding::PartialTile(tmp_tile_assignment);
if (source_is_partial_replicate) {
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
sharding(), tmp_partial_replicate_sharding)) {
auto partitioned_hlo =
ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
return partitioned_hlo.Reshard(target);
}
} else {
auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
partitioned_hlo.sharding(), target)) {
return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
}
}
return absl::nullopt;
}
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
const HloSharding& target) const {
CHECK(CanReshardWithCollectivePermute(sharding(), target))

View File

@ -338,6 +338,10 @@ class PartitionedHlo {
absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
const HloSharding& target);
// Helper function to reshard from partial replicate using AllToAll.
absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
const HloSharding& target);
// SPMD instruction.
HloInstruction* hlo_;

View File

@ -5542,6 +5542,55 @@ ENTRY entry {
EXPECT_THAT(root, partially_replicated);
}
TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8] parameter(0),
sharding={devices=[2,3]0,1,2,3,4,5}
ROOT %copy0 = f32[8,8] copy(%param0),
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/6));
VLOG(1) << module->ToString();
auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
auto partially_replicated = AllOf(
op::Shape("f32[8,4]"),
op::Copy(op::Reshape(
op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, partially_replicated);
}
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%param0 = f32[8,8] parameter(0),
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
ROOT %copy0 = f32[8,8] copy(%param0),
sharding={devices=[2,3]0,1,2,3,4,5}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/6));
VLOG(1) << module->ToString();
auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
auto tiled = AllOf(
op::Shape("f32[4,3]"),
op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
op::Reshape(partial_replicated)))),
_),
_, _)));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, tiled);
}
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
const char* const hlo_string = R"(
HloModule module