From 47917266f3115b6c92088985df33f52fc7cb018b Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Tue, 9 Feb 2021 14:17:13 -0800 Subject: [PATCH] [XLA:SPMD] Use per-mesh-dimension allreduce/allgather PiperOrigin-RevId: 356589391 Change-Id: I0649b6e9395f11a02ba1890102991e7566443b4f --- .../compiler/xla/service/spmd/dot_handler.cc | 57 +++++++--------- .../service/spmd/gather_scatter_handler.cc | 34 +++++----- .../xla/service/spmd/spmd_partitioner.cc | 65 +++++++++++++------ .../xla/service/spmd/spmd_partitioner.h | 12 +++- .../xla/service/spmd/spmd_partitioner_test.cc | 22 +++---- 5 files changed, 106 insertions(+), 84 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index c9cae573563..2c586400da7 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -92,17 +92,6 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { namespace { -std::vector GetAllDevicesInOrder(const HloSharding& sharding) { - CHECK(!sharding.IsTileMaximal()); - std::vector results; - results.reserve(sharding.tile_assignment().num_elements()); - sharding.tile_assignment().Each( - [&](absl::Span /* indices */, int64 device) { - results.push_back(device); - }); - return results; -} - StatusOr PartitionBaseCase( PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping, @@ -1018,11 +1007,15 @@ StatusOr PartitionBaseCase( } TF_ASSIGN_OR_RETURN( auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); - auto ar = - lhs.state().collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), - {GetAllDevicesInOrder(lhs.sharding())}, - (*lhs.state().next_channel_id)++); + std::vector lhs_contracting_dims; + lhs_contracting_dims.reserve(lhs.base_shape().rank()); + for (const auto& cd : dims_mapping.contracting_dims) { + lhs_contracting_dims.push_back(cd.lhs); + } + auto ar = lhs.state().partitioner->AllReduceAlongShardingDims( + b, dot, lhs.sharding(), lhs.state().next_channel_id, + lhs_contracting_dims, lhs.state().collective_ops_creator, + MakeBinaryAdd(output_base_shape.element_type(), module)); ar->set_sharding(HloSharding::Replicate()); return PartitionedHlo(ar, output_base_shape, lhs.state()) .Reshard(output_sharding) @@ -1123,10 +1116,16 @@ StatusOr PartitionBaseCase( } TF_ASSIGN_OR_RETURN( auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window)); - return lhs.state().collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), - {GetAllDevicesInOrder(lhs.sharding())}, - (*lhs.state().next_channel_id)++); + + std::vector lhs_contracting_dims; + lhs_contracting_dims.reserve(lhs.base_shape().rank()); + for (const auto& cd : dims_mapping.contracting_dims) { + lhs_contracting_dims.push_back(cd.lhs); + } + return lhs.state().partitioner->AllReduceAlongShardingDims( + b, dot, lhs.sharding(), lhs.state().next_channel_id, + lhs_contracting_dims, lhs.state().collective_ops_creator, + MakeBinaryAdd(output_base_shape.element_type(), module)); } return nullptr; } @@ -1679,20 +1678,10 @@ StatusOr PartitionDotGroupOnContracting( if (!dot) { return nullptr; } - std::vector other_lhs_dims; - for (int64 i = 0; i < lhs_sharding.tile_assignment().num_dimensions(); ++i) { - if (!absl::c_linear_search(lhs_dims, i)) { - other_lhs_dims.push_back(i); - } - } - auto inverse_grouped = GroupShardingOnDims(lhs_sharding, other_lhs_dims); - auto ar = - CreatePerGroupPartitioningState(lhs.state(), - inverse_grouped.device_groups, b) - .collective_ops_creator.create_cross_partition_all_reduce( - b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), - {GetAllDevicesInOrder(inverse_grouped.sharding)}, - (*lhs.state().next_channel_id)++); + auto ar = lhs.state().partitioner->AllReduceAlongShardingDims( + b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims, + lhs.state().collective_ops_creator, + MakeBinaryAdd(output_base_shape.element_type(), module)); ar->set_sharding(outer_output_tmp_sharding); return PartitionedHlo(ar, output_base_shape, lhs.state()) .Reshard(output_sharding) diff --git a/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc b/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc index ca1b0385ad1..013d1555774 100644 --- a/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" @@ -285,17 +286,15 @@ StatusOr ParititonTrivialIndexedOperandDimension( replicated_dim.push_back( operand.sharding().tile_assignment().num_dimensions() - 1); } - auto sharding_grouped = - GroupShardingOnDims(operand.sharding(), replicated_dim); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - operand.state(), sharding_grouped.device_groups, b); - auto collective_ops_creator = - per_group_partitioner_state.collective_ops_creator; - auto ar = collective_ops_creator.create_cross_partition_all_reduce( - b, filtered, + // All-reduce along all dims in operand sharding -- this is OK because the + // operand is sharded only on trivially sliced dimensions. + std::vector all_dims(operand.base_shape().rank()); + absl::c_iota(all_dims, 0); + auto ar = operand.state().partitioner->AllReduceAlongShardingDims( + b, filtered, operand.sharding(), operand.state().next_channel_id, + all_dims, operand.state().collective_ops_creator, MakeBinaryAdd(filtered->shape().element_type(), - per_group_partitioner_state.module), - {}, visitor->NewChannel()); + operand.state().module)); VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand " "batch_dim slice"; ar->set_sharding(HloSharding::Replicate()); @@ -574,8 +573,7 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { update_dim_to_index_dim); CHECK(new_updates_sharding.has_value()); updates = updates.Reshard(*new_updates_sharding); - // Update collective_ops_creator and partition_id for partial replicate. - auto collective_ops_creator = collective_ops_creator_; + // Update partition_id for partial replicate. auto partition_id = partition_id_; if (indices.sharding().ReplicateOnLastTileDim()) { auto sharding_grouped = GroupShardingOnDims( @@ -583,8 +581,6 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { {indices.sharding().tile_assignment().num_dimensions() - 1}); auto per_group_partitioner_state = CreatePerGroupPartitioningState( indices.state(), sharding_grouped.device_groups, &b_); - collective_ops_creator = - per_group_partitioner_state.collective_ops_creator; partition_id = per_group_partitioner_state.partition_id; } // To avoid accumulating the initial operand multiple times during @@ -600,9 +596,13 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { identity, operand.Replicate().hlo())); auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands( scatter->shape(), {select_operand, indices.hlo(), updates.hlo()})); - auto all_reduce = - collective_ops_creator.create_cross_partition_all_reduce( - &b_, pscatter, scatter->to_apply(), {}, NewChannel()); + // All-reduce along all dims in operand sharding -- this is OK because the + // operand is not sharded on index_vector_dim. + std::vector all_dims(indices.base_shape().rank()); + absl::c_iota(all_dims, 0); + auto all_reduce = operand.state().partitioner->AllReduceAlongShardingDims( + &b_, pscatter, indices.sharding(), indices.state().next_channel_id, + all_dims, collective_ops_creator_, scatter->to_apply()); all_reduce->set_sharding(HloSharding::Replicate()); SetPartitionedHlo(hlo, [&]() { return PartitionedHlo(all_reduce, hlo->shape(), MakePartitioningState()) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index f2286bd1eb5..3ed7d8e1aea 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -877,7 +877,7 @@ HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span dims) { HloInstruction* result = nullptr; if (state_.collective_ops_creator.create_cross_partition_all_gather) { result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(), - NewChannel(), dims, + state_.next_channel_id, dims, state_.collective_ops_creator); } if (result == nullptr) { @@ -892,12 +892,9 @@ HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span dims) { padded_target_shape, zero_bcast, hlo_, offsets)); HloComputation* reduction = MakeBinaryAdd(shard_shape.element_type(), state_.module); - - auto all_reduce = - state_.collective_ops_creator.create_cross_partition_all_reduce( - state_.b, dus, reduction, - GetPartitionGroupsForReplication(sharding(), dims), NewChannel()); - result = all_reduce; + result = state_.partitioner->AllReduceAlongShardingDims( + state_.b, dus, sharding(), state_.next_channel_id, dims, + state_.collective_ops_creator, reduction); } if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) { std::vector start_indices(target_shape.rank(), 0); @@ -2765,14 +2762,15 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { if (inputs[0].sharding().ReplicateOnLastTileDim()) { preserved_dims.push_back(inputs[0].base_shape().rank()); } - auto grouped = GroupShardingOnDims(inputs[0].sharding(), preserved_dims); - auto grouped_state = CreatePerGroupPartitioningState( - inputs[0].state(), grouped.device_groups, &b_); if (local_reduce->shape().IsArray()) { - reduce = grouped_state.collective_ops_creator - .create_cross_partition_all_reduce( - &b_, local_reduce, hlo->to_apply(), {}, NewChannel()); + reduce = partitioner_->AllReduceAlongShardingDims( + &b_, local_reduce, inputs[0].sharding(), next_channel_id_, + hlo->dimensions(), collective_ops_creator_, hlo->to_apply()); } else { + auto grouped = + GroupShardingOnDims(inputs[0].sharding(), preserved_dims); + auto grouped_state = CreatePerGroupPartitioningState( + inputs[0].state(), grouped.device_groups, &b_); std::vector all_gathered_partial_results(input_count); for (int64 i = 0; i < input_count; ++i) { auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -3500,8 +3498,11 @@ SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, HloInstruction* SpmdPartitioner::AllGatherShards( SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, - int64 channel_id, absl::Span selected_dims, + int64* next_channel_id, absl::Span selected_dims, const SPMDCollectiveOpsCreator& collectives_creator) { + if (selected_dims.empty()) { + return operand; + } CHECK(!sharding.IsTileMaximal()); // Add one leading dimension to gather all partitions. std::vector shape; @@ -3511,12 +3512,18 @@ HloInstruction* SpmdPartitioner::AllGatherShards( } auto reshape = b->AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand)); - auto partition_subgroups = - GetPartitionGroupsForReplication(sharding, selected_dims); - shape[0] = partition_subgroups[0].size(); - auto result = collectives_creator.create_cross_partition_all_gather( - b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape), - partition_subgroups, channel_id, /*all_gather_dimension=*/0); + HloInstruction* result = reshape; + for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) { + if (sharding.tile_assignment().dim(*it) == 1) { + continue; + } + auto partition_subgroups = + GetPartitionGroupsForReplication(sharding, {*it}); + shape[0] *= partition_subgroups[0].size(); + result = collectives_creator.create_cross_partition_all_gather( + b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape), + partition_subgroups, (*next_channel_id)++, /*all_gather_dimension=*/0); + } // If n > 1 dimensions are partitioned, split the leading dimension to n. std::vector tiled_dims; for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { @@ -3567,6 +3574,24 @@ HloInstruction* SpmdPartitioner::AllGatherShards( return result; } +HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims( + SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, + int64* next_channel_id, absl::Span selected_dims, + const SPMDCollectiveOpsCreator& collectives_creator, + HloComputation* reduction) { + auto result = operand; + for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) { + if (sharding.tile_assignment().dim(*it) == 1) { + continue; + } + auto partition_subgroups = + GetPartitionGroupsForReplication(sharding, {*it}); + result = collectives_creator.create_cross_partition_all_reduce( + b, result, reduction, partition_subgroups, (*next_channel_id)++); + } + return result; +} + StatusOr SpmdPartitioner::PartitionComputation( HloComputation* computation, const HloSharding& root_sharding, int64* next_channel_id, SpmdLogger* logger) { diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index c666269dc2e..5c9001a8e35 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -197,7 +197,7 @@ class SpmdPartitioner : public HloModulePass { int64* next_channel_id, SpmdLogger* logger); - // Creates all-gather based on HloSharding. Can be overridden to customize. + // Creates all-gather(s) based on HloSharding. Can be overridden to customize. // The default uses a single all-gather even if there are multiple sharded // dimensions, and adds potential reshapes and transposes to achieve that. // If it returns false, the partitioner will fall back to all-reduce. @@ -206,9 +206,17 @@ class SpmdPartitioner : public HloModulePass { // all-gather. virtual HloInstruction* AllGatherShards( SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, - int64 channel_id, absl::Span selected_dims, + int64* next_channel_id, absl::Span selected_dims, const SPMDCollectiveOpsCreator& collectives_creator); + // Creates all-reduce(s) across devices along selected_dims in sharding. Can + // be overridden to customize. + virtual HloInstruction* AllReduceAlongShardingDims( + SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, + int64* next_channel_id, absl::Span selected_dims, + const SPMDCollectiveOpsCreator& collectives_creator, + HloComputation* reduction); + const SpmdPartitionerOptions& options() { return options_; } protected: diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 908303a65d6..f44da51ab33 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3705,7 +3705,7 @@ ENTRY entry { op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Constant(), op::Reshape(), op::Reshape())), op::Shape("f32[32,39296,32,64]")); - EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + EXPECT_THAT(root, AllOf(op::AllReduce(op::AllReduce(op::Dot(lhs, rhs))), op::Shape("f32[32,24,39296]"))); } @@ -4662,10 +4662,10 @@ ENTRY entry { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( root, - AllOf(op::AllReduce(op::Scatter( + AllOf(op::AllReduce(op::AllReduce(op::Scatter( op::Select(op::Broadcast(op::Convert(op::PartitionId())), op::Broadcast(op::Constant()), op::Parameter(0)), - op::Parameter(1), op::Parameter(2))), + op::Parameter(1), op::Parameter(2)))), op::Shape("f32[2,9,8]"))); } @@ -4698,10 +4698,10 @@ ENTRY entry { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( root, - AllOf(op::AllReduce(op::Scatter( + AllOf(op::AllReduce(op::AllReduce(op::Scatter( op::Select(op::Broadcast(op::Convert(op::Reshape())), op::Broadcast(op::Constant()), op::Parameter(0)), - op::Parameter(1), op::Parameter(2))), + op::Parameter(1), op::Parameter(2)))), op::Shape("f32[2,9,8]"))); } @@ -4732,10 +4732,10 @@ ENTRY entry { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( root, - AllOf(op::AllReduce(op::Scatter( + AllOf(op::AllReduce(op::AllReduce(op::Scatter( op::Select(op::Broadcast(op::Convert(op::PartitionId())), op::Broadcast(op::Constant()), op::Parameter(0)), - op::Parameter(1), op::Parameter(2))), + op::Parameter(1), op::Parameter(2)))), op::Shape("f32[2,9,8]"))); } @@ -6733,8 +6733,8 @@ ENTRY %module { auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice()); auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract()); auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices)); - EXPECT_THAT(root, - op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _))); + EXPECT_THAT(root, op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(_, gather, _, _, _, _)))); } TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedIndices) { @@ -6887,8 +6887,8 @@ ENTRY %module { auto operand = AllOf(op::Shape("s32[4,1,2,2]"), op::CollectivePermute()); auto indices = AllOf(op::Shape("s32[2,4,1]"), op::Subtract()); auto gather = AllOf(op::Shape("s32[4,1,2,2]"), op::Gather(operand, indices)); - EXPECT_THAT(root, - op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _))); + EXPECT_THAT(root, op::AllReduce(op::AllReduce( + op::DynamicUpdateSlice(_, gather, _, _, _, _)))); } TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {