From 0b5cc6f1b98fa6f5a3bc413cf30a87e4b3f1af8c Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Fri, 24 Jul 2020 00:38:06 -0700 Subject: [PATCH] Fix build and resubmit: [XLA:SPMD] Recursively handling more Dot cases PiperOrigin-RevId: 322949994 Change-Id: I44a8a8e958a7ba4995a667d139f793dfa3a4fe7f --- tensorflow/compiler/xla/service/spmd/BUILD | 1 + .../xla/service/spmd/convolution_handler.cc | 4 +- .../compiler/xla/service/spmd/dot_handler.cc | 717 +++++++++++++----- .../xla/service/spmd/spmd_partitioner.cc | 170 +++-- .../xla/service/spmd/spmd_partitioner.h | 57 +- .../xla/service/spmd/spmd_partitioner_test.cc | 152 +++- .../xla/service/spmd/spmd_partitioner_util.cc | 266 ++++++- .../xla/service/spmd/spmd_partitioner_util.h | 52 +- 8 files changed, 1121 insertions(+), 298 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index e41b89f6dff..a67e4cf55c5 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/core/platform:numbers", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc index 1204df59080..4caa2bbbf35 100644 --- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc @@ -226,7 +226,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( hlo->batch_group_count(), new_window, hlo->convolution_dimension_numbers(), hlo->precision_config())); auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), {}, NewChannel()); ar->set_sharding(HloSharding::Replicate()); return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) @@ -605,7 +605,7 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { hlo->batch_group_count(), new_window, dnums, hlo->precision_config())); auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), {}, NewChannel()); ar->set_sharding(HloSharding::Replicate()); return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index 9ecf21f5841..8fea788b1b7 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -80,12 +80,25 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { return HandleDotHelper(hlo, mapping, create_sharded_dot); } -Status SpmdPartitioningVisitor::HandleDotHelper( - HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, +namespace { + +StatusOr PartitionBaseCase( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { - const HloSharding& lhs_sharding = hlo->operand(0)->sharding(); - const HloSharding& rhs_sharding = hlo->operand(1)->sharding(); + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions, + int64 rhs_batch_partitions, int64 output_batch_partitions, + int64 lhs_contracting_partitions, int64 rhs_contracting_partitions, + int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, + int64 output_lhs_non_contracting_partitions, + int64 output_rhs_non_contracting_partitions, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + const HloSharding& lhs_sharding = lhs.sharding(); + const HloSharding& rhs_sharding = rhs.sharding(); // Similar to hlo_sharding_util::TransposeSharding(), but allows // removing/adding non-partitioned dimensions. @@ -132,12 +145,12 @@ Status SpmdPartitioningVisitor::HandleDotHelper( return HloSharding::Tile(reshape_tiles); }; - std::vector lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1); - std::vector lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1); - std::vector rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1); - std::vector rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1); - std::vector output_to_lhs_indices(hlo->shape().rank(), -1); - std::vector output_to_rhs_indices(hlo->shape().rank(), -1); + std::vector lhs_to_rhs_indices(lhs.base_shape().rank(), -1); + std::vector lhs_to_output_indices(lhs.base_shape().rank(), -1); + std::vector rhs_to_lhs_indices(rhs.base_shape().rank(), -1); + std::vector rhs_to_output_indices(rhs.base_shape().rank(), -1); + std::vector output_to_lhs_indices(output_base_shape.rank(), -1); + std::vector output_to_rhs_indices(output_base_shape.rank(), -1); auto populate_indices_mapping = [&](const DotGeneralDimsMapping::DimsMapping& mapping) { if (mapping.lhs >= 0) { @@ -174,127 +187,84 @@ Status SpmdPartitioningVisitor::HandleDotHelper( auto rhs_sharding_transposed_to_match_output = transpose_sharding( rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); auto output_sharding_transposed_to_match_lhs = transpose_sharding( - hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices); + output_sharding, output_to_lhs_indices, lhs_to_output_indices); auto output_sharding_transposed_to_match_rhs = transpose_sharding( - hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices); + output_sharding, output_to_rhs_indices, rhs_to_output_indices); - // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. - auto get_partitions_for_dims = - [&](const HloSharding& sharding, - absl::Span dims, - int lhs_rhs_or_output) { - int64 partitions = 1; - if (sharding.IsTileMaximal()) { - return partitions; - } - for (const auto& dim : dims) { - if (lhs_rhs_or_output == 0) { - partitions *= sharding.tile_assignment().dim(dim.lhs); - } else if (lhs_rhs_or_output == 1) { - partitions *= sharding.tile_assignment().dim(dim.rhs); - } else { - CHECK_EQ(lhs_rhs_or_output, 2); - partitions *= sharding.tile_assignment().dim(dim.output); - } - } - return partitions; - }; - const int64 lhs_batch_partitions = - get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0); - const int64 rhs_batch_partitions = - get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1); - const int64 output_batch_partitions = - get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2); - const int64 lhs_contracting_partitions = - get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0); - const int64 rhs_contracting_partitions = - get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1); - const int64 lhs_non_contracting_partitions = get_partitions_for_dims( - lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0); - const int64 rhs_non_contracting_partitions = get_partitions_for_dims( - rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1); - const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( - hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2); - const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( - hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2); - - auto& lhs = GetPartitionedHlo(hlo->operand(0)); - auto& rhs = GetPartitionedHlo(hlo->operand(1)); // LHS and RHS are partitioned the same way and only partitioned in batch // dimensions. if (lhs_batch_partitions == rhs_batch_partitions && - rhs_batch_partitions == num_partitions_ && + rhs_batch_partitions == num_partitions && lhs_sharding_transposed_to_match_rhs == rhs_sharding) { - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { - dot->set_sharding(*lhs_sharding_transposed_to_match_output); - return PartitionedHlo(dot, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + dot->set_sharding(*lhs_sharding_transposed_to_match_output); + return PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); } // Try emit batch-partitioned einsum with one operand resharded. Returns - // whether the attempt succeeds. If may_reshard_with_allreduce is false, - // reshard must be done using all-to-all; otherwise this attempt fails. + // partitioned HLO or nullptr if the attempt fails. If + // may_reshard_with_allreduce is false, reshard must be done using + // all-to-all/collective-permute; otherwise this attempt fails. auto try_emit_output_batch_partitioned_einsum_with_reshard = - [&](bool may_reshard_with_allreduce) -> StatusOr { + [&](bool may_reshard_with_allreduce) -> StatusOr { // LHS and output are batch partitioned in the same way. - if (lhs_batch_partitions == num_partitions_ && - output_batch_partitions == num_partitions_ && - lhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (lhs_batch_partitions == num_partitions && + output_batch_partitions == num_partitions && + lhs_sharding_transposed_to_match_output == output_sharding) { if (!may_reshard_with_allreduce && + !CanReshardWithCollectivePermute( + rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) && !GetReshardAllToAllSourceTargetDims( rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) { - return false; + return nullptr; } auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); TF_ASSIGN_OR_RETURN( - auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return true; + auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b)); + return dot; } // RHS and output are batch partitioned in the same way. - if (rhs_batch_partitions == num_partitions_ && - output_batch_partitions == num_partitions_ && - rhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (rhs_batch_partitions == num_partitions && + output_batch_partitions == num_partitions && + rhs_sharding_transposed_to_match_output == output_sharding) { if (!may_reshard_with_allreduce && + !CanReshardWithCollectivePermute( + lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) && !GetReshardAllToAllSourceTargetDims( lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) { - return false; + return nullptr; } auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); TF_ASSIGN_OR_RETURN( - auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return true; + auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b)); + return dot; } - return false; + return nullptr; }; { // Try batch-parallel by resharding one operand, and not using all-reduce. TF_ASSIGN_OR_RETURN( - bool emitted, + HloInstruction * partitioned_dot, try_emit_output_batch_partitioned_einsum_with_reshard(false)); - if (emitted) { - return Status::OK(); + if (partitioned_dot) { + return partitioned_dot; } } // Try to emit windowed DotGeneral when one operand is partitioned in the same // way as the output along non-contracting dimensions, but the other operand // is tiled in other dimensions. - auto emit_windowed_dot_general = [&](int64 matching_operand, - int64 windowing_operand, - bool windowed_at_contracting_dims, - bool windowed_at_batch_dims) { + auto emit_windowed_dot_general = + [&](int64 matching_operand, int64 windowing_operand, + bool windowed_at_contracting_dims, + bool windowed_at_batch_dims) -> StatusOr { CHECK_EQ(matching_operand + windowing_operand, 1); CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); auto unpadded_result_buffer_shape = - MakePartitionedShape(hlo->shape(), hlo->sharding()); + MakePartitionedShape(output_base_shape, output_sharding); auto padded_result_buffer_shape = unpadded_result_buffer_shape; // For windowing at batch/non-contracting dims, we produce the result one // partition at a time, so we need to pad the shape in case of uneven @@ -310,17 +280,17 @@ Status SpmdPartitioningVisitor::HandleDotHelper( if (windowed_at_contracting_dims) { auto& to_mask = windowing_operand == 0 ? lhs : rhs; to_mask = - to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type())))); + to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type())))); } - auto result_buffer = CreateZero(padded_result_buffer_shape, &b_); - auto iteration = b_.AddInstruction( + auto result_buffer = CreateZero(padded_result_buffer_shape, b); + auto iteration = b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); // Create a while loop that computes one window per iteration. During each // iteration, each partition sends its input window to its neighbor using // collective-permute for the next iteration. - SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_); + SpmdBuilder body_b("windowed_dot_general_body", original_hlo); auto param = body_b.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), @@ -335,11 +305,12 @@ Status SpmdPartitioningVisitor::HandleDotHelper( auto i = body_b.AddInstruction( HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); - auto partition_id = collective_ops_creator_.create_partition_id(&body_b); + auto partition_id = + lhs.state().collective_ops_creator.create_partition_id(&body_b); auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( i->shape(), HloOpcode::kAdd, i, partition_id)); auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(num_partitions_))); + LiteralUtil::CreateR0(num_partitions))); data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); auto dot_lhs = l; @@ -350,7 +321,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper( // operand as replicated, and resharding it to match the windowed operand. auto slice_operand = matching_operand == 0 ? l : r; slice_operand->set_sharding(HloSharding::Replicate()); - auto state = MakePartitioningState(); + auto state = lhs.state(); state.b = &body_b; state.partition_id = data_partition_id; auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) @@ -392,26 +363,27 @@ Status SpmdPartitioningVisitor::HandleDotHelper( auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( ShapeUtil::MakeShape(PRED, {}), i, body_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(num_partitions_))), + LiteralUtil::CreateR0(num_partitions))), ComparisonDirection::kLt)); // Collective-permute for the next window. We don't need it for the last // iteration, so we use a conditional around the collective-permute. HloInstruction* conditional; { - SpmdBuilder cp_b("window_collective_permute", visiting_hlo_); + SpmdBuilder cp_b("window_collective_permute", original_hlo); { auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); - std::vector> sd_pairs(num_partitions_); - for (int64 source = 0; source < num_partitions_; ++source) { + std::vector> sd_pairs(num_partitions); + for (int64 source = 0; source < num_partitions; ++source) { // 0 -> n-1, 1 -> 0, 2 -> 1, ... sd_pairs[source] = {source, - (source - 1 + num_partitions_) % num_partitions_}; + (source - 1 + num_partitions) % num_partitions}; } - collective_ops_creator_.create_cross_partition_collective_permute( - &cp_b, p, sd_pairs, (*next_channel_id_)++); + lhs.state() + .collective_ops_creator.create_cross_partition_collective_permute( + &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++); } - SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_); + SpmdBuilder ncp_b("last_iteration_noop", original_hlo); { ncp_b.AddInstruction(HloInstruction::CreateParameter( 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); @@ -419,9 +391,9 @@ Status SpmdPartitioningVisitor::HandleDotHelper( conditional = body_b.AddInstruction(HloInstruction::CreateConditional( windowing_operand == 0 ? l->shape() : r->shape(), has_more, windowing_operand == 0 ? l : r, - module_->AddEmbeddedComputation(cp_b.Build()), + module->AddEmbeddedComputation(cp_b.Build()), windowing_operand == 0 ? l : r, - module_->AddEmbeddedComputation(ncp_b.Build()))); + module->AddEmbeddedComputation(ncp_b.Build()))); } if (windowing_operand == 0) { l = conditional; @@ -430,7 +402,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper( } body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); - SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_); + SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo); auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), @@ -441,56 +413,53 @@ Status SpmdPartitioningVisitor::HandleDotHelper( cond_b.AddInstruction(HloInstruction::CreateCompare( ShapeUtil::MakeShape(PRED, {}), cond_i, cond_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(num_partitions_))), + LiteralUtil::CreateR0(num_partitions))), ComparisonDirection::kLt)); - auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile( - cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()), - module_->AddEmbeddedComputation(body_b.Build()), - b_.AddInstruction(HloInstruction::CreateTuple( + auto while_loop = b->AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()), + module->AddEmbeddedComputation(body_b.Build()), + b->AddInstruction(HloInstruction::CreateTuple( {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); - windowed_dot_general_loops_.push_back({while_loop, windowing_operand, + windowed_dot_general_loops->push_back({while_loop, windowing_operand, windowed_at_contracting_dims, windowed_at_batch_dims}); - SetPartitionedHlo(hlo, [&] { - auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement( - result_buffer->shape(), while_loop, 2)); - if (!ShapeUtil::Compatible(padded_result_buffer_shape, - unpadded_result_buffer_shape)) { - result = b_.AddInstruction(HloInstruction::CreateSlice( - unpadded_result_buffer_shape, result, - std::vector(padded_result_buffer_shape.rank(), 0), - unpadded_result_buffer_shape.dimensions(), - std::vector(padded_result_buffer_shape.rank(), 1))); - } - return result; - }); - return Status::OK(); + auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), while_loop, 2)); + if (!ShapeUtil::Compatible(padded_result_buffer_shape, + unpadded_result_buffer_shape)) { + result = b->AddInstruction(HloInstruction::CreateSlice( + unpadded_result_buffer_shape, result, + std::vector(padded_result_buffer_shape.rank(), 0), + unpadded_result_buffer_shape.dimensions(), + std::vector(padded_result_buffer_shape.rank(), 1))); + } + return result; }; - if (output_lhs_non_contracting_partitions == num_partitions_ && + if (output_lhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_lhs == lhs_sharding && - ShapeSizeInBytes(hlo->operand(1)->shape()) >= - options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { - if (rhs_contracting_partitions == num_partitions_) { + ShapeSizeInBytes(rhs.base_shape()) >= + threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (rhs_contracting_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, true, false); } - if (rhs_non_contracting_partitions == num_partitions_) { + if (rhs_non_contracting_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, false, false); } - if (rhs_batch_partitions == num_partitions_) { + if (rhs_batch_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, false, true); } } - if (output_rhs_non_contracting_partitions == num_partitions_ && + if (output_rhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_rhs == rhs_sharding && - ShapeSizeInBytes(hlo->operand(0)->shape()) >= - options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { - if (lhs_contracting_partitions == num_partitions_) { + ShapeSizeInBytes(lhs.base_shape()) >= + threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (lhs_contracting_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, true, false); } - if (lhs_non_contracting_partitions == num_partitions_) { + if (lhs_non_contracting_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, false, false); } - if (lhs_batch_partitions == num_partitions_) { + if (lhs_batch_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, false, true); } } @@ -498,18 +467,18 @@ Status SpmdPartitioningVisitor::HandleDotHelper( { // Try batch-parallel by resharding one operand, and allowing all-reduce. TF_ASSIGN_OR_RETURN( - bool emitted, + HloInstruction * partitioned_dot, try_emit_output_batch_partitioned_einsum_with_reshard(true)); - if (emitted) { - return Status::OK(); + if (partitioned_dot) { + return partitioned_dot; } } // LHS and RHS have the same partitioned contracting dimensions. if (lhs_contracting_partitions == rhs_contracting_partitions && - lhs_contracting_partitions == num_partitions_) { - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); + lhs_contracting_partitions == num_partitions) { + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); // Pad both sides with zero, since NaN at one side cannot be masked by zero // on the other side. if (ShapeSizeInBytes(lhs.base_shape()) < @@ -522,100 +491,91 @@ Status SpmdPartitioningVisitor::HandleDotHelper( rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); } - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + auto ar = + lhs.state().collective_ops_creator.create_cross_partition_all_reduce( + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + (*lhs.state().next_channel_id)++); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); } // LHS and output have the same partitioned non-contracting dimensions. - if (lhs_non_contracting_partitions == num_partitions_ && - output_lhs_non_contracting_partitions == num_partitions_ && - lhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (lhs_non_contracting_partitions == num_partitions && + output_lhs_non_contracting_partitions == num_partitions && + lhs_sharding_transposed_to_match_output == output_sharding) { auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs_replicated, &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); + create_sharded_dot(lhs.hlo(), rhs_replicated, b)); + return dot; } // RHS and output have the same partitioned non-contracting dimensions. - if (rhs_non_contracting_partitions == num_partitions_ && - output_rhs_non_contracting_partitions == num_partitions_ && - rhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (rhs_non_contracting_partitions == num_partitions && + output_rhs_non_contracting_partitions == num_partitions && + rhs_sharding_transposed_to_match_output == output_sharding) { auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs_replicated, rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); + create_sharded_dot(lhs_replicated, rhs.hlo(), b)); + return dot; } // Output is batch partitioned. - if (output_batch_partitions == num_partitions_) { + if (output_batch_partitions == num_partitions) { auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - resharded_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); + resharded_rhs.hlo(), b)); + return dot; } // Output is partitioned along LHS non-contracting dimensions. - if (output_lhs_non_contracting_partitions == num_partitions_) { + if (output_lhs_non_contracting_partitions == num_partitions) { auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); - TF_ASSIGN_OR_RETURN( - auto dot, - create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + replicated_rhs.hlo(), b)); + return dot; } // Output is partitioned along RHS non-contracting dimensions. - if (output_rhs_non_contracting_partitions == num_partitions_) { + if (output_rhs_non_contracting_partitions == num_partitions) { auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), - resharded_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); + resharded_rhs.hlo(), b)); + return dot; } // Returns true if it is beneficial to reshard the operand at `operand_idx` // across the contracting dimension. const auto should_partition_contracting_dim = [&](int64 operand_idx) { - if (!hlo->sharding().IsReplicated()) { + if (!output_sharding.IsReplicated()) { return false; } if (operand_idx == 0) { // If LHS and output are replicated, we compare the cost of all-gather // on RHS vs all-reduce on the output. - return (rhs_contracting_partitions == num_partitions_) && + return (rhs_contracting_partitions == num_partitions) && lhs.sharding().IsReplicated() && - ShapeUtil::ElementsIn(hlo->operand(1)->shape()) > - ShapeUtil::ElementsIn(hlo->shape()); + ShapeUtil::ElementsIn(rhs.base_shape()) > + ShapeUtil::ElementsIn(output_base_shape); } else { - return (lhs_contracting_partitions == num_partitions_) && + return (lhs_contracting_partitions == num_partitions) && rhs.sharding().IsReplicated() && - ShapeUtil::ElementsIn(hlo->operand(0)->shape()) > - ShapeUtil::ElementsIn(hlo->shape()); + ShapeUtil::ElementsIn(lhs.base_shape()) > + ShapeUtil::ElementsIn(output_base_shape); } }; // When the output is replicated and one of the operands is partitioned along // contracting dimension, align the other operand to be partitioned along // the contracting dimensions. - if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) || + if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) || should_partition_contracting_dim(1))) { - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); if (should_partition_contracting_dim(0)) { lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); @@ -625,19 +585,361 @@ Status SpmdPartitioningVisitor::HandleDotHelper( rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); } - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo(); - }); - return Status::OK(); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + return lhs.state().collective_ops_creator.create_cross_partition_all_reduce( + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + (*lhs.state().next_channel_id)++); + } + return nullptr; +} + +StatusOr PartitionDot( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops); + +StatusOr PartitionDotGroupOnBatch( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + std::vector lhs_dims; + std::vector rhs_dims; + std::vector output_dims; + auto lhs_sharding_dims_adjusted_to_output = + lhs.sharding().tile_assignment().dimensions(); + auto rhs_sharding_dims_adjusted_to_output = + lhs.sharding().tile_assignment().dimensions(); + auto output_sharding_dims_adjusted_to_lhs = + output_sharding.tile_assignment().dimensions(); + bool lhs_rhs_dims_matching = true; + for (const auto& dim : dims_mapping.batch_dims) { + lhs_dims.push_back(dim.lhs); + rhs_dims.push_back(dim.rhs); + output_dims.push_back(dim.output); + if (lhs_sharding_dims_adjusted_to_output[dim.lhs] != + rhs_sharding_dims_adjusted_to_output[dim.rhs]) { + lhs_rhs_dims_matching = false; + } + lhs_sharding_dims_adjusted_to_output[dim.lhs] = + output_sharding.tile_assignment().dim(dim.output); + rhs_sharding_dims_adjusted_to_output[dim.rhs] = + output_sharding.tile_assignment().dim(dim.output); + output_sharding_dims_adjusted_to_lhs[dim.output] = + lhs.sharding().tile_assignment().dim(dim.lhs); + } + auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); + auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); + auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); + if (lhs_rhs_dims_matching) { + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) > + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); + rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); + } else { + lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); + lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); + } + auto reshaped_output_tiling = output_sharding.tile_assignment(); + reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs); + output_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling), + output_dims), + lhs_grouped); + } else { + auto reshaped_lhs_tiling = lhs.sharding().tile_assignment(); + reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output); + lhs_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims), + output_grouped); + lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); + auto reshaped_rhs_tiling = rhs.sharding().tile_assignment(); + reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output); + rhs_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims), + output_grouped); + rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); + } + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + lhs.state(), lhs_grouped.device_groups, b); + lhs.hlo()->set_sharding(lhs_grouped.sharding); + rhs.hlo()->set_sharding(rhs_grouped.sharding); + CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding); + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot( + PartitionedHlo(lhs.hlo(), + GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), + per_group_partitioner_state), + PartitionedHlo(rhs.hlo(), + GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), + per_group_partitioner_state), + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, dims_mapping, + num_partitions / lhs_grouped.device_groups.size(), create_sharded_dot, + module, original_hlo, threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + // Reset the LHS sharding to the ungrouped one. + lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped)); + rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped)); + dot->set_sharding(UngroupSharding(output_grouped)); + return PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +StatusOr PartitionDotGroupOnNonContracting( + bool lhs_matching, PartitionedHlo matching, PartitionedHlo other, + int64 matching_contracting_partitions, int64 other_contracting_partitions, + int64 matching_non_contracting_partitions, + int64 other_non_contracting_partitions, + int64 output_other_non_contracting_partitions, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + const bool may_replicate_other_contracting_dims = + (other_contracting_partitions == matching_non_contracting_partitions && + other_non_contracting_partitions == + output_other_non_contracting_partitions); + const bool may_replicate_other_non_contracting_dims = + matching_non_contracting_partitions == other_non_contracting_partitions && + matching_contracting_partitions == other_contracting_partitions; + std::vector other_group_dims; + if (may_replicate_other_contracting_dims && + (!may_replicate_other_non_contracting_dims || + ShapeUtil::ByteSizeOf(other.base_shape()) <= + ShapeUtil::ByteSizeOf(output_base_shape))) { + for (const auto& dim : dims_mapping.contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else if (may_replicate_other_non_contracting_dims) { + for (const auto& dim : lhs_matching + ? dims_mapping.rhs_non_contracting_dims + : dims_mapping.lhs_non_contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else { + return nullptr; + } + auto matching_sharding_dims = + matching.sharding().tile_assignment().dimensions(); + std::vector matching_dims; + std::vector output_dims; + // Make sure the partitioning on matching's non-contracting dimensions + // defines the same device groups for both matching and output. + for (const auto& dim : lhs_matching ? dims_mapping.lhs_non_contracting_dims + : dims_mapping.rhs_non_contracting_dims) { + int64 md = lhs_matching ? dim.lhs : dim.rhs; + matching_sharding_dims[md] = + output_sharding.tile_assignment().dim(dim.output); + matching_dims.push_back(md); + output_dims.push_back(dim.output); + } + auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); + auto reshaped_matching_tiling = matching.sharding().tile_assignment(); + reshaped_matching_tiling.Reshape(matching_sharding_dims); + auto matching_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_matching_tiling), + matching_dims), + output_grouped); + matching = matching.Reshard(UngroupSharding(matching_grouped)); + + auto other_grouped = + AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims), + output_grouped, /*ignore_group_order=*/true); + other = other.Reshard(UngroupSharding(other_grouped)); + auto partially_replicated_other = + other.ReplicatePartial(other_grouped.group_dims); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + matching.state(), matching_grouped.device_groups, b); + matching.hlo()->set_sharding(matching_grouped.sharding); + partially_replicated_other->set_sharding(other_grouped.sharding); + auto matching_p = PartitionedHlo( + matching.hlo(), + GetPerGroupBaseShape(matching_grouped, matching.base_shape()), + per_group_partitioner_state); + auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(), + per_group_partitioner_state); + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot(lhs_matching ? matching_p : other_p, + lhs_matching ? other_p : matching_p, + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, dims_mapping, + num_partitions / matching_grouped.device_groups.size(), + create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + // Reset matching's sharding to the ungrouped one. + matching.hlo()->set_sharding(UngroupSharding(matching_grouped)); + return dot; +} + +// Recursive partitioning function. If there are partial dimensions matching in +// the operands and output, group the devices and recursively partition the +// in-group dot. +StatusOr PartitionDot( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + const int64 lhs_batch_partitions = + get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0); + const int64 rhs_batch_partitions = + get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1); + const int64 output_batch_partitions = + get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2); + const int64 lhs_contracting_partitions = + get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0); + const int64 rhs_contracting_partitions = + get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1); + const int64 lhs_non_contracting_partitions = get_partitions_for_dims( + lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0); + const int64 rhs_non_contracting_partitions = get_partitions_for_dims( + rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1); + const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( + output_sharding, dims_mapping.lhs_non_contracting_dims, 2); + const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( + output_sharding, dims_mapping.rhs_non_contracting_dims, 2); + TF_ASSIGN_OR_RETURN( + auto try_partitioned_dot, + PartitionBaseCase( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, module, original_hlo, + lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, + lhs_contracting_partitions, rhs_contracting_partitions, + lhs_non_contracting_partitions, rhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + if (try_partitioned_dot) { + return try_partitioned_dot; } - return DefaultAction(hlo); + // Recursively partition on different types of dimensions. + // + // Case 1: Group partitions by batch. + if (lhs_batch_partitions == rhs_batch_partitions && + lhs_batch_partitions == output_batch_partitions && + lhs_batch_partitions > 1) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnBatch( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + + // Case 2: Group partitions by non-contracting dimensions. + const bool may_group_on_lhs_non_contracting = + lhs_non_contracting_partitions == output_lhs_non_contracting_partitions && + lhs_non_contracting_partitions > 1; + const bool may_group_on_rhs_non_contracting = + rhs_non_contracting_partitions == output_rhs_non_contracting_partitions && + rhs_non_contracting_partitions > 1; + if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) { + // If both match output non-contracting dimensions, choose the one which + // will result in smaller replication of the other operand. + const bool lhs_matching = + may_group_on_lhs_non_contracting && + (!may_group_on_rhs_non_contracting || + lhs_non_contracting_partitions * + ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <= + rhs_non_contracting_partitions * + ShapeUtil::ByteSizeOf(lhs.hlo()->shape())); + + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnNonContracting( + lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs, + lhs_matching ? lhs_contracting_partitions + : rhs_contracting_partitions, + lhs_matching ? rhs_contracting_partitions + : lhs_contracting_partitions, + lhs_matching ? lhs_non_contracting_partitions + : rhs_non_contracting_partitions, + lhs_matching ? rhs_non_contracting_partitions + : lhs_non_contracting_partitions, + lhs_matching ? output_rhs_non_contracting_partitions + : output_lhs_non_contracting_partitions, + output_base_shape, output_sharding, dims_mapping, num_partitions, + create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + + // Default action. + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(), + rhs.Replicate().hlo(), b)); + dot->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +} // namespace + +Status SpmdPartitioningVisitor::HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + auto& lhs = GetPartitionedHlo(hlo->operand(0)); + auto& rhs = GetPartitionedHlo(hlo->operand(1)); + TF_ASSIGN_OR_RETURN( + auto partitioned_dot, + PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping, + num_partitions_, create_sharded_dot, module_, hlo, + options_.threshold_for_windowed_einsum_mib, &b_, + &windowed_dot_general_loops_)); + SetPartitionedHlo(hlo, [&] { return partitioned_dot; }); + return Status::OK(); } namespace { @@ -780,6 +1082,7 @@ Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( [](const HloInstruction* a, const HloInstruction* b) { return a->unique_id() < b->unique_id(); }); + worklist.reserve(nullaries_to_sink.size()); for (auto inst : nullaries_to_sink) { worklist.push_back(inst); } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index bac5c812814..7aaa3e32b2a 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -165,16 +165,6 @@ template namespace { -// Returns the replica group configuration where each replica belongs to its own -// group. -std::vector CreateReplicaGroups(int64 num_replicas) { - std::vector groups(num_replicas); - for (int64 i = 0; i < num_replicas; ++i) { - groups[i].add_replica_ids(i); - } - return groups; -} - // Clears all sharding attributes from instructions in the module. This must be // called only after all SPMD transformation is complete. Status ClearShardingAttributes(HloModule* module) { @@ -195,6 +185,28 @@ Status ClearShardingAttributes(HloModule* module) { return Status::OK(); } +std::vector> GetPartitionGroupsForReplication( + const HloSharding& sharding, absl::Span replication_dims) { + int64 group_size = 1; + for (int64 i : replication_dims) { + group_size *= sharding.tile_assignment().dim(i); + } + std::vector> partition_groups( + sharding.tile_assignment().num_elements() / group_size); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 partition) { + int64 group_id = 0; + for (int64 i = 0; i < indices.size(); ++i) { + if (!absl::c_linear_search(replication_dims, i)) { + group_id *= sharding.tile_assignment().dim(i); + group_id += indices[i]; + } + } + partition_groups[group_id].push_back(partition); + }); + return partition_groups; +} + } // namespace HloInstruction* SpmdBuilder::AddInstruction( @@ -664,42 +676,57 @@ PartitionedHlo PartitionedHlo::Replicate() { } // 'Tiled' to 'Replicated'. + std::vector all_dims(shape.rank()); + std::iota(all_dims.begin(), all_dims.end(), 0); + HloInstruction* result = ReplicatePartial(all_dims); + result->set_sharding(HloSharding::Replicate()); + return update_cache(PartitionedHlo(result, base_shape_, state_)); +} + +HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span dims) { + CHECK(!sharding().IsTileMaximal()); + const Shape& shard_shape = hlo()->shape(); + Shape target_shape = shard_shape; + Shape padded_target_shape = shard_shape; + for (int64 i : dims) { + padded_target_shape.set_dimensions( + i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i)); + target_shape.set_dimensions(i, base_shape().dimensions(i)); + } + HloInstruction* result = nullptr; if (state_.collective_ops_creator.create_cross_partition_all_gather) { - result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding, - NewChannel()); - } - Shape padded_base_shape = shape; - for (int64 i = 0; i < padded_base_shape.rank(); ++i) { - padded_base_shape.set_dimensions( - i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(), + NewChannel(), dims, + state_.collective_ops_creator); } if (result == nullptr) { auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()))); + LiteralUtil::Zero(shard_shape.element_type()))); auto zero_bcast = state_.b->AddInstruction( - HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + HloInstruction::CreateBroadcast(padded_target_shape, zero, {})); + auto offsets = MakePartitionOffsets(padded_target_shape, sharding(), + state_.partition_id, state_.b, dims); auto dus = state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - padded_base_shape, zero_bcast, hlo_, - MakePartitionOffsets(padded_base_shape, sharding, - state_.partition_id, state_.b))); + padded_target_shape, zero_bcast, hlo_, offsets)); HloComputation* reduction = - MakeBinaryAdd(shape.element_type(), state_.module); + MakeBinaryAdd(shard_shape.element_type(), state_.module); auto all_reduce = state_.collective_ops_creator.create_cross_partition_all_reduce( - state_.b, dus, reduction, NewChannel()); + state_.b, dus, reduction, + GetPartitionGroupsForReplication(sharding(), dims), NewChannel()); result = all_reduce; } - if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { - std::vector start_indices(shape.rank(), 0); - std::vector strides(shape.rank(), 1); - result = state_.b->AddInstruction(HloInstruction::CreateSlice( - base_shape_, result, start_indices, base_shape_.dimensions(), strides)); + if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) { + std::vector start_indices(target_shape.rank(), 0); + std::vector strides(target_shape.rank(), 1); + result = state_.b->AddInstruction( + HloInstruction::CreateSlice(target_shape, result, start_indices, + base_shape_.dimensions(), strides)); } - result->set_sharding(HloSharding::Replicate()); - return update_cache(PartitionedHlo(result, base_shape_, state_)); + return result; } PartitionedHlo PartitionedHlo::Broadcast() const { @@ -728,7 +755,7 @@ PartitionedHlo PartitionedHlo::Broadcast() const { MakeBinaryAdd(shape.element_type(), state_.module); auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( - state_.b, operand, reduction, NewChannel()); + state_.b, operand, reduction, {}, NewChannel()); result->set_sharding(HloSharding::Replicate()); return PartitionedHlo(result, base_shape_, state_); } @@ -796,7 +823,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b); // The order of ids in the group must follow the temp_target sharding. - std::vector groups( + std::vector> groups( temp_target.tile_assignment().num_elements() / group_size); temp_target.tile_assignment().Each( [&](absl::Span indices, int64 device) { @@ -810,7 +837,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( group_id += indices[dim]; } } - groups[group_id].add_replica_ids(device); + groups[group_id].push_back(device); }); HloInstruction* result = nullptr; @@ -1027,7 +1054,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { offset += operand->shape().dimensions(dimension); } auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), + &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), {}, NewChannel()); SetPartitionedHlo(hlo, [&] { auto start_indices = @@ -2153,7 +2180,7 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { // Combine from different partitions. auto ar = collective_ops_creator_.create_cross_partition_all_reduce( &b_, filtered, - MakeBinaryAdd(filtered->shape().element_type(), module_), + MakeBinaryAdd(filtered->shape().element_type(), module_), {}, NewChannel()); ar->set_sharding(HloSharding::Replicate()); SetPartitionedHlo(hlo, [&]() { @@ -2449,7 +2476,7 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { if (reduce_sharded_dimension) { CHECK(local_reduce->shape().IsArray()); reduce = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, local_reduce, hlo->to_apply(), NewChannel()); + &b_, local_reduce, hlo->to_apply(), {}, NewChannel()); reduce->set_sharding(HloSharding::Replicate()); } else { reduce = local_reduce; @@ -2917,13 +2944,36 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, [](SpmdBuilder* b) { return b->AddInstruction(HloInstruction::CreatePartitionId()); }, - [num_replicas](SpmdBuilder* b, HloInstruction* operand, - HloComputation* reduction, int64 channel_id) { + [num_replicas, num_partitions]( + SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, + const std::vector>& partition_subgroups, + int64 channel_id) { + if (partition_subgroups.size() <= 1) { + std::vector groups(num_replicas); + // TODO(yuanzx): Unify subgroup definition with AllToAll. + for (int64 i = 0; i < num_replicas; ++i) { + groups[i].add_replica_ids(i); + } + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + } + + std::vector device_groups; + device_groups.reserve(partition_subgroups.size() * num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + for (const auto& pgroup : partition_subgroups) { + device_groups.emplace_back(); + for (int64 pid : pgroup) { + device_groups.back().add_replica_ids(i * num_partitions + pid); + } + } + } return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, - CreateReplicaGroups(num_replicas), + operand->shape(), {operand}, reduction, device_groups, /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); + /*use_global_device_ids=*/true)); }, [](SpmdBuilder* b, HloInstruction* operand, std::vector>& src_dst_pairs, @@ -2932,14 +2982,20 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, operand->shape(), operand, src_dst_pairs, channel_id)); }, [](SpmdBuilder* b, absl::Span operands, - const std::vector& replica_groups, int64 channel_id, - absl::optional split_dimension) { + const std::vector>& partition_subgroups, + int64 channel_id, absl::optional split_dimension) { std::vector shapes(operands.size(), operands[0]->shape()); const Shape output_shape = (shapes.size() == 1) ? shapes[0] : ShapeUtil::MakeTupleShape(shapes); + std::vector groups(partition_subgroups.size()); + for (int64 i = 0; i < groups.size(); ++i) { + for (int64 id : partition_subgroups[i]) { + groups[i].add_replica_ids(id); + } + } return b->AddInstruction(HloInstruction::CreateAllToAll( - output_shape, operands, replica_groups, + output_shape, operands, groups, /*constrain_layout=*/false, channel_id, split_dimension)); }, [num_replicas, num_partitions]( @@ -2970,10 +3026,10 @@ SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, num_partitions, num_replicas, std::move(options), GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {} -HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, - HloInstruction* operand, - const HloSharding& sharding, - int64 channel_id) { +HloInstruction* SpmdPartitioner::AllGatherShards( + SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, + int64 channel_id, absl::Span selected_dims, + const SPMDCollectiveOpsCreator& collectives_creator) { CHECK(!sharding.IsTileMaximal()); // Add one leading dimension to gather all partitions. std::vector shape; @@ -2983,18 +3039,17 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, } auto reshape = b->AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand)); - std::vector> partition_subgroups(1); - for (int64 pid : sharding.tile_assignment()) { - partition_subgroups[0].push_back(pid); - } - shape[0] = sharding.tile_assignment().num_elements(); - auto result = collective_ops_creator_.create_cross_partition_all_gather( + auto partition_subgroups = + GetPartitionGroupsForReplication(sharding, selected_dims); + shape[0] = partition_subgroups[0].size(); + auto result = collectives_creator.create_cross_partition_all_gather( b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape), partition_subgroups, channel_id, /*all_gather_dimension=*/0); // If n > 1 dimensions are partitioned, split the leading dimension to n. std::vector tiled_dims; for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { - if (sharding.tile_assignment().dim(i) > 1) { + if (sharding.tile_assignment().dim(i) > 1 && + absl::c_linear_search(selected_dims, i)) { tiled_dims.push_back(i); } } @@ -3016,7 +3071,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, std::vector xpose_permutation(result->shape().rank()); int64 split_dims_added = 0; for (int64 i = 0; i < xpose_permutation.size(); ++i) { - if (sharding.tile_assignment().dim(i - split_dims_added) == 1) { + if (sharding.tile_assignment().dim(i - split_dims_added) == 1 || + !absl::c_linear_search(selected_dims, i - split_dims_added)) { xpose_permutation[i] = i + tiled_dims.size() - split_dims_added; } else { xpose_permutation[i] = split_dims_added; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index 606a7ae5f14..d844ac3af1f 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -82,8 +83,10 @@ struct SPMDCollectiveOpsCreator { std::function create_partition_id; // Function used to create a cross-partition all-reduce HLO. - std::function + std::function>& partition_subgroups, + int64 channel_id)> create_cross_partition_all_reduce; // Function used to create a cross-partition collective-permute HLO. @@ -96,8 +99,8 @@ struct SPMDCollectiveOpsCreator { // Function used to create a cross-partition all-to-all HLO. std::function operands, - const std::vector& replica_groups, int64 channel_id, - absl::optional split_dimension)> + const std::vector>& partition_subgroups, + int64 channel_id, absl::optional split_dimension)> create_cross_partition_all_to_all; // Function used to create a cross-partition all-gather HLO. This is optional: @@ -169,10 +172,13 @@ class SpmdPartitioner : public HloModulePass { // The default uses a single all-gather even if there are multiple sharded // dimensions, and adds potential reshapes and transposes to achieve that. // If it returns false, the partitioner will fall back to all-reduce. - virtual HloInstruction* AllGatherShards(SpmdBuilder* b, - HloInstruction* operand, - const HloSharding& sharding, - int64 channel_id); + // `selected_dims` specifies the dimensions along which the all-gather happens + // in the tiled sharding, which allows potentially creating a subgroup + // all-gather. + virtual HloInstruction* AllGatherShards( + SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, + int64 channel_id, absl::Span selected_dims, + const SPMDCollectiveOpsCreator& collectives_creator); protected: virtual std::unique_ptr CreateVisitor( @@ -215,7 +221,12 @@ class PartitionedHlo { std::tuple> window_reshard_cache; }; + // Use std::unordered_map for pointer stability. std::unordered_map per_hlo_cache; + // Caches for nested partitioning of grouped sharding. Each string key + // represents a unique way of grouping devices. + absl::flat_hash_map> + groupd_caches; }; struct PartitioningState { SpmdBuilder* b; @@ -270,15 +281,18 @@ class PartitionedHlo { const PartitioningState& state() const { return state_; } + // Helper function to replicate the data on all devices. Could only modify + // the reshard cache. + PartitionedHlo Replicate(); + + // Helper function to replicate the data for partitions along the given dims. + HloInstruction* ReplicatePartial(absl::Span dims); + private: // Same as Reshard except that it does not explicitly modify the reshard // cache, although it would indirectly modify by calling Replicate(). PartitionedHlo ReshardNoCache(const HloSharding& target); - // Helper function to replicate the data on all devices. Could only modify - // the reshard cache. - PartitionedHlo Replicate(); - // Helper function to broadcast data from a single device to all devices. PartitionedHlo Broadcast() const; @@ -417,6 +431,16 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { StatusOr DoPartition(HloComputation* computation, const HloSharding& root_sharding); + // Information about a loop created for windowed dot-general. Used when + // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor + // finishes traversing the graph. + struct WindowedDotGeneralLoop { + HloInstruction* while_loop; + int64 windowed_operand; + bool windowed_in_contracting_dims; + bool windowed_in_batch_dims; + }; + private: Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; @@ -445,15 +469,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // partitioned instruction. ConstHloInstructionMap partitioned_instructions_; - // Information about a loop created for windowed dot-general. Used when - // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor - // finishes traversing the graph. - struct WindowedDotGeneralLoop { - HloInstruction* while_loop; - int64 windowed_operand; - bool windowed_in_contracting_dims; - bool windowed_in_batch_dims; - }; std::vector windowed_dot_general_loops_; HloInstruction* visiting_hlo_; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 1045d1187b8..5f3fd8d53e7 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -2218,7 +2218,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); - std::cout << module->ToString(); + VLOG(1) << module->ToString(); auto sort = FindInstruction(module.get(), "sort"); EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); @@ -2294,7 +2294,7 @@ ENTRY entry TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); - std::cout << module->ToString(); + VLOG(1) << module->ToString(); auto sort = FindInstruction(module.get(), "sort"); EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); @@ -3842,6 +3842,154 @@ ENTRY entry { EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2))); } +TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3} + %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,1,2,3} + ROOT %dot = f32[48,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0)); + auto partial_replicated_lhs = + AllOf(op::Shape("f32[24,12]"), + op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _))); + auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[16,12]"), op::AllReduce(op::DynamicUpdateSlice( + _, op::CollectivePermute(rhs), _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs), + op::Shape("f32[24,16]"))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3} + %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3} + ROOT %dot = f32[48,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[32,50]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[24,16]"), + op::DynamicSlice( + op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs), + op::Shape("f32[24,32]"))), + _, _))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[2,32,100]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"), + op::Dot(lhs, partial_replicated_rhs))); +} + +TEST_F(SpmdPartitioningTest, + Dot2DPartitionedBatchNonContractingAndContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3} + %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1)); + auto partial_replicated_lhs = + AllOf(op::Shape("f32[2,24,100]"), + op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"), + op::Dot(partial_replicated_lhs, rhs))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3} + %rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3} + ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={3}, + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[2,8,32,100]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _))); + auto dot = + AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs)); + auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot)); + auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape)); + auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 6beed5a15e5..454a1da4646 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -16,7 +16,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" #include +#include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -143,10 +148,10 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, return partition_shape; } -std::vector MakePartitionOffsets(const Shape& shape, - const HloSharding& sharding, - HloInstruction* partition_id, - SpmdBuilder* b) { +std::vector MakePartitionOffsets( + const Shape& shape, const HloSharding& sharding, + HloInstruction* partition_id, SpmdBuilder* b, + absl::Span dims) { CHECK(!shape.IsTuple()); Array2D offset_array( @@ -158,7 +163,8 @@ std::vector MakePartitionOffsets(const Shape& shape, LiteralUtil::CreateR2FromArray2D(offset_array))); std::vector offsets; for (int64 i = 0; i < shape.rank(); ++i) { - if (sharding.tile_assignment().dim(i) == 1) { + if (sharding.tile_assignment().dim(i) == 1 || + (!dims.empty() && !absl::c_linear_search(dims, i))) { offsets.push_back(b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); } else { @@ -978,5 +984,255 @@ bool CanReshardWithCollectivePermute(const HloSharding& source, source.tile_assignment() != target.tile_assignment(); } +GroupedSharding GroupShardingOnDims(const HloSharding& sharding, + absl::Span group_dims) { + CHECK(!sharding.IsTileMaximal()); + std::vector grouped_tiling_dims = + sharding.tile_assignment().dimensions(); + std::vector group_dim_sizes(group_dims.size()); + for (int64 i = 0; i < group_dims.size(); ++i) { + group_dim_sizes[i] = grouped_tiling_dims[group_dims[i]]; + grouped_tiling_dims[group_dims[i]] = 1; + } + std::vector> device_groups(Product(group_dim_sizes)); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + int64 group_id = 0; + for (int64 dim : group_dims) { + group_id *= sharding.tile_assignment().dim(dim); + group_id += indices[dim]; + } + device_groups[group_id].push_back(device); + }); + Array grouped_tiling(grouped_tiling_dims); + grouped_tiling.FillIota(0); + return GroupedSharding( + std::move(device_groups), + std::vector(group_dims.begin(), group_dims.end()), + std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), + HloSharding::Tile(grouped_tiling)); +} + +HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { + CHECK(!grouped_sharding.sharding.IsTileMaximal()); + std::vector tiling_dims = + grouped_sharding.sharding.tile_assignment().dimensions(); + for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { + tiling_dims[grouped_sharding.group_dims[i]] = + grouped_sharding.group_dim_sizes[i]; + } + Array tiling(tiling_dims); + grouped_sharding.sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + std::vector ungrouped_inds(indices.begin(), indices.end()); + for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { + int64 remaining_group_index = g; + for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) { + ungrouped_inds[grouped_sharding.group_dims[i]] = + remaining_group_index % grouped_sharding.group_dim_sizes[i]; + remaining_group_index /= grouped_sharding.group_dim_sizes[i]; + } + tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; + } + }); + return HloSharding::Tile(tiling); +} + +GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding, + const GroupedSharding& reference, + bool ignore_group_order) { + // Returns src -> dst index mapping. + auto get_permutation = [](absl::Span src, + absl::Span dst) { + CHECK_EQ(src.size(), dst.size()); + absl::flat_hash_map dst_reverse_map; + for (int64 i = 0; i < dst.size(); ++i) { + dst_reverse_map[dst[i]] = i; + } + std::vector permutation(src.size()); + for (int64 i = 0; i < src.size(); ++i) { + auto it = dst_reverse_map.find(src[i]); + CHECK(it != dst_reverse_map.end()); + permutation[i] = it->second; + } + return permutation; + }; + CHECK_EQ(grouped_sharding.device_groups.size(), + reference.device_groups.size()); + absl::flat_hash_map device_to_ref_group; + for (int64 g = 0; g < reference.device_groups.size(); ++g) { + for (int64 device : reference.device_groups[g]) { + device_to_ref_group[device] = g; + } + } + auto unique_ref_dev_group = [&](absl::Span devices) -> int64 { + int64 ref_g = -1; + for (int64 device : devices) { + if (ref_g == -1) { + ref_g = device_to_ref_group[device]; + } else if (ref_g != device_to_ref_group[device]) { + return -1; + } + } + return ref_g; + }; + bool matching_groups = true; + std::vector original_src_to_ref_permutation; + for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { + int64 ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]); + if (ref_g < 0 || (!ignore_group_order && g != ref_g)) { + matching_groups = false; + break; + } + if (g == 0) { + original_src_to_ref_permutation = get_permutation( + grouped_sharding.device_groups[g], reference.device_groups[ref_g]); + } + } + if (matching_groups) { + auto tiles = grouped_sharding.sharding.tile_assignment(); + tiles.Each([&](absl::Span indices, int64* device) { + *device = original_src_to_ref_permutation[*device]; + }); + grouped_sharding.sharding = HloSharding::Tile(tiles); + } + grouped_sharding.device_groups = std::move(reference.device_groups); + return grouped_sharding; +} + +Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, + const Shape& original_base_shape) { + auto result = original_base_shape; + for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { + int64 dim = grouped_sharding.group_dims[i]; + int64 groups = grouped_sharding.group_dim_sizes[i]; + result.set_dimensions(dim, result.dimensions(dim) / groups); + } + return result; +} + +namespace { + +HloInstruction* GetInGroupPartitionId( + HloInstruction* partition_id, + const std::vector>& device_groups, SpmdBuilder* b) { + int64 total_devices = device_groups.size() * device_groups[0].size(); + std::vector in_group_ids(total_devices); + for (uint32 i = 0; i < device_groups.size(); ++i) { + for (uint32 j = 0; j < device_groups[i].size(); ++j) { + in_group_ids[device_groups[i][j]] = j; + } + } + auto id_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(in_group_ids))); + return b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeScalarShape(U32), + b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1})))); +} + +SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator( + const SPMDCollectiveOpsCreator& creator, + const std::vector>& device_groups) { + SPMDCollectiveOpsCreator result; + result.create_partition_id = [creator, device_groups](SpmdBuilder* b) { + return GetInGroupPartitionId(creator.create_partition_id(b), device_groups, + b); + }; + auto expand_partition_groups = + [device_groups]( + const std::vector>& partition_subgroups) { + if (partition_subgroups.empty()) { + return device_groups; + } + std::vector> result(partition_subgroups.size() * + device_groups.size()); + for (int64 g = 0; g < device_groups.size(); ++g) { + for (int64 i = 0; i < partition_subgroups.size(); ++i) { + result[g * partition_subgroups.size() + i].resize( + partition_subgroups[i].size()); + for (int64 j = 0; j < partition_subgroups[i].size(); ++j) { + result[g * partition_subgroups.size() + i][j] = + device_groups[g][partition_subgroups[i][j]]; + } + } + } + return result; + }; + result.create_cross_partition_all_reduce = + [creator, expand_partition_groups]( + SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, + const std::vector>& partition_subgroups, + int64 channel_id) { + return creator.create_cross_partition_all_reduce( + b, operand, reduction, expand_partition_groups(partition_subgroups), + channel_id); + }; + result.create_cross_partition_collective_permute = + [creator, device_groups]( + SpmdBuilder* b, HloInstruction* operand, + std::vector>& src_dst_pairs, + int64 next_channel_id) { + std::vector> expanded_pairs( + src_dst_pairs.size() * device_groups.size()); + for (int64 g = 0; g < device_groups.size(); ++g) { + for (int64 i = 0; i < src_dst_pairs.size(); ++i) { + expanded_pairs[g * src_dst_pairs.size() + i] = + std::pair{ + device_groups[g][src_dst_pairs[i].first], + device_groups[g][src_dst_pairs[i].second]}; + } + } + return creator.create_cross_partition_collective_permute( + b, operand, expanded_pairs, next_channel_id); + }; + result.create_cross_partition_all_to_all = + [creator, expand_partition_groups]( + SpmdBuilder* b, absl::Span operands, + const std::vector>& partition_subgroups, + int64 channel_id, absl::optional split_dimension) { + return creator.create_cross_partition_all_to_all( + b, operands, expand_partition_groups(partition_subgroups), + channel_id, split_dimension); + }; + if (creator.create_cross_partition_all_gather) { + result.create_cross_partition_all_gather = + [creator, expand_partition_groups]( + SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape, + const std::vector>& partition_subgroups, + int64 channel_id, int64 all_gather_dimension) { + return creator.create_cross_partition_all_gather( + b, operand, ag_shape, + expand_partition_groups(partition_subgroups), channel_id, + all_gather_dimension); + }; + } + return result; +} + +} // namespace + +PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( + const PartitionedHlo::PartitioningState& state, + const std::vector>& device_groups, SpmdBuilder* b) { + auto result = state; + result.collective_ops_creator = GetPerGroupCollectiveOpsCreator( + state.collective_ops_creator, device_groups); + result.partition_id = + GetInGroupPartitionId(state.partition_id, device_groups, b); + // Create a string key for the groups. + std::vector per_group_strings(device_groups.size()); + for (int64 i = 0; i < per_group_strings.size(); ++i) { + per_group_strings[i] = absl::StrJoin(device_groups[i], ","); + } + auto& grouped_cache = + state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")]; + if (!grouped_cache) { + grouped_cache = absl::make_unique(); + } + result.reshard_cache = grouped_cache.get(); + return result; +} + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index 7b737daf78c..6e68375f9b9 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -87,10 +87,12 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, // Generates the HLO instructions that represent the dimension offsets on any // device. The size of the returned vector is the rank of the given shape. -std::vector MakePartitionOffsets(const Shape& shape, - const HloSharding& sharding, - HloInstruction* partition_id, - SpmdBuilder* b); +// If `dims` is non-empty, the generated offsets will only be non-zero for those +// dimensions. +std::vector MakePartitionOffsets( + const Shape& shape, const HloSharding& sharding, + HloInstruction* partition_id, SpmdBuilder* b, + absl::Span dims = {}); // Returns the offsets of the partition in the tile assignment. std::vector MakeTiledPartitionOrdinals( @@ -276,6 +278,48 @@ GetReshardAllToAllSourceTargetDims(const HloSharding& source, bool CanReshardWithCollectivePermute(const HloSharding& source, const HloSharding& target); +// Represents grouping devices in a tiled sharding along certain dimensions. +// Elements in group dimensions define different device groups, and the sharding +// represents the in-group sharding. +struct GroupedSharding { + GroupedSharding(std::vector> device_groups, + std::vector group_dims, + std::vector group_dim_sizes, int64 rank, + HloSharding grouped_sharding) + : device_groups(std::move(device_groups)), + group_dims(std::move(group_dims)), + group_dim_sizes(std::move(group_dim_sizes)), + sharding(std::move(grouped_sharding)) {} + std::vector> device_groups; + std::vector group_dims; + std::vector group_dim_sizes; + int64 rank; + HloSharding sharding; +}; + +// Creates a GroupedSharding for a tiled sharding. +GroupedSharding GroupShardingOnDims(const HloSharding& sharding, + absl::Span group_dims); + +// Reconstructs the ungrouped sharding from a GroupedSharding. +HloSharding UngroupSharding(const GroupedSharding& grouped_sharding); + +// Returns a new GroupedSharding that has the same group definition of +// `reference`. +GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding, + const GroupedSharding& reference, + bool ignore_group_order = false); + +// Returns the per-group base shape, i.e., before applying the in-group +// sharding. +Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, + const Shape& original_base_shape); + +// Creates the nested partitioner state for in-group patitioning. +PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( + const PartitionedHlo::PartitioningState& state, + const std::vector>& device_groups, SpmdBuilder* b); + } // namespace spmd } // namespace xla