Fix build and resubmit: [XLA:SPMD] Recursively handling more Dot cases

PiperOrigin-RevId: 322949994
Change-Id: I44a8a8e958a7ba4995a667d139f793dfa3a4fe7f
This commit is contained in:
Yuanzhong Xu 2020-07-24 00:38:06 -07:00 committed by TensorFlower Gardener
parent 8d4711c52d
commit 0b5cc6f1b9
8 changed files with 1121 additions and 298 deletions

View File

@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/core/platform:numbers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",

View File

@ -226,7 +226,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
hlo->batch_group_count(), new_window,
hlo->convolution_dimension_numbers(), hlo->precision_config()));
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
@ -605,7 +605,7 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
hlo->batch_group_count(), new_window, dnums,
hlo->precision_config()));
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_),
&b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())

View File

@ -80,12 +80,25 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
return HandleDotHelper(hlo, mapping, create_sharded_dot);
}
Status SpmdPartitioningVisitor::HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
namespace {
StatusOr<HloInstruction*> PartitionBaseCase(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
const HloSharding& lhs_sharding = hlo->operand(0)->sharding();
const HloSharding& rhs_sharding = hlo->operand(1)->sharding();
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions,
int64 rhs_batch_partitions, int64 output_batch_partitions,
int64 lhs_contracting_partitions, int64 rhs_contracting_partitions,
int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions,
int64 output_lhs_non_contracting_partitions,
int64 output_rhs_non_contracting_partitions,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
const HloSharding& lhs_sharding = lhs.sharding();
const HloSharding& rhs_sharding = rhs.sharding();
// Similar to hlo_sharding_util::TransposeSharding(), but allows
// removing/adding non-partitioned dimensions.
@ -132,12 +145,12 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
return HloSharding::Tile(reshape_tiles);
};
std::vector<int64> lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1);
std::vector<int64> lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1);
std::vector<int64> rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1);
std::vector<int64> rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1);
std::vector<int64> output_to_lhs_indices(hlo->shape().rank(), -1);
std::vector<int64> output_to_rhs_indices(hlo->shape().rank(), -1);
std::vector<int64> lhs_to_rhs_indices(lhs.base_shape().rank(), -1);
std::vector<int64> lhs_to_output_indices(lhs.base_shape().rank(), -1);
std::vector<int64> rhs_to_lhs_indices(rhs.base_shape().rank(), -1);
std::vector<int64> rhs_to_output_indices(rhs.base_shape().rank(), -1);
std::vector<int64> output_to_lhs_indices(output_base_shape.rank(), -1);
std::vector<int64> output_to_rhs_indices(output_base_shape.rank(), -1);
auto populate_indices_mapping =
[&](const DotGeneralDimsMapping::DimsMapping& mapping) {
if (mapping.lhs >= 0) {
@ -174,127 +187,84 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
auto rhs_sharding_transposed_to_match_output = transpose_sharding(
rhs_sharding, rhs_to_output_indices, output_to_rhs_indices);
auto output_sharding_transposed_to_match_lhs = transpose_sharding(
hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices);
output_sharding, output_to_lhs_indices, lhs_to_output_indices);
auto output_sharding_transposed_to_match_rhs = transpose_sharding(
hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices);
output_sharding, output_to_rhs_indices, rhs_to_output_indices);
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
auto get_partitions_for_dims =
[&](const HloSharding& sharding,
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
int lhs_rhs_or_output) {
int64 partitions = 1;
if (sharding.IsTileMaximal()) {
return partitions;
}
for (const auto& dim : dims) {
if (lhs_rhs_or_output == 0) {
partitions *= sharding.tile_assignment().dim(dim.lhs);
} else if (lhs_rhs_or_output == 1) {
partitions *= sharding.tile_assignment().dim(dim.rhs);
} else {
CHECK_EQ(lhs_rhs_or_output, 2);
partitions *= sharding.tile_assignment().dim(dim.output);
}
}
return partitions;
};
const int64 lhs_batch_partitions =
get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0);
const int64 rhs_batch_partitions =
get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1);
const int64 output_batch_partitions =
get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2);
const int64 lhs_contracting_partitions =
get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0);
const int64 rhs_contracting_partitions =
get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1);
const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0);
const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1);
const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2);
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2);
auto& lhs = GetPartitionedHlo(hlo->operand(0));
auto& rhs = GetPartitionedHlo(hlo->operand(1));
// LHS and RHS are partitioned the same way and only partitioned in batch
// dimensions.
if (lhs_batch_partitions == rhs_batch_partitions &&
rhs_batch_partitions == num_partitions_ &&
rhs_batch_partitions == num_partitions &&
lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] {
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
return PartitionedHlo(dot, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
dot->set_sharding(*lhs_sharding_transposed_to_match_output);
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
.hlo();
}
// Try emit batch-partitioned einsum with one operand resharded. Returns
// whether the attempt succeeds. If may_reshard_with_allreduce is false,
// reshard must be done using all-to-all; otherwise this attempt fails.
// partitioned HLO or nullptr if the attempt fails. If
// may_reshard_with_allreduce is false, reshard must be done using
// all-to-all/collective-permute; otherwise this attempt fails.
auto try_emit_output_batch_partitioned_einsum_with_reshard =
[&](bool may_reshard_with_allreduce) -> StatusOr<bool> {
[&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
// LHS and output are batch partitioned in the same way.
if (lhs_batch_partitions == num_partitions_ &&
output_batch_partitions == num_partitions_ &&
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
if (lhs_batch_partitions == num_partitions &&
output_batch_partitions == num_partitions &&
lhs_sharding_transposed_to_match_output == output_sharding) {
if (!may_reshard_with_allreduce &&
!CanReshardWithCollectivePermute(
rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
!GetReshardAllToAllSourceTargetDims(
rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
return false;
return nullptr;
}
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return true;
auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b));
return dot;
}
// RHS and output are batch partitioned in the same way.
if (rhs_batch_partitions == num_partitions_ &&
output_batch_partitions == num_partitions_ &&
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
if (rhs_batch_partitions == num_partitions &&
output_batch_partitions == num_partitions &&
rhs_sharding_transposed_to_match_output == output_sharding) {
if (!may_reshard_with_allreduce &&
!CanReshardWithCollectivePermute(
lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
!GetReshardAllToAllSourceTargetDims(
lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
return false;
return nullptr;
}
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
TF_ASSIGN_OR_RETURN(
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return true;
auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b));
return dot;
}
return false;
return nullptr;
};
{
// Try batch-parallel by resharding one operand, and not using all-reduce.
TF_ASSIGN_OR_RETURN(
bool emitted,
HloInstruction * partitioned_dot,
try_emit_output_batch_partitioned_einsum_with_reshard(false));
if (emitted) {
return Status::OK();
if (partitioned_dot) {
return partitioned_dot;
}
}
// Try to emit windowed DotGeneral when one operand is partitioned in the same
// way as the output along non-contracting dimensions, but the other operand
// is tiled in other dimensions.
auto emit_windowed_dot_general = [&](int64 matching_operand,
int64 windowing_operand,
bool windowed_at_contracting_dims,
bool windowed_at_batch_dims) {
auto emit_windowed_dot_general =
[&](int64 matching_operand, int64 windowing_operand,
bool windowed_at_contracting_dims,
bool windowed_at_batch_dims) -> StatusOr<HloInstruction*> {
CHECK_EQ(matching_operand + windowing_operand, 1);
CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims);
auto unpadded_result_buffer_shape =
MakePartitionedShape(hlo->shape(), hlo->sharding());
MakePartitionedShape(output_base_shape, output_sharding);
auto padded_result_buffer_shape = unpadded_result_buffer_shape;
// For windowing at batch/non-contracting dims, we produce the result one
// partition at a time, so we need to pad the shape in case of uneven
@ -310,17 +280,17 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
if (windowed_at_contracting_dims) {
auto& to_mask = windowing_operand == 0 ? lhs : rhs;
to_mask =
to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type()))));
to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(output_base_shape.element_type()))));
}
auto result_buffer = CreateZero(padded_result_buffer_shape, &b_);
auto iteration = b_.AddInstruction(
auto result_buffer = CreateZero(padded_result_buffer_shape, b);
auto iteration = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
// Create a while loop that computes one window per iteration. During each
// iteration, each partition sends its input window to its neighbor using
// collective-permute for the next iteration.
SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_);
SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
@ -335,11 +305,12 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
auto i = body_b.AddInstruction(
HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3));
auto partition_id = collective_ops_creator_.create_partition_id(&body_b);
auto partition_id =
lhs.state().collective_ops_creator.create_partition_id(&body_b);
auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
i->shape(), HloOpcode::kAdd, i, partition_id));
auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions_)));
LiteralUtil::CreateR0<uint32>(num_partitions)));
data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count));
auto dot_lhs = l;
@ -350,7 +321,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
// operand as replicated, and resharding it to match the windowed operand.
auto slice_operand = matching_operand == 0 ? l : r;
slice_operand->set_sharding(HloSharding::Replicate());
auto state = MakePartitioningState();
auto state = lhs.state();
state.b = &body_b;
state.partition_id = data_partition_id;
auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state)
@ -392,26 +363,27 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), i,
body_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions_))),
LiteralUtil::CreateR0<uint32>(num_partitions))),
ComparisonDirection::kLt));
// Collective-permute for the next window. We don't need it for the last
// iteration, so we use a conditional around the collective-permute.
HloInstruction* conditional;
{
SpmdBuilder cp_b("window_collective_permute", visiting_hlo_);
SpmdBuilder cp_b("window_collective_permute", original_hlo);
{
auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
0, windowing_operand == 0 ? l->shape() : r->shape(), "window"));
std::vector<std::pair<int64, int64>> sd_pairs(num_partitions_);
for (int64 source = 0; source < num_partitions_; ++source) {
std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
for (int64 source = 0; source < num_partitions; ++source) {
// 0 -> n-1, 1 -> 0, 2 -> 1, ...
sd_pairs[source] = {source,
(source - 1 + num_partitions_) % num_partitions_};
(source - 1 + num_partitions) % num_partitions};
}
collective_ops_creator_.create_cross_partition_collective_permute(
&cp_b, p, sd_pairs, (*next_channel_id_)++);
lhs.state()
.collective_ops_creator.create_cross_partition_collective_permute(
&cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
}
SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_);
SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
{
ncp_b.AddInstruction(HloInstruction::CreateParameter(
0, windowing_operand == 0 ? l->shape() : r->shape(), "window"));
@ -419,9 +391,9 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
windowing_operand == 0 ? l->shape() : r->shape(), has_more,
windowing_operand == 0 ? l : r,
module_->AddEmbeddedComputation(cp_b.Build()),
module->AddEmbeddedComputation(cp_b.Build()),
windowing_operand == 0 ? l : r,
module_->AddEmbeddedComputation(ncp_b.Build())));
module->AddEmbeddedComputation(ncp_b.Build())));
}
if (windowing_operand == 0) {
l = conditional;
@ -430,7 +402,7 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
}
body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i}));
SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_);
SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0,
ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(),
@ -441,56 +413,53 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
cond_b.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), cond_i,
cond_b.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(num_partitions_))),
LiteralUtil::CreateR0<uint32>(num_partitions))),
ComparisonDirection::kLt));
auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile(
cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()),
module_->AddEmbeddedComputation(body_b.Build()),
b_.AddInstruction(HloInstruction::CreateTuple(
auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
module->AddEmbeddedComputation(body_b.Build()),
b->AddInstruction(HloInstruction::CreateTuple(
{lhs.hlo(), rhs.hlo(), result_buffer, iteration}))));
windowed_dot_general_loops_.push_back({while_loop, windowing_operand,
windowed_dot_general_loops->push_back({while_loop, windowing_operand,
windowed_at_contracting_dims,
windowed_at_batch_dims});
SetPartitionedHlo(hlo, [&] {
auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
result_buffer->shape(), while_loop, 2));
if (!ShapeUtil::Compatible(padded_result_buffer_shape,
unpadded_result_buffer_shape)) {
result = b_.AddInstruction(HloInstruction::CreateSlice(
unpadded_result_buffer_shape, result,
std::vector<int64>(padded_result_buffer_shape.rank(), 0),
unpadded_result_buffer_shape.dimensions(),
std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
}
return result;
});
return Status::OK();
auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
result_buffer->shape(), while_loop, 2));
if (!ShapeUtil::Compatible(padded_result_buffer_shape,
unpadded_result_buffer_shape)) {
result = b->AddInstruction(HloInstruction::CreateSlice(
unpadded_result_buffer_shape, result,
std::vector<int64>(padded_result_buffer_shape.rank(), 0),
unpadded_result_buffer_shape.dimensions(),
std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
}
return result;
};
if (output_lhs_non_contracting_partitions == num_partitions_ &&
if (output_lhs_non_contracting_partitions == num_partitions &&
output_sharding_transposed_to_match_lhs == lhs_sharding &&
ShapeSizeInBytes(hlo->operand(1)->shape()) >=
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (rhs_contracting_partitions == num_partitions_) {
ShapeSizeInBytes(rhs.base_shape()) >=
threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (rhs_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(0, 1, true, false);
}
if (rhs_non_contracting_partitions == num_partitions_) {
if (rhs_non_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(0, 1, false, false);
}
if (rhs_batch_partitions == num_partitions_) {
if (rhs_batch_partitions == num_partitions) {
return emit_windowed_dot_general(0, 1, false, true);
}
}
if (output_rhs_non_contracting_partitions == num_partitions_ &&
if (output_rhs_non_contracting_partitions == num_partitions &&
output_sharding_transposed_to_match_rhs == rhs_sharding &&
ShapeSizeInBytes(hlo->operand(0)->shape()) >=
options_.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (lhs_contracting_partitions == num_partitions_) {
ShapeSizeInBytes(lhs.base_shape()) >=
threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (lhs_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(1, 0, true, false);
}
if (lhs_non_contracting_partitions == num_partitions_) {
if (lhs_non_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(1, 0, false, false);
}
if (lhs_batch_partitions == num_partitions_) {
if (lhs_batch_partitions == num_partitions) {
return emit_windowed_dot_general(1, 0, false, true);
}
}
@ -498,18 +467,18 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
{
// Try batch-parallel by resharding one operand, and allowing all-reduce.
TF_ASSIGN_OR_RETURN(
bool emitted,
HloInstruction * partitioned_dot,
try_emit_output_batch_partitioned_einsum_with_reshard(true));
if (emitted) {
return Status::OK();
if (partitioned_dot) {
return partitioned_dot;
}
}
// LHS and RHS have the same partitioned contracting dimensions.
if (lhs_contracting_partitions == rhs_contracting_partitions &&
lhs_contracting_partitions == num_partitions_) {
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
lhs_contracting_partitions == num_partitions) {
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(output_base_shape.element_type())));
// Pad both sides with zero, since NaN at one side cannot be masked by zero
// on the other side.
if (ShapeSizeInBytes(lhs.base_shape()) <
@ -522,100 +491,91 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] {
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
auto ar =
lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
(*lhs.state().next_channel_id)++);
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, output_base_shape, lhs.state())
.Reshard(output_sharding)
.hlo();
}
// LHS and output have the same partitioned non-contracting dimensions.
if (lhs_non_contracting_partitions == num_partitions_ &&
output_lhs_non_contracting_partitions == num_partitions_ &&
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
if (lhs_non_contracting_partitions == num_partitions &&
output_lhs_non_contracting_partitions == num_partitions &&
lhs_sharding_transposed_to_match_output == output_sharding) {
auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs_replicated, &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
create_sharded_dot(lhs.hlo(), rhs_replicated, b));
return dot;
}
// RHS and output have the same partitioned non-contracting dimensions.
if (rhs_non_contracting_partitions == num_partitions_ &&
output_rhs_non_contracting_partitions == num_partitions_ &&
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
if (rhs_non_contracting_partitions == num_partitions &&
output_rhs_non_contracting_partitions == num_partitions &&
rhs_sharding_transposed_to_match_output == output_sharding) {
auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs_replicated, rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
create_sharded_dot(lhs_replicated, rhs.hlo(), b));
return dot;
}
// Output is batch partitioned.
if (output_batch_partitions == num_partitions_) {
if (output_batch_partitions == num_partitions) {
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
resharded_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
resharded_rhs.hlo(), b));
return dot;
}
// Output is partitioned along LHS non-contracting dimensions.
if (output_lhs_non_contracting_partitions == num_partitions_) {
if (output_lhs_non_contracting_partitions == num_partitions) {
auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs);
auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
TF_ASSIGN_OR_RETURN(
auto dot,
create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(),
replicated_rhs.hlo(), b));
return dot;
}
// Output is partitioned along RHS non-contracting dimensions.
if (output_rhs_non_contracting_partitions == num_partitions_) {
if (output_rhs_non_contracting_partitions == num_partitions) {
auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs);
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(),
resharded_rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] { return dot; });
return Status::OK();
resharded_rhs.hlo(), b));
return dot;
}
// Returns true if it is beneficial to reshard the operand at `operand_idx`
// across the contracting dimension.
const auto should_partition_contracting_dim = [&](int64 operand_idx) {
if (!hlo->sharding().IsReplicated()) {
if (!output_sharding.IsReplicated()) {
return false;
}
if (operand_idx == 0) {
// If LHS and output are replicated, we compare the cost of all-gather
// on RHS vs all-reduce on the output.
return (rhs_contracting_partitions == num_partitions_) &&
return (rhs_contracting_partitions == num_partitions) &&
lhs.sharding().IsReplicated() &&
ShapeUtil::ElementsIn(hlo->operand(1)->shape()) >
ShapeUtil::ElementsIn(hlo->shape());
ShapeUtil::ElementsIn(rhs.base_shape()) >
ShapeUtil::ElementsIn(output_base_shape);
} else {
return (lhs_contracting_partitions == num_partitions_) &&
return (lhs_contracting_partitions == num_partitions) &&
rhs.sharding().IsReplicated() &&
ShapeUtil::ElementsIn(hlo->operand(0)->shape()) >
ShapeUtil::ElementsIn(hlo->shape());
ShapeUtil::ElementsIn(lhs.base_shape()) >
ShapeUtil::ElementsIn(output_base_shape);
}
};
// When the output is replicated and one of the operands is partitioned along
// contracting dimension, align the other operand to be partitioned along
// the contracting dimensions.
if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) ||
if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
should_partition_contracting_dim(1))) {
auto zero = b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(hlo->shape().element_type())));
auto zero = b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(output_base_shape.element_type())));
if (should_partition_contracting_dim(0)) {
lhs =
lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
@ -625,19 +585,361 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
rhs =
rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
}
TF_ASSIGN_OR_RETURN(auto dot,
create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_));
SetPartitionedHlo(hlo, [&] {
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_),
NewChannel());
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo();
});
return Status::OK();
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b));
return lhs.state().collective_ops_creator.create_cross_partition_all_reduce(
b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {},
(*lhs.state().next_channel_id)++);
}
return nullptr;
}
StatusOr<HloInstruction*> PartitionDot(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops);
StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
std::vector<int64> lhs_dims;
std::vector<int64> rhs_dims;
std::vector<int64> output_dims;
auto lhs_sharding_dims_adjusted_to_output =
lhs.sharding().tile_assignment().dimensions();
auto rhs_sharding_dims_adjusted_to_output =
lhs.sharding().tile_assignment().dimensions();
auto output_sharding_dims_adjusted_to_lhs =
output_sharding.tile_assignment().dimensions();
bool lhs_rhs_dims_matching = true;
for (const auto& dim : dims_mapping.batch_dims) {
lhs_dims.push_back(dim.lhs);
rhs_dims.push_back(dim.rhs);
output_dims.push_back(dim.output);
if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
lhs_rhs_dims_matching = false;
}
lhs_sharding_dims_adjusted_to_output[dim.lhs] =
output_sharding.tile_assignment().dim(dim.output);
rhs_sharding_dims_adjusted_to_output[dim.rhs] =
output_sharding.tile_assignment().dim(dim.output);
output_sharding_dims_adjusted_to_lhs[dim.output] =
lhs.sharding().tile_assignment().dim(dim.lhs);
}
auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
if (lhs_rhs_dims_matching) {
if (ShapeUtil::ByteSizeOf(lhs.base_shape()) >
ShapeUtil::ByteSizeOf(rhs.base_shape())) {
rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
} else {
lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
}
auto reshaped_output_tiling = output_sharding.tile_assignment();
reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
output_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling),
output_dims),
lhs_grouped);
} else {
auto reshaped_lhs_tiling = lhs.sharding().tile_assignment();
reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output);
lhs_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims),
output_grouped);
lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
auto reshaped_rhs_tiling = rhs.sharding().tile_assignment();
reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output);
rhs_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims),
output_grouped);
rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
}
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
lhs.state(), lhs_grouped.device_groups, b);
lhs.hlo()->set_sharding(lhs_grouped.sharding);
rhs.hlo()->set_sharding(rhs_grouped.sharding);
CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding);
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDot(
PartitionedHlo(lhs.hlo(),
GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
per_group_partitioner_state),
PartitionedHlo(rhs.hlo(),
GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
per_group_partitioner_state),
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / lhs_grouped.device_groups.size(), create_sharded_dot,
module, original_hlo, threshold_for_windowed_einsum_mib, b,
windowed_dot_general_loops));
// Reset the LHS sharding to the ungrouped one.
lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped));
rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped));
dot->set_sharding(UngroupSharding(output_grouped));
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
.hlo();
}
StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
int64 matching_contracting_partitions, int64 other_contracting_partitions,
int64 matching_non_contracting_partitions,
int64 other_non_contracting_partitions,
int64 output_other_non_contracting_partitions,
const Shape& output_base_shape, const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
const bool may_replicate_other_contracting_dims =
(other_contracting_partitions == matching_non_contracting_partitions &&
other_non_contracting_partitions ==
output_other_non_contracting_partitions);
const bool may_replicate_other_non_contracting_dims =
matching_non_contracting_partitions == other_non_contracting_partitions &&
matching_contracting_partitions == other_contracting_partitions;
std::vector<int64> other_group_dims;
if (may_replicate_other_contracting_dims &&
(!may_replicate_other_non_contracting_dims ||
ShapeUtil::ByteSizeOf(other.base_shape()) <=
ShapeUtil::ByteSizeOf(output_base_shape))) {
for (const auto& dim : dims_mapping.contracting_dims) {
other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
}
} else if (may_replicate_other_non_contracting_dims) {
for (const auto& dim : lhs_matching
? dims_mapping.rhs_non_contracting_dims
: dims_mapping.lhs_non_contracting_dims) {
other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
}
} else {
return nullptr;
}
auto matching_sharding_dims =
matching.sharding().tile_assignment().dimensions();
std::vector<int64> matching_dims;
std::vector<int64> output_dims;
// Make sure the partitioning on matching's non-contracting dimensions
// defines the same device groups for both matching and output.
for (const auto& dim : lhs_matching ? dims_mapping.lhs_non_contracting_dims
: dims_mapping.rhs_non_contracting_dims) {
int64 md = lhs_matching ? dim.lhs : dim.rhs;
matching_sharding_dims[md] =
output_sharding.tile_assignment().dim(dim.output);
matching_dims.push_back(md);
output_dims.push_back(dim.output);
}
auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
auto reshaped_matching_tiling = matching.sharding().tile_assignment();
reshaped_matching_tiling.Reshape(matching_sharding_dims);
auto matching_grouped = AlignGroupsWith(
GroupShardingOnDims(HloSharding::Tile(reshaped_matching_tiling),
matching_dims),
output_grouped);
matching = matching.Reshard(UngroupSharding(matching_grouped));
auto other_grouped =
AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims),
output_grouped, /*ignore_group_order=*/true);
other = other.Reshard(UngroupSharding(other_grouped));
auto partially_replicated_other =
other.ReplicatePartial(other_grouped.group_dims);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
matching.state(), matching_grouped.device_groups, b);
matching.hlo()->set_sharding(matching_grouped.sharding);
partially_replicated_other->set_sharding(other_grouped.sharding);
auto matching_p = PartitionedHlo(
matching.hlo(),
GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
per_group_partitioner_state);
auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
per_group_partitioner_state);
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDot(lhs_matching ? matching_p : other_p,
lhs_matching ? other_p : matching_p,
GetPerGroupBaseShape(output_grouped, output_base_shape),
output_grouped.sharding, dims_mapping,
num_partitions / matching_grouped.device_groups.size(),
create_sharded_dot, module, original_hlo,
threshold_for_windowed_einsum_mib, b,
windowed_dot_general_loops));
// Reset matching's sharding to the ungrouped one.
matching.hlo()->set_sharding(UngroupSharding(matching_grouped));
return dot;
}
// Recursive partitioning function. If there are partial dimensions matching in
// the operands and output, group the devices and recursively partition the
// in-group dot.
StatusOr<HloInstruction*> PartitionDot(
PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
const HloSharding& output_sharding,
const DotGeneralDimsMapping& dims_mapping, int64 num_partitions,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot,
HloModule* module, HloInstruction* original_hlo,
int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b,
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
windowed_dot_general_loops) {
// lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
auto get_partitions_for_dims =
[&](const HloSharding& sharding,
absl::Span<const DotGeneralDimsMapping::DimsMapping> dims,
int lhs_rhs_or_output) {
int64 partitions = 1;
if (sharding.IsTileMaximal()) {
return partitions;
}
for (const auto& dim : dims) {
if (lhs_rhs_or_output == 0) {
partitions *= sharding.tile_assignment().dim(dim.lhs);
} else if (lhs_rhs_or_output == 1) {
partitions *= sharding.tile_assignment().dim(dim.rhs);
} else {
CHECK_EQ(lhs_rhs_or_output, 2);
partitions *= sharding.tile_assignment().dim(dim.output);
}
}
return partitions;
};
const int64 lhs_batch_partitions =
get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
const int64 rhs_batch_partitions =
get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
const int64 output_batch_partitions =
get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
const int64 lhs_contracting_partitions =
get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
const int64 rhs_contracting_partitions =
get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
const int64 lhs_non_contracting_partitions = get_partitions_for_dims(
lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
const int64 rhs_non_contracting_partitions = get_partitions_for_dims(
rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims(
output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims(
output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
TF_ASSIGN_OR_RETURN(
auto try_partitioned_dot,
PartitionBaseCase(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
lhs_contracting_partitions, rhs_contracting_partitions,
lhs_non_contracting_partitions, rhs_non_contracting_partitions,
output_lhs_non_contracting_partitions,
output_rhs_non_contracting_partitions,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
if (try_partitioned_dot) {
return try_partitioned_dot;
}
return DefaultAction(hlo);
// Recursively partition on different types of dimensions.
//
// Case 1: Group partitions by batch.
if (lhs_batch_partitions == rhs_batch_partitions &&
lhs_batch_partitions == output_batch_partitions &&
lhs_batch_partitions > 1) {
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnBatch(
lhs, rhs, output_base_shape, output_sharding, dims_mapping,
num_partitions, create_sharded_dot, module, original_hlo,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
}
// Case 2: Group partitions by non-contracting dimensions.
const bool may_group_on_lhs_non_contracting =
lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
lhs_non_contracting_partitions > 1;
const bool may_group_on_rhs_non_contracting =
rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
rhs_non_contracting_partitions > 1;
if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
// If both match output non-contracting dimensions, choose the one which
// will result in smaller replication of the other operand.
const bool lhs_matching =
may_group_on_lhs_non_contracting &&
(!may_group_on_rhs_non_contracting ||
lhs_non_contracting_partitions *
ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <=
rhs_non_contracting_partitions *
ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
TF_ASSIGN_OR_RETURN(
auto dot,
PartitionDotGroupOnNonContracting(
lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
lhs_matching ? lhs_contracting_partitions
: rhs_contracting_partitions,
lhs_matching ? rhs_contracting_partitions
: lhs_contracting_partitions,
lhs_matching ? lhs_non_contracting_partitions
: rhs_non_contracting_partitions,
lhs_matching ? rhs_non_contracting_partitions
: lhs_non_contracting_partitions,
lhs_matching ? output_rhs_non_contracting_partitions
: output_lhs_non_contracting_partitions,
output_base_shape, output_sharding, dims_mapping, num_partitions,
create_sharded_dot, module, original_hlo,
threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops));
if (dot) {
return dot;
}
}
// Default action.
TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(),
rhs.Replicate().hlo(), b));
dot->set_sharding(HloSharding::Replicate());
return PartitionedHlo(dot, output_base_shape, lhs.state())
.Reshard(output_sharding)
.hlo();
}
} // namespace
Status SpmdPartitioningVisitor::HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) {
auto& lhs = GetPartitionedHlo(hlo->operand(0));
auto& rhs = GetPartitionedHlo(hlo->operand(1));
TF_ASSIGN_OR_RETURN(
auto partitioned_dot,
PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
num_partitions_, create_sharded_dot, module_, hlo,
options_.threshold_for_windowed_einsum_mib, &b_,
&windowed_dot_general_loops_));
SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
return Status::OK();
}
namespace {
@ -780,6 +1082,7 @@ Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
[](const HloInstruction* a, const HloInstruction* b) {
return a->unique_id() < b->unique_id();
});
worklist.reserve(nullaries_to_sink.size());
for (auto inst : nullaries_to_sink) {
worklist.push_back(inst);
}

View File

@ -165,16 +165,6 @@ template <typename F>
namespace {
// Returns the replica group configuration where each replica belongs to its own
// group.
std::vector<ReplicaGroup> CreateReplicaGroups(int64 num_replicas) {
std::vector<ReplicaGroup> groups(num_replicas);
for (int64 i = 0; i < num_replicas; ++i) {
groups[i].add_replica_ids(i);
}
return groups;
}
// Clears all sharding attributes from instructions in the module. This must be
// called only after all SPMD transformation is complete.
Status ClearShardingAttributes(HloModule* module) {
@ -195,6 +185,28 @@ Status ClearShardingAttributes(HloModule* module) {
return Status::OK();
}
std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
const HloSharding& sharding, absl::Span<const int64> replication_dims) {
int64 group_size = 1;
for (int64 i : replication_dims) {
group_size *= sharding.tile_assignment().dim(i);
}
std::vector<std::vector<int64>> partition_groups(
sharding.tile_assignment().num_elements() / group_size);
sharding.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 partition) {
int64 group_id = 0;
for (int64 i = 0; i < indices.size(); ++i) {
if (!absl::c_linear_search(replication_dims, i)) {
group_id *= sharding.tile_assignment().dim(i);
group_id += indices[i];
}
}
partition_groups[group_id].push_back(partition);
});
return partition_groups;
}
} // namespace
HloInstruction* SpmdBuilder::AddInstruction(
@ -664,42 +676,57 @@ PartitionedHlo PartitionedHlo::Replicate() {
}
// 'Tiled' to 'Replicated'.
std::vector<int64> all_dims(shape.rank());
std::iota(all_dims.begin(), all_dims.end(), 0);
HloInstruction* result = ReplicatePartial(all_dims);
result->set_sharding(HloSharding::Replicate());
return update_cache(PartitionedHlo(result, base_shape_, state_));
}
HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
CHECK(!sharding().IsTileMaximal());
const Shape& shard_shape = hlo()->shape();
Shape target_shape = shard_shape;
Shape padded_target_shape = shard_shape;
for (int64 i : dims) {
padded_target_shape.set_dimensions(
i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
target_shape.set_dimensions(i, base_shape().dimensions(i));
}
HloInstruction* result = nullptr;
if (state_.collective_ops_creator.create_cross_partition_all_gather) {
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding,
NewChannel());
}
Shape padded_base_shape = shape;
for (int64 i = 0; i < padded_base_shape.rank(); ++i) {
padded_base_shape.set_dimensions(
i, shape.dimensions(i) * sharding.tile_assignment().dim(i));
result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(),
NewChannel(), dims,
state_.collective_ops_creator);
}
if (result == nullptr) {
auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(shape.element_type())));
LiteralUtil::Zero(shard_shape.element_type())));
auto zero_bcast = state_.b->AddInstruction(
HloInstruction::CreateBroadcast(padded_base_shape, zero, {}));
HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
state_.partition_id, state_.b, dims);
auto dus =
state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
padded_base_shape, zero_bcast, hlo_,
MakePartitionOffsets(padded_base_shape, sharding,
state_.partition_id, state_.b)));
padded_target_shape, zero_bcast, hlo_, offsets));
HloComputation* reduction =
MakeBinaryAdd(shape.element_type(), state_.module);
MakeBinaryAdd(shard_shape.element_type(), state_.module);
auto all_reduce =
state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, dus, reduction, NewChannel());
state_.b, dus, reduction,
GetPartitionGroupsForReplication(sharding(), dims), NewChannel());
result = all_reduce;
}
if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) {
std::vector<int64> start_indices(shape.rank(), 0);
std::vector<int64> strides(shape.rank(), 1);
result = state_.b->AddInstruction(HloInstruction::CreateSlice(
base_shape_, result, start_indices, base_shape_.dimensions(), strides));
if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
std::vector<int64> start_indices(target_shape.rank(), 0);
std::vector<int64> strides(target_shape.rank(), 1);
result = state_.b->AddInstruction(
HloInstruction::CreateSlice(target_shape, result, start_indices,
base_shape_.dimensions(), strides));
}
result->set_sharding(HloSharding::Replicate());
return update_cache(PartitionedHlo(result, base_shape_, state_));
return result;
}
PartitionedHlo PartitionedHlo::Broadcast() const {
@ -728,7 +755,7 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
MakeBinaryAdd(shape.element_type(), state_.module);
auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
state_.b, operand, reduction, NewChannel());
state_.b, operand, reduction, {}, NewChannel());
result->set_sharding(HloSharding::Replicate());
return PartitionedHlo(result, base_shape_, state_);
}
@ -796,7 +823,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
// The order of ids in the group must follow the temp_target sharding.
std::vector<ReplicaGroup> groups(
std::vector<std::vector<int64>> groups(
temp_target.tile_assignment().num_elements() / group_size);
temp_target.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
@ -810,7 +837,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
group_id += indices[dim];
}
}
groups[group_id].add_replica_ids(device);
groups[group_id].push_back(device);
});
HloInstruction* result = nullptr;
@ -1027,7 +1054,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
offset += operand->shape().dimensions(dimension);
}
auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_),
&b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), {},
NewChannel());
SetPartitionedHlo(hlo, [&] {
auto start_indices =
@ -2153,7 +2180,7 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
// Combine from different partitions.
auto ar = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, filtered,
MakeBinaryAdd(filtered->shape().element_type(), module_),
MakeBinaryAdd(filtered->shape().element_type(), module_), {},
NewChannel());
ar->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
@ -2449,7 +2476,7 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
if (reduce_sharded_dimension) {
CHECK(local_reduce->shape().IsArray());
reduce = collective_ops_creator_.create_cross_partition_all_reduce(
&b_, local_reduce, hlo->to_apply(), NewChannel());
&b_, local_reduce, hlo->to_apply(), {}, NewChannel());
reduce->set_sharding(HloSharding::Replicate());
} else {
reduce = local_reduce;
@ -2917,13 +2944,36 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
[](SpmdBuilder* b) {
return b->AddInstruction(HloInstruction::CreatePartitionId());
},
[num_replicas](SpmdBuilder* b, HloInstruction* operand,
HloComputation* reduction, int64 channel_id) {
[num_replicas, num_partitions](
SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id) {
if (partition_subgroups.size() <= 1) {
std::vector<ReplicaGroup> groups(num_replicas);
// TODO(yuanzx): Unify subgroup definition with AllToAll.
for (int64 i = 0; i < num_replicas; ++i) {
groups[i].add_replica_ids(i);
}
return b->AddInstruction(HloInstruction::CreateAllReduce(
operand->shape(), {operand}, reduction, groups,
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/false));
}
std::vector<ReplicaGroup> device_groups;
device_groups.reserve(partition_subgroups.size() * num_replicas);
for (int64 i = 0; i < num_replicas; ++i) {
for (const auto& pgroup : partition_subgroups) {
device_groups.emplace_back();
for (int64 pid : pgroup) {
device_groups.back().add_replica_ids(i * num_partitions + pid);
}
}
}
return b->AddInstruction(HloInstruction::CreateAllReduce(
operand->shape(), {operand}, reduction,
CreateReplicaGroups(num_replicas),
operand->shape(), {operand}, reduction, device_groups,
/*constrain_layout=*/false, channel_id,
/*use_global_device_ids=*/false));
/*use_global_device_ids=*/true));
},
[](SpmdBuilder* b, HloInstruction* operand,
std::vector<std::pair<int64, int64>>& src_dst_pairs,
@ -2932,14 +2982,20 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
operand->shape(), operand, src_dst_pairs, channel_id));
},
[](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
absl::optional<int64> split_dimension) {
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, absl::optional<int64> split_dimension) {
std::vector<Shape> shapes(operands.size(), operands[0]->shape());
const Shape output_shape = (shapes.size() == 1)
? shapes[0]
: ShapeUtil::MakeTupleShape(shapes);
std::vector<ReplicaGroup> groups(partition_subgroups.size());
for (int64 i = 0; i < groups.size(); ++i) {
for (int64 id : partition_subgroups[i]) {
groups[i].add_replica_ids(id);
}
}
return b->AddInstruction(HloInstruction::CreateAllToAll(
output_shape, operands, replica_groups,
output_shape, operands, groups,
/*constrain_layout=*/false, channel_id, split_dimension));
},
[num_replicas, num_partitions](
@ -2970,10 +3026,10 @@ SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas,
num_partitions, num_replicas, std::move(options),
GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
HloInstruction* operand,
const HloSharding& sharding,
int64 channel_id) {
HloInstruction* SpmdPartitioner::AllGatherShards(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64 channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator) {
CHECK(!sharding.IsTileMaximal());
// Add one leading dimension to gather all partitions.
std::vector<int64> shape;
@ -2983,18 +3039,17 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
}
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
std::vector<std::vector<int64>> partition_subgroups(1);
for (int64 pid : sharding.tile_assignment()) {
partition_subgroups[0].push_back(pid);
}
shape[0] = sharding.tile_assignment().num_elements();
auto result = collective_ops_creator_.create_cross_partition_all_gather(
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, selected_dims);
shape[0] = partition_subgroups[0].size();
auto result = collectives_creator.create_cross_partition_all_gather(
b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, channel_id, /*all_gather_dimension=*/0);
// If n > 1 dimensions are partitioned, split the leading dimension to n.
std::vector<int64> tiled_dims;
for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (sharding.tile_assignment().dim(i) > 1) {
if (sharding.tile_assignment().dim(i) > 1 &&
absl::c_linear_search(selected_dims, i)) {
tiled_dims.push_back(i);
}
}
@ -3016,7 +3071,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
std::vector<int64> xpose_permutation(result->shape().rank());
int64 split_dims_added = 0;
for (int64 i = 0; i < xpose_permutation.size(); ++i) {
if (sharding.tile_assignment().dim(i - split_dims_added) == 1) {
if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
!absl::c_linear_search(selected_dims, i - split_dims_added)) {
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
} else {
xpose_permutation[i] = split_dims_added;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -82,8 +83,10 @@ struct SPMDCollectiveOpsCreator {
std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
// Function used to create a cross-partition all-reduce HLO.
std::function<HloInstruction*(SpmdBuilder*, HloInstruction* operand,
HloComputation* reduction, int64 channel_id)>
std::function<HloInstruction*(
SpmdBuilder*, HloInstruction* operand, HloComputation* reduction,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id)>
create_cross_partition_all_reduce;
// Function used to create a cross-partition collective-permute HLO.
@ -96,8 +99,8 @@ struct SPMDCollectiveOpsCreator {
// Function used to create a cross-partition all-to-all HLO.
std::function<HloInstruction*(
SpmdBuilder*, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
absl::optional<int64> split_dimension)>
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, absl::optional<int64> split_dimension)>
create_cross_partition_all_to_all;
// Function used to create a cross-partition all-gather HLO. This is optional:
@ -169,10 +172,13 @@ class SpmdPartitioner : public HloModulePass {
// The default uses a single all-gather even if there are multiple sharded
// dimensions, and adds potential reshapes and transposes to achieve that.
// If it returns false, the partitioner will fall back to all-reduce.
virtual HloInstruction* AllGatherShards(SpmdBuilder* b,
HloInstruction* operand,
const HloSharding& sharding,
int64 channel_id);
// `selected_dims` specifies the dimensions along which the all-gather happens
// in the tiled sharding, which allows potentially creating a subgroup
// all-gather.
virtual HloInstruction* AllGatherShards(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64 channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator);
protected:
virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
@ -215,7 +221,12 @@ class PartitionedHlo {
std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
window_reshard_cache;
};
// Use std::unordered_map for pointer stability.
std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
// Caches for nested partitioning of grouped sharding. Each string key
// represents a unique way of grouping devices.
absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>>
groupd_caches;
};
struct PartitioningState {
SpmdBuilder* b;
@ -270,15 +281,18 @@ class PartitionedHlo {
const PartitioningState& state() const { return state_; }
// Helper function to replicate the data on all devices. Could only modify
// the reshard cache.
PartitionedHlo Replicate();
// Helper function to replicate the data for partitions along the given dims.
HloInstruction* ReplicatePartial(absl::Span<const int64> dims);
private:
// Same as Reshard except that it does not explicitly modify the reshard
// cache, although it would indirectly modify by calling Replicate().
PartitionedHlo ReshardNoCache(const HloSharding& target);
// Helper function to replicate the data on all devices. Could only modify
// the reshard cache.
PartitionedHlo Replicate();
// Helper function to broadcast data from a single device to all devices.
PartitionedHlo Broadcast() const;
@ -417,6 +431,16 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
StatusOr<bool> DoPartition(HloComputation* computation,
const HloSharding& root_sharding);
// Information about a loop created for windowed dot-general. Used when
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
// finishes traversing the graph.
struct WindowedDotGeneralLoop {
HloInstruction* while_loop;
int64 windowed_operand;
bool windowed_in_contracting_dims;
bool windowed_in_batch_dims;
};
private:
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
@ -445,15 +469,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
// partitioned instruction.
ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
// Information about a loop created for windowed dot-general. Used when
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
// finishes traversing the graph.
struct WindowedDotGeneralLoop {
HloInstruction* while_loop;
int64 windowed_operand;
bool windowed_in_contracting_dims;
bool windowed_in_batch_dims;
};
std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
HloInstruction* visiting_hlo_;

View File

@ -2218,7 +2218,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
std::cout << module->ToString();
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
@ -2294,7 +2294,7 @@ ENTRY entry
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
std::cout << module->ToString();
VLOG(1) << module->ToString();
auto sort = FindInstruction(module.get(), "sort");
EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
@ -3842,6 +3842,154 @@ ENTRY entry {
EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
}
TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3}
%rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,1,2,3}
ROOT %dot = f32[48,32] dot(%lhs, %rhs),
lhs_batch_dims={}, rhs_batch_dims={},
lhs_contracting_dims={1}, rhs_contracting_dims={1},
sharding={devices=[2,2]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0));
auto partial_replicated_lhs =
AllOf(op::Shape("f32[24,12]"),
op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _)));
auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1));
auto partial_replicated_rhs =
AllOf(op::Shape("f32[16,12]"), op::AllReduce(op::DynamicUpdateSlice(
_, op::CollectivePermute(rhs), _, _)));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root,
AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs),
op::Shape("f32[24,16]")));
}
TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3}
%rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
ROOT %dot = f32[48,32] dot(%lhs, %rhs),
lhs_batch_dims={}, rhs_batch_dims={},
lhs_contracting_dims={1}, rhs_contracting_dims={1},
sharding={devices=[2,2]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
auto partial_replicated_rhs =
AllOf(op::Shape("f32[32,50]"),
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(
root, AllOf(op::Shape("f32[24,16]"),
op::DynamicSlice(
op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs),
op::Shape("f32[24,32]"))),
_, _)));
}
TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
%rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[2,2,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0));
auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
auto partial_replicated_rhs =
AllOf(op::Shape("f32[2,32,100]"),
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _)));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
op::Dot(lhs, partial_replicated_rhs)));
}
TEST_F(SpmdPartitioningTest,
Dot2DPartitionedBatchNonContractingAndContracting) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
%rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={2},
sharding={devices=[2,1,2]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
auto partial_replicated_lhs =
AllOf(op::Shape("f32[2,24,100]"),
op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _)));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"),
op::Dot(partial_replicated_lhs, rhs)));
}
TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3}
%rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3}
ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs),
lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
lhs_contracting_dims={3}, rhs_contracting_dims={3},
sharding={devices=[1,2,2,1]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0));
auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1));
auto partial_replicated_rhs =
AllOf(op::Shape("f32[2,8,32,100]"),
op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _)));
auto dot =
AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs));
auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot));
auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape));
auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose)));
}
} // namespace
} // namespace spmd
} // namespace xla

View File

@ -16,7 +16,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include <algorithm>
#include <memory>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@ -143,10 +148,10 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
return partition_shape;
}
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
const HloSharding& sharding,
HloInstruction* partition_id,
SpmdBuilder* b) {
std::vector<HloInstruction*> MakePartitionOffsets(
const Shape& shape, const HloSharding& sharding,
HloInstruction* partition_id, SpmdBuilder* b,
absl::Span<const int64> dims) {
CHECK(!shape.IsTuple());
Array2D<int32> offset_array(
@ -158,7 +163,8 @@ std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
LiteralUtil::CreateR2FromArray2D(offset_array)));
std::vector<HloInstruction*> offsets;
for (int64 i = 0; i < shape.rank(); ++i) {
if (sharding.tile_assignment().dim(i) == 1) {
if (sharding.tile_assignment().dim(i) == 1 ||
(!dims.empty() && !absl::c_linear_search(dims, i))) {
offsets.push_back(b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
} else {
@ -978,5 +984,255 @@ bool CanReshardWithCollectivePermute(const HloSharding& source,
source.tile_assignment() != target.tile_assignment();
}
GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
absl::Span<const int64> group_dims) {
CHECK(!sharding.IsTileMaximal());
std::vector<int64> grouped_tiling_dims =
sharding.tile_assignment().dimensions();
std::vector<int64> group_dim_sizes(group_dims.size());
for (int64 i = 0; i < group_dims.size(); ++i) {
group_dim_sizes[i] = grouped_tiling_dims[group_dims[i]];
grouped_tiling_dims[group_dims[i]] = 1;
}
std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes));
sharding.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
int64 group_id = 0;
for (int64 dim : group_dims) {
group_id *= sharding.tile_assignment().dim(dim);
group_id += indices[dim];
}
device_groups[group_id].push_back(device);
});
Array<int64> grouped_tiling(grouped_tiling_dims);
grouped_tiling.FillIota(0);
return GroupedSharding(
std::move(device_groups),
std::vector<int64>(group_dims.begin(), group_dims.end()),
std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
HloSharding::Tile(grouped_tiling));
}
HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
CHECK(!grouped_sharding.sharding.IsTileMaximal());
std::vector<int64> tiling_dims =
grouped_sharding.sharding.tile_assignment().dimensions();
for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
tiling_dims[grouped_sharding.group_dims[i]] =
grouped_sharding.group_dim_sizes[i];
}
Array<int64> tiling(tiling_dims);
grouped_sharding.sharding.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
std::vector<int64> ungrouped_inds(indices.begin(), indices.end());
for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
int64 remaining_group_index = g;
for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
ungrouped_inds[grouped_sharding.group_dims[i]] =
remaining_group_index % grouped_sharding.group_dim_sizes[i];
remaining_group_index /= grouped_sharding.group_dim_sizes[i];
}
tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
}
});
return HloSharding::Tile(tiling);
}
GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
const GroupedSharding& reference,
bool ignore_group_order) {
// Returns src -> dst index mapping.
auto get_permutation = [](absl::Span<const int64> src,
absl::Span<const int64> dst) {
CHECK_EQ(src.size(), dst.size());
absl::flat_hash_map<int64, int64> dst_reverse_map;
for (int64 i = 0; i < dst.size(); ++i) {
dst_reverse_map[dst[i]] = i;
}
std::vector<int64> permutation(src.size());
for (int64 i = 0; i < src.size(); ++i) {
auto it = dst_reverse_map.find(src[i]);
CHECK(it != dst_reverse_map.end());
permutation[i] = it->second;
}
return permutation;
};
CHECK_EQ(grouped_sharding.device_groups.size(),
reference.device_groups.size());
absl::flat_hash_map<int64, int64> device_to_ref_group;
for (int64 g = 0; g < reference.device_groups.size(); ++g) {
for (int64 device : reference.device_groups[g]) {
device_to_ref_group[device] = g;
}
}
auto unique_ref_dev_group = [&](absl::Span<const int64> devices) -> int64 {
int64 ref_g = -1;
for (int64 device : devices) {
if (ref_g == -1) {
ref_g = device_to_ref_group[device];
} else if (ref_g != device_to_ref_group[device]) {
return -1;
}
}
return ref_g;
};
bool matching_groups = true;
std::vector<int64> original_src_to_ref_permutation;
for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) {
int64 ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
matching_groups = false;
break;
}
if (g == 0) {
original_src_to_ref_permutation = get_permutation(
grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
}
}
if (matching_groups) {
auto tiles = grouped_sharding.sharding.tile_assignment();
tiles.Each([&](absl::Span<const int64> indices, int64* device) {
*device = original_src_to_ref_permutation[*device];
});
grouped_sharding.sharding = HloSharding::Tile(tiles);
}
grouped_sharding.device_groups = std::move(reference.device_groups);
return grouped_sharding;
}
Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
const Shape& original_base_shape) {
auto result = original_base_shape;
for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) {
int64 dim = grouped_sharding.group_dims[i];
int64 groups = grouped_sharding.group_dim_sizes[i];
result.set_dimensions(dim, result.dimensions(dim) / groups);
}
return result;
}
namespace {
HloInstruction* GetInGroupPartitionId(
HloInstruction* partition_id,
const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
int64 total_devices = device_groups.size() * device_groups[0].size();
std::vector<uint32> in_group_ids(total_devices);
for (uint32 i = 0; i < device_groups.size(); ++i) {
for (uint32 j = 0; j < device_groups[i].size(); ++j) {
in_group_ids[device_groups[i][j]] = j;
}
}
auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<uint32>(in_group_ids)));
return b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeScalarShape(U32),
b->AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
}
SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
const SPMDCollectiveOpsCreator& creator,
const std::vector<std::vector<int64>>& device_groups) {
SPMDCollectiveOpsCreator result;
result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
b);
};
auto expand_partition_groups =
[device_groups](
const std::vector<std::vector<int64>>& partition_subgroups) {
if (partition_subgroups.empty()) {
return device_groups;
}
std::vector<std::vector<int64>> result(partition_subgroups.size() *
device_groups.size());
for (int64 g = 0; g < device_groups.size(); ++g) {
for (int64 i = 0; i < partition_subgroups.size(); ++i) {
result[g * partition_subgroups.size() + i].resize(
partition_subgroups[i].size());
for (int64 j = 0; j < partition_subgroups[i].size(); ++j) {
result[g * partition_subgroups.size() + i][j] =
device_groups[g][partition_subgroups[i][j]];
}
}
}
return result;
};
result.create_cross_partition_all_reduce =
[creator, expand_partition_groups](
SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id) {
return creator.create_cross_partition_all_reduce(
b, operand, reduction, expand_partition_groups(partition_subgroups),
channel_id);
};
result.create_cross_partition_collective_permute =
[creator, device_groups](
SpmdBuilder* b, HloInstruction* operand,
std::vector<std::pair<int64, int64>>& src_dst_pairs,
int64 next_channel_id) {
std::vector<std::pair<int64, int64>> expanded_pairs(
src_dst_pairs.size() * device_groups.size());
for (int64 g = 0; g < device_groups.size(); ++g) {
for (int64 i = 0; i < src_dst_pairs.size(); ++i) {
expanded_pairs[g * src_dst_pairs.size() + i] =
std::pair<int64, int64>{
device_groups[g][src_dst_pairs[i].first],
device_groups[g][src_dst_pairs[i].second]};
}
}
return creator.create_cross_partition_collective_permute(
b, operand, expanded_pairs, next_channel_id);
};
result.create_cross_partition_all_to_all =
[creator, expand_partition_groups](
SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, absl::optional<int64> split_dimension) {
return creator.create_cross_partition_all_to_all(
b, operands, expand_partition_groups(partition_subgroups),
channel_id, split_dimension);
};
if (creator.create_cross_partition_all_gather) {
result.create_cross_partition_all_gather =
[creator, expand_partition_groups](
SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, int64 all_gather_dimension) {
return creator.create_cross_partition_all_gather(
b, operand, ag_shape,
expand_partition_groups(partition_subgroups), channel_id,
all_gather_dimension);
};
}
return result;
}
} // namespace
PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
const PartitionedHlo::PartitioningState& state,
const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
auto result = state;
result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
state.collective_ops_creator, device_groups);
result.partition_id =
GetInGroupPartitionId(state.partition_id, device_groups, b);
// Create a string key for the groups.
std::vector<std::string> per_group_strings(device_groups.size());
for (int64 i = 0; i < per_group_strings.size(); ++i) {
per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
}
auto& grouped_cache =
state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
if (!grouped_cache) {
grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>();
}
result.reshard_cache = grouped_cache.get();
return result;
}
} // namespace spmd
} // namespace xla

View File

@ -87,10 +87,12 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
// Generates the HLO instructions that represent the dimension offsets on any
// device. The size of the returned vector is the rank of the given shape.
std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape,
const HloSharding& sharding,
HloInstruction* partition_id,
SpmdBuilder* b);
// If `dims` is non-empty, the generated offsets will only be non-zero for those
// dimensions.
std::vector<HloInstruction*> MakePartitionOffsets(
const Shape& shape, const HloSharding& sharding,
HloInstruction* partition_id, SpmdBuilder* b,
absl::Span<const int64> dims = {});
// Returns the offsets of the partition in the tile assignment.
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
@ -276,6 +278,48 @@ GetReshardAllToAllSourceTargetDims(const HloSharding& source,
bool CanReshardWithCollectivePermute(const HloSharding& source,
const HloSharding& target);
// Represents grouping devices in a tiled sharding along certain dimensions.
// Elements in group dimensions define different device groups, and the sharding
// represents the in-group sharding.
struct GroupedSharding {
GroupedSharding(std::vector<std::vector<int64>> device_groups,
std::vector<int64> group_dims,
std::vector<int64> group_dim_sizes, int64 rank,
HloSharding grouped_sharding)
: device_groups(std::move(device_groups)),
group_dims(std::move(group_dims)),
group_dim_sizes(std::move(group_dim_sizes)),
sharding(std::move(grouped_sharding)) {}
std::vector<std::vector<int64>> device_groups;
std::vector<int64> group_dims;
std::vector<int64> group_dim_sizes;
int64 rank;
HloSharding sharding;
};
// Creates a GroupedSharding for a tiled sharding.
GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
absl::Span<const int64> group_dims);
// Reconstructs the ungrouped sharding from a GroupedSharding.
HloSharding UngroupSharding(const GroupedSharding& grouped_sharding);
// Returns a new GroupedSharding that has the same group definition of
// `reference`.
GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
const GroupedSharding& reference,
bool ignore_group_order = false);
// Returns the per-group base shape, i.e., before applying the in-group
// sharding.
Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
const Shape& original_base_shape);
// Creates the nested partitioner state for in-group patitioning.
PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
const PartitionedHlo::PartitioningState& state,
const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b);
} // namespace spmd
} // namespace xla