From 629a241b5506cd270cf8c97615bb4b96521dbfda Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 Dec 2020 14:12:56 -0800 Subject: [PATCH] [XLA:SPMD] Apply all to all when resharding from/to partial replicate at some special cases. PiperOrigin-RevId: 348091566 Change-Id: I178f5dc9fe83cb9bcf45a87998cb755691935b4f --- .../xla/service/spmd/spmd_partitioner.cc | 94 +++++++++++++++++++ .../xla/service/spmd/spmd_partitioner.h | 4 + .../xla/service/spmd/spmd_partitioner_test.cc | 49 ++++++++++ 3 files changed, 147 insertions(+) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 4064308aa40..b36c646c737 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -417,6 +417,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { if (try_reshard.has_value()) { return try_reshard.value(); } + try_reshard = ReshardPartialReplicateWithAllToAll(target); + if (try_reshard.has_value()) { + return try_reshard.value(); + } } if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) { @@ -424,6 +428,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { if (try_reshard.has_value()) { return try_reshard.value(); } + try_reshard = ReshardPartialReplicateWithAllToAll(target); + if (try_reshard.has_value()) { + return try_reshard.value(); + } } // If not replicated yet, first replicate and then reshard to use one of the @@ -1216,6 +1224,92 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( .ReshardWithAllToAll(target, remaining_source_target_dims); } +absl::optional +PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) { + bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim(); + const auto& partial_replicate_sharding = + source_is_partial_replicate ? sharding() : target; + // If neither the source nor the target is partial replicate, return null. + if (!partial_replicate_sharding.ReplicateOnLastTileDim()) { + return absl::nullopt; + } + const auto& tile_sharding = source_is_partial_replicate ? target : sharding(); + // If both source and target are partial replicate, should be supported in + // Reshard with AllToAll already. + if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) { + return absl::nullopt; + } + + // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5} + // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where + // the last tile dim will be replicate first before all-to-all. + // Or resharding from + // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate} + // to sharding={devices=[2,3]0,1,2,3,4,5}, where + // the last tile dim will be sharded after all-to-all. + const int num_replicas = + partial_replicate_sharding.tile_assignment().dimensions().back(); + if (((tile_sharding.tile_assignment().num_dimensions() + 1) != + partial_replicate_sharding.tile_assignment().num_dimensions()) || + (partial_replicate_sharding.tile_assignment().dim(0) != 1)) { + return absl::nullopt; + } + int to_replicate_dim = -1; + for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0; + --i) { + if (tile_sharding.tile_assignment().dim(i) > 1 && + (to_replicate_dim == -1)) { + if (tile_sharding.tile_assignment().dim(i) != num_replicas) { + return absl::nullopt; + } + to_replicate_dim = i; + } + + if (tile_sharding.tile_assignment().dim(i) != + partial_replicate_sharding.tile_assignment().dim(i + 1)) { + return absl::nullopt; + } + } + + if (to_replicate_dim == -1) { + return absl::nullopt; + } + + // Check if core assignments for source and the target are the same. + auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment(); + reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions()); + if (reshape_tile_assignment != tile_sharding.tile_assignment()) { + return absl::nullopt; + } + + auto tmp_tile_assignment = tile_sharding.tile_assignment(); + auto tmp_tile_assignment_dimensions = + tile_sharding.tile_assignment().dimensions(); + tmp_tile_assignment_dimensions[to_replicate_dim] = 1; + tmp_tile_assignment_dimensions.push_back(num_replicas); + tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions); + auto tmp_partial_replicate_sharding = + HloSharding::PartialTile(tmp_tile_assignment); + + if (source_is_partial_replicate) { + if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims( + sharding(), tmp_partial_replicate_sharding)) { + auto partitioned_hlo = + ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims); + return partitioned_hlo.Reshard(target); + } + } else { + auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding); + + if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims( + partitioned_hlo.sharding(), target)) { + return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims); + } + } + + return absl::nullopt; +} + PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( const HloSharding& target) const { CHECK(CanReshardWithCollectivePermute(sharding(), target)) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index d5a2efd9fc0..d77fd7eb2d3 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -338,6 +338,10 @@ class PartitionedHlo { absl::optional ReshardFromPartialReplicateWithDynamicSlice( const HloSharding& target); + // Helper function to reshard from partial replicate using AllToAll. + absl::optional ReshardPartialReplicateWithAllToAll( + const HloSharding& target); + // SPMD instruction. HloInstruction* hlo_; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index cac7694de65..52bd709525e 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -5542,6 +5542,55 @@ ENTRY entry { EXPECT_THAT(root, partially_replicated); } +TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0), + sharding={devices=[2,3]0,1,2,3,4,5} + ROOT %copy0 = f32[8,8] copy(%param0), + sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/6)); + VLOG(1) << module->ToString(); + auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0)); + auto partially_replicated = AllOf( + op::Shape("f32[8,4]"), + op::Copy(op::Reshape( + op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _))))))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0), + sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%param0), + sharding={devices=[2,3]0,1,2,3,4,5} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/6)); + VLOG(1) << module->ToString(); + auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0)); + auto tiled = AllOf( + op::Shape("f32[4,3]"), + op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll( + op::Reshape(partial_replicated)))), + _), + _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) { const char* const hlo_string = R"( HloModule module