Fix build and resubmit: [XLA:SPMD] Recursively handling more Dot cases
PiperOrigin-RevId: 322949994 Change-Id: I44a8a8e958a7ba4995a667d139f793dfa3a4fe7f
This commit is contained in:
parent
8d4711c52d
commit
0b5cc6f1b9
@ -50,6 +50,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/core/platform:numbers",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -226,7 +226,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
hlo->batch_group_count(), new_window,
|
||||
hlo->convolution_dimension_numbers(), hlo->precision_config()));
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
|
||||
@ -605,7 +605,7 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
hlo->batch_group_count(), new_window, dnums,
|
||||
hlo->precision_config()));
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
|
||||
|
@ -80,12 +80,25 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
|
||||
return HandleDotHelper(hlo, mapping, create_sharded_dot);
|
||||
}
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
|
||||
namespace {
|
||||
|
||||
StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
|
||||
const HloSharding& lhs_sharding = hlo->operand(0)->sharding();
|
||||
const HloSharding& rhs_sharding = hlo->operand(1)->sharding();
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
|
||||
int64 rhs_batch_partitions, int64 output_batch_partitions,
|
||||
int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
|
||||
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
|
||||
int64 output_lhs_non_contracting_partitions,
|
||||
int64 output_rhs_non_contracting_partitions,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
const HloSharding& lhs_sharding = lhs.sharding();
|
||||
const HloSharding& rhs_sharding = rhs.sharding();
|
||||
|
||||
// Similar to hlo_sharding_util::TransposeSharding(), but allows
|
||||
// removing/adding non-partitioned dimensions.
|
||||
@ -132,12 +145,12 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
return HloSharding::Tile(reshape_tiles);
|
||||
};
|
||||
|
||||
std::vector<int64> lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1);
|
||||
std::vector<int64> lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1);
|
||||
std::vector<int64> rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1);
|
||||
std::vector<int64> rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1);
|
||||
std::vector<int64> output_to_lhs_indices(hlo->shape().rank(), -1);
|
||||
std::vector<int64> output_to_rhs_indices(hlo->shape().rank(), -1);
|
||||
std::vector<int64> lhs_to_rhs_indices(lhs.base_shape().rank(), -1);
|
||||
std::vector<int64> lhs_to_output_indices(lhs.base_shape().rank(), -1);
|
||||
std::vector<int64> rhs_to_lhs_indices(rhs.base_shape().rank(), -1);
|
||||
std::vector<int64> rhs_to_output_indices(rhs.base_shape().rank(), -1);
|
||||
std::vector<int64> output_to_lhs_indices(output_base_shape.rank(), -1);
|
||||
std::vector<int64> output_to_rhs_indices(output_base_shape.rank(), -1);
|
||||
auto populate_indices_mapping =
|
||||
[&](const DotGeneralDimsMapping::DimsMapping& mapping) {
|
||||
if (mapping.lhs >= 0) {
|
||||
@ -174,127 +187,84 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
auto rhs_sharding_transposed_to_match_output = transpose_sharding(
|
||||
rhs_sharding, rhs_to_output_indices, output_to_rhs_indices);
|
||||
auto output_sharding_transposed_to_match_lhs = transpose_sharding(
|
||||
hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices);
|
||||
output_sharding, output_to_lhs_indices, lhs_to_output_indices);
|
||||
auto output_sharding_transposed_to_match_rhs = transpose_sharding(
|
||||
hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices);
|
||||
output_sharding, output_to_rhs_indices, rhs_to_output_indices);
|
||||
|
||||
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
|
||||
auto get_partitions_for_dims =
|
||||
[&](const HloSharding& sharding,
|
||||
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
|
||||
int lhs_rhs_or_output) {
|
||||
int64 partitions = 1;
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return partitions;
|
||||
}
|
||||
for (const auto& dim : dims) {
|
||||
if (lhs_rhs_or_output == 0) {
|
||||
partitions *= sharding.tile_assignment().dim(dim.lhs);
|
||||
} else if (lhs_rhs_or_output == 1) {
|
||||
partitions *= sharding.tile_assignment().dim(dim.rhs);
|
||||
} else {
|
||||
CHECK_EQ(lhs_rhs_or_output, 2);
|
||||
partitions *= sharding.tile_assignment().dim(dim.output);
|
||||
}
|
||||
}
|
||||
return partitions;
|
||||
};
|
||||
const int64 lhs_batch_partitions =
|
||||
get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0);
|
||||
const int64 rhs_batch_partitions =
|
||||
get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1);
|
||||
const int64 output_batch_partitions =
|
||||
get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2);
|
||||
const int64 lhs_contracting_partitions =
|
||||
get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0);
|
||||
const int64 rhs_contracting_partitions =
|
||||
get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1);
|
||||
const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0);
|
||||
const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1);
|
||||
const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2);
|
||||
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2);
|
||||
|
||||
auto& lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto& rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
// LHS and RHS are partitioned the same way and only partitioned in batch
|
||||
// dimensions.
|
||||
if (lhs_batch_partitions == rhs_batch_partitions &&
|
||||
rhs_batch_partitions == num_partitions_ &&
|
||||
rhs_batch_partitions == num_partitions &&
|
||||
lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
|
||||
return PartitionedHlo(dot, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding())
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
|
||||
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
.hlo();
|
||||
}
|
||||
|
||||
// Try emit batch-partitioned einsum with one operand resharded. Returns
|
||||
// whether the attempt succeeds. If may_reshard_with_allreduce is false,
|
||||
// reshard must be done using all-to-all; otherwise this attempt fails.
|
||||
// partitioned HLO or nullptr if the attempt fails. If
|
||||
// may_reshard_with_allreduce is false, reshard must be done using
|
||||
// all-to-all/collective-permute; otherwise this attempt fails.
|
||||
auto try_emit_output_batch_partitioned_einsum_with_reshard =
|
||||
[&](bool may_reshard_with_allreduce) -> StatusOr<bool> {
|
||||
[&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
|
||||
// LHS and output are batch partitioned in the same way.
|
||||
if (lhs_batch_partitions == num_partitions_ &&
|
||||
output_batch_partitions == num_partitions_ &&
|
||||
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
||||
if (lhs_batch_partitions == num_partitions &&
|
||||
output_batch_partitions == num_partitions &&
|
||||
lhs_sharding_transposed_to_match_output == output_sharding) {
|
||||
if (!may_reshard_with_allreduce &&
|
||||
!CanReshardWithCollectivePermute(
|
||||
rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
|
||||
!GetReshardAllToAllSourceTargetDims(
|
||||
rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return true;
|
||||
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// RHS and output are batch partitioned in the same way.
|
||||
if (rhs_batch_partitions == num_partitions_ &&
|
||||
output_batch_partitions == num_partitions_ &&
|
||||
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
||||
if (rhs_batch_partitions == num_partitions &&
|
||||
output_batch_partitions == num_partitions &&
|
||||
rhs_sharding_transposed_to_match_output == output_sharding) {
|
||||
if (!may_reshard_with_allreduce &&
|
||||
!CanReshardWithCollectivePermute(
|
||||
lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
|
||||
!GetReshardAllToAllSourceTargetDims(
|
||||
lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return true;
|
||||
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
return false;
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
{
|
||||
// Try batch-parallel by resharding one operand, and not using all-reduce.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool emitted,
|
||||
HloInstruction * partitioned_dot,
|
||||
try_emit_output_batch_partitioned_einsum_with_reshard(false));
|
||||
if (emitted) {
|
||||
return Status::OK();
|
||||
if (partitioned_dot) {
|
||||
return partitioned_dot;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to emit windowed DotGeneral when one operand is partitioned in the same
|
||||
// way as the output along non-contracting dimensions, but the other operand
|
||||
// is tiled in other dimensions.
|
||||
auto emit_windowed_dot_general = [&](int64 matching_operand,
|
||||
int64 windowing_operand,
|
||||
bool windowed_at_contracting_dims,
|
||||
bool windowed_at_batch_dims) {
|
||||
auto emit_windowed_dot_general =
|
||||
[&](int64 matching_operand, int64 windowing_operand,
|
||||
bool windowed_at_contracting_dims,
|
||||
bool windowed_at_batch_dims) -> StatusOr<HloInstruction*> {
|
||||
CHECK_EQ(matching_operand + windowing_operand, 1);
|
||||
CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims);
|
||||
auto unpadded_result_buffer_shape =
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding());
|
||||
MakePartitionedShape(output_base_shape, output_sharding);
|
||||
auto padded_result_buffer_shape = unpadded_result_buffer_shape;
|
||||
// For windowing at batch/non-contracting dims, we produce the result one
|
||||
// partition at a time, so we need to pad the shape in case of uneven
|
||||
@ -310,17 +280,17 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
if (windowed_at_contracting_dims) {
|
||||
auto& to_mask = windowing_operand == 0 ? lhs : rhs;
|
||||
to_mask =
|
||||
to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type()))));
|
||||
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(output_base_shape.element_type()))));
|
||||
}
|
||||
auto result_buffer = CreateZero(padded_result_buffer_shape, &b_);
|
||||
auto iteration = b_.AddInstruction(
|
||||
auto result_buffer = CreateZero(padded_result_buffer_shape, b);
|
||||
auto iteration = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
|
||||
|
||||
// Create a while loop that computes one window per iteration. During each
|
||||
// iteration, each partition sends its input window to its neighbor using
|
||||
// collective-permute for the next iteration.
|
||||
SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_);
|
||||
SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
|
||||
auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0,
|
||||
ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
|
||||
@ -335,11 +305,12 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
auto i = body_b.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3));
|
||||
|
||||
auto partition_id = collective_ops_creator_.create_partition_id(&body_b);
|
||||
auto partition_id =
|
||||
lhs.state().collective_ops_creator.create_partition_id(&body_b);
|
||||
auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
i->shape(), HloOpcode::kAdd, i, partition_id));
|
||||
auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions_)));
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions)));
|
||||
data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
|
||||
i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count));
|
||||
auto dot_lhs = l;
|
||||
@ -350,7 +321,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
// operand as replicated, and resharding it to match the windowed operand.
|
||||
auto slice_operand = matching_operand == 0 ? l : r;
|
||||
slice_operand->set_sharding(HloSharding::Replicate());
|
||||
auto state = MakePartitioningState();
|
||||
auto state = lhs.state();
|
||||
state.b = &body_b;
|
||||
state.partition_id = data_partition_id;
|
||||
auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
|
||||
@ -392,26 +363,27 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), i,
|
||||
body_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions_))),
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions))),
|
||||
ComparisonDirection::kLt));
|
||||
// Collective-permute for the next window. We don't need it for the last
|
||||
// iteration, so we use a conditional around the collective-permute.
|
||||
HloInstruction* conditional;
|
||||
{
|
||||
SpmdBuilder cp_b("window_collective_permute", visiting_hlo_);
|
||||
SpmdBuilder cp_b("window_collective_permute", original_hlo);
|
||||
{
|
||||
auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, windowing_operand == 0 ? l->shape() : r->shape(), "window"));
|
||||
std::vector<std::pair<int64, int64>> sd_pairs(num_partitions_);
|
||||
for (int64 source = 0; source < num_partitions_; ++source) {
|
||||
std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
|
||||
for (int64 source = 0; source < num_partitions; ++source) {
|
||||
// 0 -> n-1, 1 -> 0, 2 -> 1, ...
|
||||
sd_pairs[source] = {source,
|
||||
(source - 1 + num_partitions_) % num_partitions_};
|
||||
(source - 1 + num_partitions) % num_partitions};
|
||||
}
|
||||
collective_ops_creator_.create_cross_partition_collective_permute(
|
||||
&cp_b, p, sd_pairs, (*next_channel_id_)++);
|
||||
lhs.state()
|
||||
.collective_ops_creator.create_cross_partition_collective_permute(
|
||||
&cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
|
||||
}
|
||||
SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_);
|
||||
SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
|
||||
{
|
||||
ncp_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, windowing_operand == 0 ? l->shape() : r->shape(), "window"));
|
||||
@ -419,9 +391,9 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
|
||||
windowing_operand == 0 ? l->shape() : r->shape(), has_more,
|
||||
windowing_operand == 0 ? l : r,
|
||||
module_->AddEmbeddedComputation(cp_b.Build()),
|
||||
module->AddEmbeddedComputation(cp_b.Build()),
|
||||
windowing_operand == 0 ? l : r,
|
||||
module_->AddEmbeddedComputation(ncp_b.Build())));
|
||||
module->AddEmbeddedComputation(ncp_b.Build())));
|
||||
}
|
||||
if (windowing_operand == 0) {
|
||||
l = conditional;
|
||||
@ -430,7 +402,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
}
|
||||
body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i}));
|
||||
|
||||
SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_);
|
||||
SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
|
||||
auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0,
|
||||
ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
|
||||
@ -441,56 +413,53 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
cond_b.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::MakeShape(PRED, {}), cond_i,
|
||||
cond_b.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions_))),
|
||||
LiteralUtil::CreateR0<uint32>(num_partitions))),
|
||||
ComparisonDirection::kLt));
|
||||
auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile(
|
||||
cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()),
|
||||
module_->AddEmbeddedComputation(body_b.Build()),
|
||||
b_.AddInstruction(HloInstruction::CreateTuple(
|
||||
auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
|
||||
cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
|
||||
module->AddEmbeddedComputation(body_b.Build()),
|
||||
b->AddInstruction(HloInstruction::CreateTuple(
|
||||
{lhs.hlo(), rhs.hlo(), result_buffer, iteration}))));
|
||||
windowed_dot_general_loops_.push_back({while_loop, windowing_operand,
|
||||
windowed_dot_general_loops->push_back({while_loop, windowing_operand,
|
||||
windowed_at_contracting_dims,
|
||||
windowed_at_batch_dims});
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
result_buffer->shape(), while_loop, 2));
|
||||
if (!ShapeUtil::Compatible(padded_result_buffer_shape,
|
||||
unpadded_result_buffer_shape)) {
|
||||
result = b_.AddInstruction(HloInstruction::CreateSlice(
|
||||
unpadded_result_buffer_shape, result,
|
||||
std::vector<int64>(padded_result_buffer_shape.rank(), 0),
|
||||
unpadded_result_buffer_shape.dimensions(),
|
||||
std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
|
||||
}
|
||||
return result;
|
||||
});
|
||||
return Status::OK();
|
||||
auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
result_buffer->shape(), while_loop, 2));
|
||||
if (!ShapeUtil::Compatible(padded_result_buffer_shape,
|
||||
unpadded_result_buffer_shape)) {
|
||||
result = b->AddInstruction(HloInstruction::CreateSlice(
|
||||
unpadded_result_buffer_shape, result,
|
||||
std::vector<int64>(padded_result_buffer_shape.rank(), 0),
|
||||
unpadded_result_buffer_shape.dimensions(),
|
||||
std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
|
||||
}
|
||||
return result;
|
||||
};
|
||||
if (output_lhs_non_contracting_partitions == num_partitions_ &&
|
||||
if (output_lhs_non_contracting_partitions == num_partitions &&
|
||||
output_sharding_transposed_to_match_lhs == lhs_sharding &&
|
||||
ShapeSizeInBytes(hlo->operand(1)->shape()) >=
|
||||
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (rhs_contracting_partitions == num_partitions_) {
|
||||
ShapeSizeInBytes(rhs.base_shape()) >=
|
||||
threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (rhs_contracting_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(0, 1, true, false);
|
||||
}
|
||||
if (rhs_non_contracting_partitions == num_partitions_) {
|
||||
if (rhs_non_contracting_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(0, 1, false, false);
|
||||
}
|
||||
if (rhs_batch_partitions == num_partitions_) {
|
||||
if (rhs_batch_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(0, 1, false, true);
|
||||
}
|
||||
}
|
||||
if (output_rhs_non_contracting_partitions == num_partitions_ &&
|
||||
if (output_rhs_non_contracting_partitions == num_partitions &&
|
||||
output_sharding_transposed_to_match_rhs == rhs_sharding &&
|
||||
ShapeSizeInBytes(hlo->operand(0)->shape()) >=
|
||||
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (lhs_contracting_partitions == num_partitions_) {
|
||||
ShapeSizeInBytes(lhs.base_shape()) >=
|
||||
threshold_for_windowed_einsum_mib * 1024 * 1024) {
|
||||
if (lhs_contracting_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(1, 0, true, false);
|
||||
}
|
||||
if (lhs_non_contracting_partitions == num_partitions_) {
|
||||
if (lhs_non_contracting_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(1, 0, false, false);
|
||||
}
|
||||
if (lhs_batch_partitions == num_partitions_) {
|
||||
if (lhs_batch_partitions == num_partitions) {
|
||||
return emit_windowed_dot_general(1, 0, false, true);
|
||||
}
|
||||
}
|
||||
@ -498,18 +467,18 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
{
|
||||
// Try batch-parallel by resharding one operand, and allowing all-reduce.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool emitted,
|
||||
HloInstruction * partitioned_dot,
|
||||
try_emit_output_batch_partitioned_einsum_with_reshard(true));
|
||||
if (emitted) {
|
||||
return Status::OK();
|
||||
if (partitioned_dot) {
|
||||
return partitioned_dot;
|
||||
}
|
||||
}
|
||||
|
||||
// LHS and RHS have the same partitioned contracting dimensions.
|
||||
if (lhs_contracting_partitions == rhs_contracting_partitions &&
|
||||
lhs_contracting_partitions == num_partitions_) {
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
lhs_contracting_partitions == num_partitions) {
|
||||
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(output_base_shape.element_type())));
|
||||
// Pad both sides with zero, since NaN at one side cannot be masked by zero
|
||||
// on the other side.
|
||||
if (ShapeSizeInBytes(lhs.base_shape()) <
|
||||
@ -522,100 +491,91 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
rhs =
|
||||
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding())
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
|
||||
auto ar =
|
||||
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
|
||||
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
|
||||
(*lhs.state().next_channel_id)++);
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
.hlo();
|
||||
}
|
||||
|
||||
// LHS and output have the same partitioned non-contracting dimensions.
|
||||
if (lhs_non_contracting_partitions == num_partitions_ &&
|
||||
output_lhs_non_contracting_partitions == num_partitions_ &&
|
||||
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
||||
if (lhs_non_contracting_partitions == num_partitions &&
|
||||
output_lhs_non_contracting_partitions == num_partitions &&
|
||||
lhs_sharding_transposed_to_match_output == output_sharding) {
|
||||
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs.hlo(), rhs_replicated, &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return Status::OK();
|
||||
create_sharded_dot(lhs.hlo(), rhs_replicated, b));
|
||||
return dot;
|
||||
}
|
||||
|
||||
// RHS and output have the same partitioned non-contracting dimensions.
|
||||
if (rhs_non_contracting_partitions == num_partitions_ &&
|
||||
output_rhs_non_contracting_partitions == num_partitions_ &&
|
||||
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
||||
if (rhs_non_contracting_partitions == num_partitions &&
|
||||
output_rhs_non_contracting_partitions == num_partitions &&
|
||||
rhs_sharding_transposed_to_match_output == output_sharding) {
|
||||
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs_replicated, rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return Status::OK();
|
||||
create_sharded_dot(lhs_replicated, rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
|
||||
// Output is batch partitioned.
|
||||
if (output_batch_partitions == num_partitions_) {
|
||||
if (output_batch_partitions == num_partitions) {
|
||||
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
resharded_rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return Status::OK();
|
||||
resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along LHS non-contracting dimensions.
|
||||
if (output_lhs_non_contracting_partitions == num_partitions_) {
|
||||
if (output_lhs_non_contracting_partitions == num_partitions) {
|
||||
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
|
||||
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return Status::OK();
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
|
||||
replicated_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
// Output is partitioned along RHS non-contracting dimensions.
|
||||
if (output_rhs_non_contracting_partitions == num_partitions_) {
|
||||
if (output_rhs_non_contracting_partitions == num_partitions) {
|
||||
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
|
||||
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
|
||||
resharded_rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] { return dot; });
|
||||
return Status::OK();
|
||||
resharded_rhs.hlo(), b));
|
||||
return dot;
|
||||
}
|
||||
|
||||
// Returns true if it is beneficial to reshard the operand at `operand_idx`
|
||||
// across the contracting dimension.
|
||||
const auto should_partition_contracting_dim = [&](int64 operand_idx) {
|
||||
if (!hlo->sharding().IsReplicated()) {
|
||||
if (!output_sharding.IsReplicated()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (operand_idx == 0) {
|
||||
// If LHS and output are replicated, we compare the cost of all-gather
|
||||
// on RHS vs all-reduce on the output.
|
||||
return (rhs_contracting_partitions == num_partitions_) &&
|
||||
return (rhs_contracting_partitions == num_partitions) &&
|
||||
lhs.sharding().IsReplicated() &&
|
||||
ShapeUtil::ElementsIn(hlo->operand(1)->shape()) >
|
||||
ShapeUtil::ElementsIn(hlo->shape());
|
||||
ShapeUtil::ElementsIn(rhs.base_shape()) >
|
||||
ShapeUtil::ElementsIn(output_base_shape);
|
||||
} else {
|
||||
return (lhs_contracting_partitions == num_partitions_) &&
|
||||
return (lhs_contracting_partitions == num_partitions) &&
|
||||
rhs.sharding().IsReplicated() &&
|
||||
ShapeUtil::ElementsIn(hlo->operand(0)->shape()) >
|
||||
ShapeUtil::ElementsIn(hlo->shape());
|
||||
ShapeUtil::ElementsIn(lhs.base_shape()) >
|
||||
ShapeUtil::ElementsIn(output_base_shape);
|
||||
}
|
||||
};
|
||||
|
||||
// When the output is replicated and one of the operands is partitioned along
|
||||
// contracting dimension, align the other operand to be partitioned along
|
||||
// the contracting dimensions.
|
||||
if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) ||
|
||||
if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
|
||||
should_partition_contracting_dim(1))) {
|
||||
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(hlo->shape().element_type())));
|
||||
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(output_base_shape.element_type())));
|
||||
if (should_partition_contracting_dim(0)) {
|
||||
lhs =
|
||||
lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
|
||||
@ -625,19 +585,361 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
rhs =
|
||||
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto dot,
|
||||
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
|
||||
return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
|
||||
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
|
||||
(*lhs.state().next_channel_id)++);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops);
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
std::vector<int64> lhs_dims;
|
||||
std::vector<int64> rhs_dims;
|
||||
std::vector<int64> output_dims;
|
||||
auto lhs_sharding_dims_adjusted_to_output =
|
||||
lhs.sharding().tile_assignment().dimensions();
|
||||
auto rhs_sharding_dims_adjusted_to_output =
|
||||
lhs.sharding().tile_assignment().dimensions();
|
||||
auto output_sharding_dims_adjusted_to_lhs =
|
||||
output_sharding.tile_assignment().dimensions();
|
||||
bool lhs_rhs_dims_matching = true;
|
||||
for (const auto& dim : dims_mapping.batch_dims) {
|
||||
lhs_dims.push_back(dim.lhs);
|
||||
rhs_dims.push_back(dim.rhs);
|
||||
output_dims.push_back(dim.output);
|
||||
if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
|
||||
rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
|
||||
lhs_rhs_dims_matching = false;
|
||||
}
|
||||
lhs_sharding_dims_adjusted_to_output[dim.lhs] =
|
||||
output_sharding.tile_assignment().dim(dim.output);
|
||||
rhs_sharding_dims_adjusted_to_output[dim.rhs] =
|
||||
output_sharding.tile_assignment().dim(dim.output);
|
||||
output_sharding_dims_adjusted_to_lhs[dim.output] =
|
||||
lhs.sharding().tile_assignment().dim(dim.lhs);
|
||||
}
|
||||
auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
|
||||
auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
|
||||
auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
|
||||
if (lhs_rhs_dims_matching) {
|
||||
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) >
|
||||
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
|
||||
rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
|
||||
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
|
||||
} else {
|
||||
lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
|
||||
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
|
||||
}
|
||||
auto reshaped_output_tiling = output_sharding.tile_assignment();
|
||||
reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
|
||||
output_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling),
|
||||
output_dims),
|
||||
lhs_grouped);
|
||||
} else {
|
||||
auto reshaped_lhs_tiling = lhs.sharding().tile_assignment();
|
||||
reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output);
|
||||
lhs_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims),
|
||||
output_grouped);
|
||||
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
|
||||
auto reshaped_rhs_tiling = rhs.sharding().tile_assignment();
|
||||
reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output);
|
||||
rhs_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims),
|
||||
output_grouped);
|
||||
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
|
||||
}
|
||||
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
|
||||
lhs.state(), lhs_grouped.device_groups, b);
|
||||
lhs.hlo()->set_sharding(lhs_grouped.sharding);
|
||||
rhs.hlo()->set_sharding(rhs_grouped.sharding);
|
||||
CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDot(
|
||||
PartitionedHlo(lhs.hlo(),
|
||||
GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
|
||||
per_group_partitioner_state),
|
||||
PartitionedHlo(rhs.hlo(),
|
||||
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
|
||||
per_group_partitioner_state),
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, dims_mapping,
|
||||
num_partitions / lhs_grouped.device_groups.size(), create_sharded_dot,
|
||||
module, original_hlo, threshold_for_windowed_einsum_mib, b,
|
||||
windowed_dot_general_loops));
|
||||
// Reset the LHS sharding to the ungrouped one.
|
||||
lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped));
|
||||
rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped));
|
||||
dot->set_sharding(UngroupSharding(output_grouped));
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
.hlo();
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
|
||||
bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
|
||||
int64 matching_contracting_partitions, int64 other_contracting_partitions,
|
||||
int64 matching_non_contracting_partitions,
|
||||
int64 other_non_contracting_partitions,
|
||||
int64 output_other_non_contracting_partitions,
|
||||
const Shape& output_base_shape, const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
const bool may_replicate_other_contracting_dims =
|
||||
(other_contracting_partitions == matching_non_contracting_partitions &&
|
||||
other_non_contracting_partitions ==
|
||||
output_other_non_contracting_partitions);
|
||||
const bool may_replicate_other_non_contracting_dims =
|
||||
matching_non_contracting_partitions == other_non_contracting_partitions &&
|
||||
matching_contracting_partitions == other_contracting_partitions;
|
||||
std::vector<int64> other_group_dims;
|
||||
if (may_replicate_other_contracting_dims &&
|
||||
(!may_replicate_other_non_contracting_dims ||
|
||||
ShapeUtil::ByteSizeOf(other.base_shape()) <=
|
||||
ShapeUtil::ByteSizeOf(output_base_shape))) {
|
||||
for (const auto& dim : dims_mapping.contracting_dims) {
|
||||
other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
|
||||
}
|
||||
} else if (may_replicate_other_non_contracting_dims) {
|
||||
for (const auto& dim : lhs_matching
|
||||
? dims_mapping.rhs_non_contracting_dims
|
||||
: dims_mapping.lhs_non_contracting_dims) {
|
||||
other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
|
||||
}
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
auto matching_sharding_dims =
|
||||
matching.sharding().tile_assignment().dimensions();
|
||||
std::vector<int64> matching_dims;
|
||||
std::vector<int64> output_dims;
|
||||
// Make sure the partitioning on matching's non-contracting dimensions
|
||||
// defines the same device groups for both matching and output.
|
||||
for (const auto& dim : lhs_matching ? dims_mapping.lhs_non_contracting_dims
|
||||
: dims_mapping.rhs_non_contracting_dims) {
|
||||
int64 md = lhs_matching ? dim.lhs : dim.rhs;
|
||||
matching_sharding_dims[md] =
|
||||
output_sharding.tile_assignment().dim(dim.output);
|
||||
matching_dims.push_back(md);
|
||||
output_dims.push_back(dim.output);
|
||||
}
|
||||
auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
|
||||
auto reshaped_matching_tiling = matching.sharding().tile_assignment();
|
||||
reshaped_matching_tiling.Reshape(matching_sharding_dims);
|
||||
auto matching_grouped = AlignGroupsWith(
|
||||
GroupShardingOnDims(HloSharding::Tile(reshaped_matching_tiling),
|
||||
matching_dims),
|
||||
output_grouped);
|
||||
matching = matching.Reshard(UngroupSharding(matching_grouped));
|
||||
|
||||
auto other_grouped =
|
||||
AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims),
|
||||
output_grouped, /*ignore_group_order=*/true);
|
||||
other = other.Reshard(UngroupSharding(other_grouped));
|
||||
auto partially_replicated_other =
|
||||
other.ReplicatePartial(other_grouped.group_dims);
|
||||
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
|
||||
matching.state(), matching_grouped.device_groups, b);
|
||||
matching.hlo()->set_sharding(matching_grouped.sharding);
|
||||
partially_replicated_other->set_sharding(other_grouped.sharding);
|
||||
auto matching_p = PartitionedHlo(
|
||||
matching.hlo(),
|
||||
GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
|
||||
per_group_partitioner_state);
|
||||
auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
|
||||
per_group_partitioner_state);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDot(lhs_matching ? matching_p : other_p,
|
||||
lhs_matching ? other_p : matching_p,
|
||||
GetPerGroupBaseShape(output_grouped, output_base_shape),
|
||||
output_grouped.sharding, dims_mapping,
|
||||
num_partitions / matching_grouped.device_groups.size(),
|
||||
create_sharded_dot, module, original_hlo,
|
||||
threshold_for_windowed_einsum_mib, b,
|
||||
windowed_dot_general_loops));
|
||||
// Reset matching's sharding to the ungrouped one.
|
||||
matching.hlo()->set_sharding(UngroupSharding(matching_grouped));
|
||||
return dot;
|
||||
}
|
||||
|
||||
// Recursive partitioning function. If there are partial dimensions matching in
|
||||
// the operands and output, group the devices and recursively partition the
|
||||
// in-group dot.
|
||||
StatusOr<HloInstruction*> PartitionDot(
|
||||
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
||||
const HloSharding& output_sharding,
|
||||
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
|
||||
HloModule* module, HloInstruction* original_hlo,
|
||||
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
|
||||
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
|
||||
windowed_dot_general_loops) {
|
||||
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
|
||||
auto get_partitions_for_dims =
|
||||
[&](const HloSharding& sharding,
|
||||
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
|
||||
int lhs_rhs_or_output) {
|
||||
int64 partitions = 1;
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return partitions;
|
||||
}
|
||||
for (const auto& dim : dims) {
|
||||
if (lhs_rhs_or_output == 0) {
|
||||
partitions *= sharding.tile_assignment().dim(dim.lhs);
|
||||
} else if (lhs_rhs_or_output == 1) {
|
||||
partitions *= sharding.tile_assignment().dim(dim.rhs);
|
||||
} else {
|
||||
CHECK_EQ(lhs_rhs_or_output, 2);
|
||||
partitions *= sharding.tile_assignment().dim(dim.output);
|
||||
}
|
||||
}
|
||||
return partitions;
|
||||
};
|
||||
const int64 lhs_batch_partitions =
|
||||
get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
|
||||
const int64 rhs_batch_partitions =
|
||||
get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
|
||||
const int64 output_batch_partitions =
|
||||
get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
|
||||
const int64 lhs_contracting_partitions =
|
||||
get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
|
||||
const int64 rhs_contracting_partitions =
|
||||
get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
|
||||
const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
|
||||
const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
|
||||
const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
|
||||
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
|
||||
output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto try_partitioned_dot,
|
||||
PartitionBaseCase(
|
||||
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
|
||||
lhs_contracting_partitions, rhs_contracting_partitions,
|
||||
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
|
||||
output_lhs_non_contracting_partitions,
|
||||
output_rhs_non_contracting_partitions,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
if (try_partitioned_dot) {
|
||||
return try_partitioned_dot;
|
||||
}
|
||||
|
||||
return DefaultAction(hlo);
|
||||
// Recursively partition on different types of dimensions.
|
||||
//
|
||||
// Case 1: Group partitions by batch.
|
||||
if (lhs_batch_partitions == rhs_batch_partitions &&
|
||||
lhs_batch_partitions == output_batch_partitions &&
|
||||
lhs_batch_partitions > 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDotGroupOnBatch(
|
||||
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
|
||||
num_partitions, create_sharded_dot, module, original_hlo,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Group partitions by non-contracting dimensions.
|
||||
const bool may_group_on_lhs_non_contracting =
|
||||
lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
|
||||
lhs_non_contracting_partitions > 1;
|
||||
const bool may_group_on_rhs_non_contracting =
|
||||
rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
|
||||
rhs_non_contracting_partitions > 1;
|
||||
if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
|
||||
// If both match output non-contracting dimensions, choose the one which
|
||||
// will result in smaller replication of the other operand.
|
||||
const bool lhs_matching =
|
||||
may_group_on_lhs_non_contracting &&
|
||||
(!may_group_on_rhs_non_contracting ||
|
||||
lhs_non_contracting_partitions *
|
||||
ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <=
|
||||
rhs_non_contracting_partitions *
|
||||
ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dot,
|
||||
PartitionDotGroupOnNonContracting(
|
||||
lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
|
||||
lhs_matching ? lhs_contracting_partitions
|
||||
: rhs_contracting_partitions,
|
||||
lhs_matching ? rhs_contracting_partitions
|
||||
: lhs_contracting_partitions,
|
||||
lhs_matching ? lhs_non_contracting_partitions
|
||||
: rhs_non_contracting_partitions,
|
||||
lhs_matching ? rhs_non_contracting_partitions
|
||||
: lhs_non_contracting_partitions,
|
||||
lhs_matching ? output_rhs_non_contracting_partitions
|
||||
: output_lhs_non_contracting_partitions,
|
||||
output_base_shape, output_sharding, dims_mapping, num_partitions,
|
||||
create_sharded_dot, module, original_hlo,
|
||||
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
|
||||
if (dot) {
|
||||
return dot;
|
||||
}
|
||||
}
|
||||
|
||||
// Default action.
|
||||
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(),
|
||||
rhs.Replicate().hlo(), b));
|
||||
dot->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(dot, output_base_shape, lhs.state())
|
||||
.Reshard(output_sharding)
|
||||
.hlo();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
|
||||
const std::function<StatusOr<HloInstruction*>(
|
||||
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
|
||||
auto& lhs = GetPartitionedHlo(hlo->operand(0));
|
||||
auto& rhs = GetPartitionedHlo(hlo->operand(1));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto partitioned_dot,
|
||||
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
|
||||
num_partitions_, create_sharded_dot, module_, hlo,
|
||||
options_.threshold_for_windowed_einsum_mib, &b_,
|
||||
&windowed_dot_general_loops_));
|
||||
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -780,6 +1082,7 @@ Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
|
||||
[](const HloInstruction* a, const HloInstruction* b) {
|
||||
return a->unique_id() < b->unique_id();
|
||||
});
|
||||
worklist.reserve(nullaries_to_sink.size());
|
||||
for (auto inst : nullaries_to_sink) {
|
||||
worklist.push_back(inst);
|
||||
}
|
||||
|
@ -165,16 +165,6 @@ template <typename F>
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the replica group configuration where each replica belongs to its own
|
||||
// group.
|
||||
std::vector<ReplicaGroup> CreateReplicaGroups(int64 num_replicas) {
|
||||
std::vector<ReplicaGroup> groups(num_replicas);
|
||||
for (int64 i = 0; i < num_replicas; ++i) {
|
||||
groups[i].add_replica_ids(i);
|
||||
}
|
||||
return groups;
|
||||
}
|
||||
|
||||
// Clears all sharding attributes from instructions in the module. This must be
|
||||
// called only after all SPMD transformation is complete.
|
||||
Status ClearShardingAttributes(HloModule* module) {
|
||||
@ -195,6 +185,28 @@ Status ClearShardingAttributes(HloModule* module) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
|
||||
const HloSharding& sharding, absl::Span<const int64> replication_dims) {
|
||||
int64 group_size = 1;
|
||||
for (int64 i : replication_dims) {
|
||||
group_size *= sharding.tile_assignment().dim(i);
|
||||
}
|
||||
std::vector<std::vector<int64>> partition_groups(
|
||||
sharding.tile_assignment().num_elements() / group_size);
|
||||
sharding.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 partition) {
|
||||
int64 group_id = 0;
|
||||
for (int64 i = 0; i < indices.size(); ++i) {
|
||||
if (!absl::c_linear_search(replication_dims, i)) {
|
||||
group_id *= sharding.tile_assignment().dim(i);
|
||||
group_id += indices[i];
|
||||
}
|
||||
}
|
||||
partition_groups[group_id].push_back(partition);
|
||||
});
|
||||
return partition_groups;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
HloInstruction* SpmdBuilder::AddInstruction(
|
||||
@ -664,42 +676,57 @@ PartitionedHlo PartitionedHlo::Replicate() {
|
||||
}
|
||||
|
||||
// 'Tiled' to 'Replicated'.
|
||||
std::vector<int64> all_dims(shape.rank());
|
||||
std::iota(all_dims.begin(), all_dims.end(), 0);
|
||||
HloInstruction* result = ReplicatePartial(all_dims);
|
||||
result->set_sharding(HloSharding::Replicate());
|
||||
return update_cache(PartitionedHlo(result, base_shape_, state_));
|
||||
}
|
||||
|
||||
HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
|
||||
CHECK(!sharding().IsTileMaximal());
|
||||
const Shape& shard_shape = hlo()->shape();
|
||||
Shape target_shape = shard_shape;
|
||||
Shape padded_target_shape = shard_shape;
|
||||
for (int64 i : dims) {
|
||||
padded_target_shape.set_dimensions(
|
||||
i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
|
||||
target_shape.set_dimensions(i, base_shape().dimensions(i));
|
||||
}
|
||||
|
||||
HloInstruction* result = nullptr;
|
||||
if (state_.collective_ops_creator.create_cross_partition_all_gather) {
|
||||
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding,
|
||||
NewChannel());
|
||||
}
|
||||
Shape padded_base_shape = shape;
|
||||
for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
|
||||
padded_base_shape.set_dimensions(
|
||||
i, shape.dimensions(i) * sharding.tile_assignment().dim(i));
|
||||
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(),
|
||||
NewChannel(), dims,
|
||||
state_.collective_ops_creator);
|
||||
}
|
||||
if (result == nullptr) {
|
||||
auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::Zero(shape.element_type())));
|
||||
LiteralUtil::Zero(shard_shape.element_type())));
|
||||
auto zero_bcast = state_.b->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(padded_base_shape, zero, {}));
|
||||
HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
|
||||
auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
|
||||
state_.partition_id, state_.b, dims);
|
||||
auto dus =
|
||||
state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
padded_base_shape, zero_bcast, hlo_,
|
||||
MakePartitionOffsets(padded_base_shape, sharding,
|
||||
state_.partition_id, state_.b)));
|
||||
padded_target_shape, zero_bcast, hlo_, offsets));
|
||||
HloComputation* reduction =
|
||||
MakeBinaryAdd(shape.element_type(), state_.module);
|
||||
MakeBinaryAdd(shard_shape.element_type(), state_.module);
|
||||
|
||||
auto all_reduce =
|
||||
state_.collective_ops_creator.create_cross_partition_all_reduce(
|
||||
state_.b, dus, reduction, NewChannel());
|
||||
state_.b, dus, reduction,
|
||||
GetPartitionGroupsForReplication(sharding(), dims), NewChannel());
|
||||
result = all_reduce;
|
||||
}
|
||||
if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) {
|
||||
std::vector<int64> start_indices(shape.rank(), 0);
|
||||
std::vector<int64> strides(shape.rank(), 1);
|
||||
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
|
||||
base_shape_, result, start_indices, base_shape_.dimensions(), strides));
|
||||
if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
|
||||
std::vector<int64> start_indices(target_shape.rank(), 0);
|
||||
std::vector<int64> strides(target_shape.rank(), 1);
|
||||
result = state_.b->AddInstruction(
|
||||
HloInstruction::CreateSlice(target_shape, result, start_indices,
|
||||
base_shape_.dimensions(), strides));
|
||||
}
|
||||
result->set_sharding(HloSharding::Replicate());
|
||||
return update_cache(PartitionedHlo(result, base_shape_, state_));
|
||||
return result;
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::Broadcast() const {
|
||||
@ -728,7 +755,7 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
|
||||
MakeBinaryAdd(shape.element_type(), state_.module);
|
||||
|
||||
auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
|
||||
state_.b, operand, reduction, NewChannel());
|
||||
state_.b, operand, reduction, {}, NewChannel());
|
||||
result->set_sharding(HloSharding::Replicate());
|
||||
return PartitionedHlo(result, base_shape_, state_);
|
||||
}
|
||||
@ -796,7 +823,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||
auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
|
||||
|
||||
// The order of ids in the group must follow the temp_target sharding.
|
||||
std::vector<ReplicaGroup> groups(
|
||||
std::vector<std::vector<int64>> groups(
|
||||
temp_target.tile_assignment().num_elements() / group_size);
|
||||
temp_target.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
@ -810,7 +837,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||
group_id += indices[dim];
|
||||
}
|
||||
}
|
||||
groups[group_id].add_replica_ids(device);
|
||||
groups[group_id].push_back(device);
|
||||
});
|
||||
|
||||
HloInstruction* result = nullptr;
|
||||
@ -1027,7 +1054,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
|
||||
offset += operand->shape().dimensions(dimension);
|
||||
}
|
||||
auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_),
|
||||
&b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
|
||||
NewChannel());
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
auto start_indices =
|
||||
@ -2153,7 +2180,7 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
||||
// Combine from different partitions.
|
||||
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, filtered,
|
||||
MakeBinaryAdd(filtered->shape().element_type(), module_),
|
||||
MakeBinaryAdd(filtered->shape().element_type(), module_), {},
|
||||
NewChannel());
|
||||
ar->set_sharding(HloSharding::Replicate());
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
@ -2449,7 +2476,7 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
if (reduce_sharded_dimension) {
|
||||
CHECK(local_reduce->shape().IsArray());
|
||||
reduce = collective_ops_creator_.create_cross_partition_all_reduce(
|
||||
&b_, local_reduce, hlo->to_apply(), NewChannel());
|
||||
&b_, local_reduce, hlo->to_apply(), {}, NewChannel());
|
||||
reduce->set_sharding(HloSharding::Replicate());
|
||||
} else {
|
||||
reduce = local_reduce;
|
||||
@ -2917,13 +2944,36 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
|
||||
[](SpmdBuilder* b) {
|
||||
return b->AddInstruction(HloInstruction::CreatePartitionId());
|
||||
},
|
||||
[num_replicas](SpmdBuilder* b, HloInstruction* operand,
|
||||
HloComputation* reduction, int64 channel_id) {
|
||||
[num_replicas, num_partitions](
|
||||
SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id) {
|
||||
if (partition_subgroups.size() <= 1) {
|
||||
std::vector<ReplicaGroup> groups(num_replicas);
|
||||
// TODO(yuanzx): Unify subgroup definition with AllToAll.
|
||||
for (int64 i = 0; i < num_replicas; ++i) {
|
||||
groups[i].add_replica_ids(i);
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateAllReduce(
|
||||
operand->shape(), {operand}, reduction, groups,
|
||||
/*constrain_layout=*/false, channel_id,
|
||||
/*use_global_device_ids=*/false));
|
||||
}
|
||||
|
||||
std::vector<ReplicaGroup> device_groups;
|
||||
device_groups.reserve(partition_subgroups.size() * num_replicas);
|
||||
for (int64 i = 0; i < num_replicas; ++i) {
|
||||
for (const auto& pgroup : partition_subgroups) {
|
||||
device_groups.emplace_back();
|
||||
for (int64 pid : pgroup) {
|
||||
device_groups.back().add_replica_ids(i * num_partitions + pid);
|
||||
}
|
||||
}
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateAllReduce(
|
||||
operand->shape(), {operand}, reduction,
|
||||
CreateReplicaGroups(num_replicas),
|
||||
operand->shape(), {operand}, reduction, device_groups,
|
||||
/*constrain_layout=*/false, channel_id,
|
||||
/*use_global_device_ids=*/false));
|
||||
/*use_global_device_ids=*/true));
|
||||
},
|
||||
[](SpmdBuilder* b, HloInstruction* operand,
|
||||
std::vector<std::pair<int64, int64>>& src_dst_pairs,
|
||||
@ -2932,14 +2982,20 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
|
||||
operand->shape(), operand, src_dst_pairs, channel_id));
|
||||
},
|
||||
[](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
|
||||
absl::optional<int64> split_dimension) {
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id, absl::optional<int64> split_dimension) {
|
||||
std::vector<Shape> shapes(operands.size(), operands[0]->shape());
|
||||
const Shape output_shape = (shapes.size() == 1)
|
||||
? shapes[0]
|
||||
: ShapeUtil::MakeTupleShape(shapes);
|
||||
std::vector<ReplicaGroup> groups(partition_subgroups.size());
|
||||
for (int64 i = 0; i < groups.size(); ++i) {
|
||||
for (int64 id : partition_subgroups[i]) {
|
||||
groups[i].add_replica_ids(id);
|
||||
}
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateAllToAll(
|
||||
output_shape, operands, replica_groups,
|
||||
output_shape, operands, groups,
|
||||
/*constrain_layout=*/false, channel_id, split_dimension));
|
||||
},
|
||||
[num_replicas, num_partitions](
|
||||
@ -2970,10 +3026,10 @@ SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
|
||||
num_partitions, num_replicas, std::move(options),
|
||||
GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
|
||||
|
||||
HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
|
||||
HloInstruction* operand,
|
||||
const HloSharding& sharding,
|
||||
int64 channel_id) {
|
||||
HloInstruction* SpmdPartitioner::AllGatherShards(
|
||||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64 channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator) {
|
||||
CHECK(!sharding.IsTileMaximal());
|
||||
// Add one leading dimension to gather all partitions.
|
||||
std::vector<int64> shape;
|
||||
@ -2983,18 +3039,17 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
|
||||
}
|
||||
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
|
||||
std::vector<std::vector<int64>> partition_subgroups(1);
|
||||
for (int64 pid : sharding.tile_assignment()) {
|
||||
partition_subgroups[0].push_back(pid);
|
||||
}
|
||||
shape[0] = sharding.tile_assignment().num_elements();
|
||||
auto result = collective_ops_creator_.create_cross_partition_all_gather(
|
||||
auto partition_subgroups =
|
||||
GetPartitionGroupsForReplication(sharding, selected_dims);
|
||||
shape[0] = partition_subgroups[0].size();
|
||||
auto result = collectives_creator.create_cross_partition_all_gather(
|
||||
b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
|
||||
partition_subgroups, channel_id, /*all_gather_dimension=*/0);
|
||||
// If n > 1 dimensions are partitioned, split the leading dimension to n.
|
||||
std::vector<int64> tiled_dims;
|
||||
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
|
||||
if (sharding.tile_assignment().dim(i) > 1) {
|
||||
if (sharding.tile_assignment().dim(i) > 1 &&
|
||||
absl::c_linear_search(selected_dims, i)) {
|
||||
tiled_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
@ -3016,7 +3071,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
|
||||
std::vector<int64> xpose_permutation(result->shape().rank());
|
||||
int64 split_dims_added = 0;
|
||||
for (int64 i = 0; i < xpose_permutation.size(); ++i) {
|
||||
if (sharding.tile_assignment().dim(i - split_dims_added) == 1) {
|
||||
if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
|
||||
!absl::c_linear_search(selected_dims, i - split_dims_added)) {
|
||||
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
|
||||
} else {
|
||||
xpose_permutation[i] = split_dims_added;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -82,8 +83,10 @@ struct SPMDCollectiveOpsCreator {
|
||||
std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
|
||||
|
||||
// Function used to create a cross-partition all-reduce HLO.
|
||||
std::function<HloInstruction*(SpmdBuilder*, HloInstruction* operand,
|
||||
HloComputation* reduction, int64 channel_id)>
|
||||
std::function<HloInstruction*(
|
||||
SpmdBuilder*, HloInstruction* operand, HloComputation* reduction,
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id)>
|
||||
create_cross_partition_all_reduce;
|
||||
|
||||
// Function used to create a cross-partition collective-permute HLO.
|
||||
@ -96,8 +99,8 @@ struct SPMDCollectiveOpsCreator {
|
||||
// Function used to create a cross-partition all-to-all HLO.
|
||||
std::function<HloInstruction*(
|
||||
SpmdBuilder*, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
|
||||
absl::optional<int64> split_dimension)>
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id, absl::optional<int64> split_dimension)>
|
||||
create_cross_partition_all_to_all;
|
||||
|
||||
// Function used to create a cross-partition all-gather HLO. This is optional:
|
||||
@ -169,10 +172,13 @@ class SpmdPartitioner : public HloModulePass {
|
||||
// The default uses a single all-gather even if there are multiple sharded
|
||||
// dimensions, and adds potential reshapes and transposes to achieve that.
|
||||
// If it returns false, the partitioner will fall back to all-reduce.
|
||||
virtual HloInstruction* AllGatherShards(SpmdBuilder* b,
|
||||
HloInstruction* operand,
|
||||
const HloSharding& sharding,
|
||||
int64 channel_id);
|
||||
// `selected_dims` specifies the dimensions along which the all-gather happens
|
||||
// in the tiled sharding, which allows potentially creating a subgroup
|
||||
// all-gather.
|
||||
virtual HloInstruction* AllGatherShards(
|
||||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64 channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator);
|
||||
|
||||
protected:
|
||||
virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
|
||||
@ -215,7 +221,12 @@ class PartitionedHlo {
|
||||
std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
|
||||
window_reshard_cache;
|
||||
};
|
||||
// Use std::unordered_map for pointer stability.
|
||||
std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
|
||||
// Caches for nested partitioning of grouped sharding. Each string key
|
||||
// represents a unique way of grouping devices.
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>>
|
||||
groupd_caches;
|
||||
};
|
||||
struct PartitioningState {
|
||||
SpmdBuilder* b;
|
||||
@ -270,15 +281,18 @@ class PartitionedHlo {
|
||||
|
||||
const PartitioningState& state() const { return state_; }
|
||||
|
||||
// Helper function to replicate the data on all devices. Could only modify
|
||||
// the reshard cache.
|
||||
PartitionedHlo Replicate();
|
||||
|
||||
// Helper function to replicate the data for partitions along the given dims.
|
||||
HloInstruction* ReplicatePartial(absl::Span<const int64> dims);
|
||||
|
||||
private:
|
||||
// Same as Reshard except that it does not explicitly modify the reshard
|
||||
// cache, although it would indirectly modify by calling Replicate().
|
||||
PartitionedHlo ReshardNoCache(const HloSharding& target);
|
||||
|
||||
// Helper function to replicate the data on all devices. Could only modify
|
||||
// the reshard cache.
|
||||
PartitionedHlo Replicate();
|
||||
|
||||
// Helper function to broadcast data from a single device to all devices.
|
||||
PartitionedHlo Broadcast() const;
|
||||
|
||||
@ -417,6 +431,16 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
StatusOr<bool> DoPartition(HloComputation* computation,
|
||||
const HloSharding& root_sharding);
|
||||
|
||||
// Information about a loop created for windowed dot-general. Used when
|
||||
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
|
||||
// finishes traversing the graph.
|
||||
struct WindowedDotGeneralLoop {
|
||||
HloInstruction* while_loop;
|
||||
int64 windowed_operand;
|
||||
bool windowed_in_contracting_dims;
|
||||
bool windowed_in_batch_dims;
|
||||
};
|
||||
|
||||
private:
|
||||
Status Preprocess(HloInstruction* hlo) override;
|
||||
Status Postprocess(HloInstruction* hlo) override;
|
||||
@ -445,15 +469,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
||||
// partitioned instruction.
|
||||
ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
|
||||
|
||||
// Information about a loop created for windowed dot-general. Used when
|
||||
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
|
||||
// finishes traversing the graph.
|
||||
struct WindowedDotGeneralLoop {
|
||||
HloInstruction* while_loop;
|
||||
int64 windowed_operand;
|
||||
bool windowed_in_contracting_dims;
|
||||
bool windowed_in_batch_dims;
|
||||
};
|
||||
std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
|
||||
|
||||
HloInstruction* visiting_hlo_;
|
||||
|
@ -2218,7 +2218,7 @@ ENTRY entry {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
std::cout << module->ToString();
|
||||
VLOG(1) << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
|
||||
@ -2294,7 +2294,7 @@ ENTRY entry
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
std::cout << module->ToString();
|
||||
VLOG(1) << module->ToString();
|
||||
auto sort = FindInstruction(module.get(), "sort");
|
||||
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
|
||||
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
|
||||
@ -3842,6 +3842,154 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3}
|
||||
%rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,1,2,3}
|
||||
ROOT %dot = f32[48,32] dot(%lhs, %rhs),
|
||||
lhs_batch_dims={}, rhs_batch_dims={},
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={1},
|
||||
sharding={devices=[2,2]0,1,2,3}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0));
|
||||
auto partial_replicated_lhs =
|
||||
AllOf(op::Shape("f32[24,12]"),
|
||||
op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _)));
|
||||
auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1));
|
||||
auto partial_replicated_rhs =
|
||||
AllOf(op::Shape("f32[16,12]"), op::AllReduce(op::DynamicUpdateSlice(
|
||||
_, op::CollectivePermute(rhs), _, _)));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs),
|
||||
op::Shape("f32[24,16]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3}
|
||||
%rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
|
||||
ROOT %dot = f32[48,32] dot(%lhs, %rhs),
|
||||
lhs_batch_dims={}, rhs_batch_dims={},
|
||||
lhs_contracting_dims={1}, rhs_contracting_dims={1},
|
||||
sharding={devices=[2,2]0,1,2,3}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
|
||||
auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
|
||||
auto partial_replicated_rhs =
|
||||
AllOf(op::Shape("f32[32,50]"),
|
||||
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(
|
||||
root, AllOf(op::Shape("f32[24,16]"),
|
||||
op::DynamicSlice(
|
||||
op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs),
|
||||
op::Shape("f32[24,32]"))),
|
||||
_, _)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
|
||||
%rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
|
||||
ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
|
||||
lhs_batch_dims={0}, rhs_batch_dims={0},
|
||||
lhs_contracting_dims={2}, rhs_contracting_dims={2},
|
||||
sharding={devices=[2,2,1]0,1,2,3}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0));
|
||||
auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
|
||||
auto partial_replicated_rhs =
|
||||
AllOf(op::Shape("f32[2,32,100]"),
|
||||
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _)));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
|
||||
op::Dot(lhs, partial_replicated_rhs)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest,
|
||||
Dot2DPartitionedBatchNonContractingAndContracting) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
|
||||
%rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
|
||||
ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
|
||||
lhs_batch_dims={0}, rhs_batch_dims={0},
|
||||
lhs_contracting_dims={2}, rhs_contracting_dims={2},
|
||||
sharding={devices=[2,1,2]0,1,2,3}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
|
||||
auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
|
||||
auto partial_replicated_lhs =
|
||||
AllOf(op::Shape("f32[2,24,100]"),
|
||||
op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _)));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"),
|
||||
op::Dot(partial_replicated_lhs, rhs)));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3}
|
||||
%rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3}
|
||||
ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs),
|
||||
lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
|
||||
lhs_contracting_dims={3}, rhs_contracting_dims={3},
|
||||
sharding={devices=[1,2,2,1]0,1,2,3}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/4));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0));
|
||||
auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1));
|
||||
auto partial_replicated_rhs =
|
||||
AllOf(op::Shape("f32[2,8,32,100]"),
|
||||
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _)));
|
||||
auto dot =
|
||||
AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs));
|
||||
auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot));
|
||||
auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape));
|
||||
auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -16,7 +16,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
@ -143,10 +148,10 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
|
||||
return partition_shape;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
HloInstruction* partition_id,
|
||||
SpmdBuilder* b) {
|
||||
std::vector<HloInstruction*> MakePartitionOffsets(
|
||||
const Shape& shape, const HloSharding& sharding,
|
||||
HloInstruction* partition_id, SpmdBuilder* b,
|
||||
absl::Span<const int64> dims) {
|
||||
CHECK(!shape.IsTuple());
|
||||
|
||||
Array2D<int32> offset_array(
|
||||
@ -158,7 +163,8 @@ std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
|
||||
LiteralUtil::CreateR2FromArray2D(offset_array)));
|
||||
std::vector<HloInstruction*> offsets;
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (sharding.tile_assignment().dim(i) == 1) {
|
||||
if (sharding.tile_assignment().dim(i) == 1 ||
|
||||
(!dims.empty() && !absl::c_linear_search(dims, i))) {
|
||||
offsets.push_back(b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
|
||||
} else {
|
||||
@ -978,5 +984,255 @@ bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||
source.tile_assignment() != target.tile_assignment();
|
||||
}
|
||||
|
||||
GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
|
||||
absl::Span<const int64> group_dims) {
|
||||
CHECK(!sharding.IsTileMaximal());
|
||||
std::vector<int64> grouped_tiling_dims =
|
||||
sharding.tile_assignment().dimensions();
|
||||
std::vector<int64> group_dim_sizes(group_dims.size());
|
||||
for (int64 i = 0; i < group_dims.size(); ++i) {
|
||||
group_dim_sizes[i] = grouped_tiling_dims[group_dims[i]];
|
||||
grouped_tiling_dims[group_dims[i]] = 1;
|
||||
}
|
||||
std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes));
|
||||
sharding.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
int64 group_id = 0;
|
||||
for (int64 dim : group_dims) {
|
||||
group_id *= sharding.tile_assignment().dim(dim);
|
||||
group_id += indices[dim];
|
||||
}
|
||||
device_groups[group_id].push_back(device);
|
||||
});
|
||||
Array<int64> grouped_tiling(grouped_tiling_dims);
|
||||
grouped_tiling.FillIota(0);
|
||||
return GroupedSharding(
|
||||
std::move(device_groups),
|
||||
std::vector<int64>(group_dims.begin(), group_dims.end()),
|
||||
std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
|
||||
HloSharding::Tile(grouped_tiling));
|
||||
}
|
||||
|
||||
HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
|
||||
CHECK(!grouped_sharding.sharding.IsTileMaximal());
|
||||
std::vector<int64> tiling_dims =
|
||||
grouped_sharding.sharding.tile_assignment().dimensions();
|
||||
for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
|
||||
tiling_dims[grouped_sharding.group_dims[i]] =
|
||||
grouped_sharding.group_dim_sizes[i];
|
||||
}
|
||||
Array<int64> tiling(tiling_dims);
|
||||
grouped_sharding.sharding.tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
std::vector<int64> ungrouped_inds(indices.begin(), indices.end());
|
||||
for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
|
||||
int64 remaining_group_index = g;
|
||||
for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
|
||||
ungrouped_inds[grouped_sharding.group_dims[i]] =
|
||||
remaining_group_index % grouped_sharding.group_dim_sizes[i];
|
||||
remaining_group_index /= grouped_sharding.group_dim_sizes[i];
|
||||
}
|
||||
tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
|
||||
}
|
||||
});
|
||||
return HloSharding::Tile(tiling);
|
||||
}
|
||||
|
||||
GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
|
||||
const GroupedSharding& reference,
|
||||
bool ignore_group_order) {
|
||||
// Returns src -> dst index mapping.
|
||||
auto get_permutation = [](absl::Span<const int64> src,
|
||||
absl::Span<const int64> dst) {
|
||||
CHECK_EQ(src.size(), dst.size());
|
||||
absl::flat_hash_map<int64, int64> dst_reverse_map;
|
||||
for (int64 i = 0; i < dst.size(); ++i) {
|
||||
dst_reverse_map[dst[i]] = i;
|
||||
}
|
||||
std::vector<int64> permutation(src.size());
|
||||
for (int64 i = 0; i < src.size(); ++i) {
|
||||
auto it = dst_reverse_map.find(src[i]);
|
||||
CHECK(it != dst_reverse_map.end());
|
||||
permutation[i] = it->second;
|
||||
}
|
||||
return permutation;
|
||||
};
|
||||
CHECK_EQ(grouped_sharding.device_groups.size(),
|
||||
reference.device_groups.size());
|
||||
absl::flat_hash_map<int64, int64> device_to_ref_group;
|
||||
for (int64 g = 0; g < reference.device_groups.size(); ++g) {
|
||||
for (int64 device : reference.device_groups[g]) {
|
||||
device_to_ref_group[device] = g;
|
||||
}
|
||||
}
|
||||
auto unique_ref_dev_group = [&](absl::Span<const int64> devices) -> int64 {
|
||||
int64 ref_g = -1;
|
||||
for (int64 device : devices) {
|
||||
if (ref_g == -1) {
|
||||
ref_g = device_to_ref_group[device];
|
||||
} else if (ref_g != device_to_ref_group[device]) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return ref_g;
|
||||
};
|
||||
bool matching_groups = true;
|
||||
std::vector<int64> original_src_to_ref_permutation;
|
||||
for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
|
||||
int64 ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
|
||||
if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
|
||||
matching_groups = false;
|
||||
break;
|
||||
}
|
||||
if (g == 0) {
|
||||
original_src_to_ref_permutation = get_permutation(
|
||||
grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
|
||||
}
|
||||
}
|
||||
if (matching_groups) {
|
||||
auto tiles = grouped_sharding.sharding.tile_assignment();
|
||||
tiles.Each([&](absl::Span<const int64> indices, int64* device) {
|
||||
*device = original_src_to_ref_permutation[*device];
|
||||
});
|
||||
grouped_sharding.sharding = HloSharding::Tile(tiles);
|
||||
}
|
||||
grouped_sharding.device_groups = std::move(reference.device_groups);
|
||||
return grouped_sharding;
|
||||
}
|
||||
|
||||
Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
|
||||
const Shape& original_base_shape) {
|
||||
auto result = original_base_shape;
|
||||
for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
|
||||
int64 dim = grouped_sharding.group_dims[i];
|
||||
int64 groups = grouped_sharding.group_dim_sizes[i];
|
||||
result.set_dimensions(dim, result.dimensions(dim) / groups);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
HloInstruction* GetInGroupPartitionId(
|
||||
HloInstruction* partition_id,
|
||||
const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
|
||||
int64 total_devices = device_groups.size() * device_groups[0].size();
|
||||
std::vector<uint32> in_group_ids(total_devices);
|
||||
for (uint32 i = 0; i < device_groups.size(); ++i) {
|
||||
for (uint32 j = 0; j < device_groups[i].size(); ++j) {
|
||||
in_group_ids[device_groups[i][j]] = j;
|
||||
}
|
||||
}
|
||||
auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<uint32>(in_group_ids)));
|
||||
return b->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeScalarShape(U32),
|
||||
b->AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
|
||||
}
|
||||
|
||||
SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
|
||||
const SPMDCollectiveOpsCreator& creator,
|
||||
const std::vector<std::vector<int64>>& device_groups) {
|
||||
SPMDCollectiveOpsCreator result;
|
||||
result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
|
||||
return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
|
||||
b);
|
||||
};
|
||||
auto expand_partition_groups =
|
||||
[device_groups](
|
||||
const std::vector<std::vector<int64>>& partition_subgroups) {
|
||||
if (partition_subgroups.empty()) {
|
||||
return device_groups;
|
||||
}
|
||||
std::vector<std::vector<int64>> result(partition_subgroups.size() *
|
||||
device_groups.size());
|
||||
for (int64 g = 0; g < device_groups.size(); ++g) {
|
||||
for (int64 i = 0; i < partition_subgroups.size(); ++i) {
|
||||
result[g * partition_subgroups.size() + i].resize(
|
||||
partition_subgroups[i].size());
|
||||
for (int64 j = 0; j < partition_subgroups[i].size(); ++j) {
|
||||
result[g * partition_subgroups.size() + i][j] =
|
||||
device_groups[g][partition_subgroups[i][j]];
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
result.create_cross_partition_all_reduce =
|
||||
[creator, expand_partition_groups](
|
||||
SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id) {
|
||||
return creator.create_cross_partition_all_reduce(
|
||||
b, operand, reduction, expand_partition_groups(partition_subgroups),
|
||||
channel_id);
|
||||
};
|
||||
result.create_cross_partition_collective_permute =
|
||||
[creator, device_groups](
|
||||
SpmdBuilder* b, HloInstruction* operand,
|
||||
std::vector<std::pair<int64, int64>>& src_dst_pairs,
|
||||
int64 next_channel_id) {
|
||||
std::vector<std::pair<int64, int64>> expanded_pairs(
|
||||
src_dst_pairs.size() * device_groups.size());
|
||||
for (int64 g = 0; g < device_groups.size(); ++g) {
|
||||
for (int64 i = 0; i < src_dst_pairs.size(); ++i) {
|
||||
expanded_pairs[g * src_dst_pairs.size() + i] =
|
||||
std::pair<int64, int64>{
|
||||
device_groups[g][src_dst_pairs[i].first],
|
||||
device_groups[g][src_dst_pairs[i].second]};
|
||||
}
|
||||
}
|
||||
return creator.create_cross_partition_collective_permute(
|
||||
b, operand, expanded_pairs, next_channel_id);
|
||||
};
|
||||
result.create_cross_partition_all_to_all =
|
||||
[creator, expand_partition_groups](
|
||||
SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id, absl::optional<int64> split_dimension) {
|
||||
return creator.create_cross_partition_all_to_all(
|
||||
b, operands, expand_partition_groups(partition_subgroups),
|
||||
channel_id, split_dimension);
|
||||
};
|
||||
if (creator.create_cross_partition_all_gather) {
|
||||
result.create_cross_partition_all_gather =
|
||||
[creator, expand_partition_groups](
|
||||
SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
|
||||
const std::vector<std::vector<int64>>& partition_subgroups,
|
||||
int64 channel_id, int64 all_gather_dimension) {
|
||||
return creator.create_cross_partition_all_gather(
|
||||
b, operand, ag_shape,
|
||||
expand_partition_groups(partition_subgroups), channel_id,
|
||||
all_gather_dimension);
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
|
||||
const PartitionedHlo::PartitioningState& state,
|
||||
const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
|
||||
auto result = state;
|
||||
result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
|
||||
state.collective_ops_creator, device_groups);
|
||||
result.partition_id =
|
||||
GetInGroupPartitionId(state.partition_id, device_groups, b);
|
||||
// Create a string key for the groups.
|
||||
std::vector<std::string> per_group_strings(device_groups.size());
|
||||
for (int64 i = 0; i < per_group_strings.size(); ++i) {
|
||||
per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
|
||||
}
|
||||
auto& grouped_cache =
|
||||
state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
|
||||
if (!grouped_cache) {
|
||||
grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>();
|
||||
}
|
||||
result.reshard_cache = grouped_cache.get();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -87,10 +87,12 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
|
||||
|
||||
// Generates the HLO instructions that represent the dimension offsets on any
|
||||
// device. The size of the returned vector is the rank of the given shape.
|
||||
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
|
||||
const HloSharding& sharding,
|
||||
HloInstruction* partition_id,
|
||||
SpmdBuilder* b);
|
||||
// If `dims` is non-empty, the generated offsets will only be non-zero for those
|
||||
// dimensions.
|
||||
std::vector<HloInstruction*> MakePartitionOffsets(
|
||||
const Shape& shape, const HloSharding& sharding,
|
||||
HloInstruction* partition_id, SpmdBuilder* b,
|
||||
absl::Span<const int64> dims = {});
|
||||
|
||||
// Returns the offsets of the partition in the tile assignment.
|
||||
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
|
||||
@ -276,6 +278,48 @@ GetReshardAllToAllSourceTargetDims(const HloSharding& source,
|
||||
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||
const HloSharding& target);
|
||||
|
||||
// Represents grouping devices in a tiled sharding along certain dimensions.
|
||||
// Elements in group dimensions define different device groups, and the sharding
|
||||
// represents the in-group sharding.
|
||||
struct GroupedSharding {
|
||||
GroupedSharding(std::vector<std::vector<int64>> device_groups,
|
||||
std::vector<int64> group_dims,
|
||||
std::vector<int64> group_dim_sizes, int64 rank,
|
||||
HloSharding grouped_sharding)
|
||||
: device_groups(std::move(device_groups)),
|
||||
group_dims(std::move(group_dims)),
|
||||
group_dim_sizes(std::move(group_dim_sizes)),
|
||||
sharding(std::move(grouped_sharding)) {}
|
||||
std::vector<std::vector<int64>> device_groups;
|
||||
std::vector<int64> group_dims;
|
||||
std::vector<int64> group_dim_sizes;
|
||||
int64 rank;
|
||||
HloSharding sharding;
|
||||
};
|
||||
|
||||
// Creates a GroupedSharding for a tiled sharding.
|
||||
GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
|
||||
absl::Span<const int64> group_dims);
|
||||
|
||||
// Reconstructs the ungrouped sharding from a GroupedSharding.
|
||||
HloSharding UngroupSharding(const GroupedSharding& grouped_sharding);
|
||||
|
||||
// Returns a new GroupedSharding that has the same group definition of
|
||||
// `reference`.
|
||||
GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
|
||||
const GroupedSharding& reference,
|
||||
bool ignore_group_order = false);
|
||||
|
||||
// Returns the per-group base shape, i.e., before applying the in-group
|
||||
// sharding.
|
||||
Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
|
||||
const Shape& original_base_shape);
|
||||
|
||||
// Creates the nested partitioner state for in-group patitioning.
|
||||
PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
|
||||
const PartitionedHlo::PartitioningState& state,
|
||||
const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user