[XLA] Make sharding propagation faster
1. Avoid copy in MergeSharding 2. Keep track of a workset to avoid unnecessary computing. PiperOrigin-RevId: 326768403 Change-Id: Iea3f1ff3c448864a06f4ebb14c37f73a16ebea1e
This commit is contained in:
parent
a74a993732
commit
ea0a469bdd
@ -120,34 +120,34 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
|
||||
return IsShardingMoreSpecific(a, b) ? a : b;
|
||||
}
|
||||
|
||||
// Returns a sharding that is refined by merging old and to_merge. May combine
|
||||
// partial sharding in addition to MergeForMoreSpecificSharding().
|
||||
HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
bool may_combine_partial_sharding) {
|
||||
// Tries to refine `to_merge` by combining with `old`. Returns if the final
|
||||
// `to_merge` is more specific than `old`. May combine partial sharding in
|
||||
// addition to MergeForMoreSpecificSharding().
|
||||
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
|
||||
bool may_combine_partial_sharding) {
|
||||
if (old.IsTuple()) {
|
||||
HloSharding result = old;
|
||||
CHECK(to_merge.IsTuple());
|
||||
CHECK_EQ(old.tuple_elements().size(), to_merge.tuple_elements().size());
|
||||
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
|
||||
result.tuple_elements()[i] =
|
||||
MergeSharding(old.tuple_elements()[i], to_merge.tuple_elements()[i],
|
||||
CHECK(to_merge->IsTuple());
|
||||
bool changed = false;
|
||||
for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
|
||||
changed |=
|
||||
MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
|
||||
may_combine_partial_sharding);
|
||||
}
|
||||
return result;
|
||||
return changed;
|
||||
}
|
||||
if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
|
||||
!to_merge.ReplicateOnLastTileDim() ||
|
||||
!to_merge->ReplicateOnLastTileDim() ||
|
||||
old.tile_assignment().num_elements() !=
|
||||
to_merge.tile_assignment().num_elements()) {
|
||||
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
|
||||
to_merge->tile_assignment().num_elements()) {
|
||||
return IsShardingMoreSpecific(*to_merge, old);
|
||||
}
|
||||
// Combine the tile dimension sizes from new and old.
|
||||
int64 num_devices = old.tile_assignment().num_elements();
|
||||
std::vector<int64> new_tile_dims;
|
||||
bool compatible = true;
|
||||
new_tile_dims.reserve(to_merge.tile_assignment().num_dimensions());
|
||||
for (int64 i = 0; i < to_merge.tile_assignment().num_dimensions() - 1; ++i) {
|
||||
int64 new_dim = to_merge.tile_assignment().dim(i);
|
||||
new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
|
||||
for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
|
||||
int64 new_dim = to_merge->tile_assignment().dim(i);
|
||||
int64 old_dim = old.tile_assignment().dim(i);
|
||||
if (new_dim == 1) {
|
||||
new_tile_dims.push_back(old_dim);
|
||||
@ -163,7 +163,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
int64 replication = num_devices / Product(new_tile_dims);
|
||||
if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
|
||||
replication >= old.tile_assignment().dimensions().back()) {
|
||||
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
|
||||
return IsShardingMoreSpecific(*to_merge, old);
|
||||
}
|
||||
new_tile_dims.push_back(replication);
|
||||
Array<int64> new_tile(new_tile_dims);
|
||||
@ -174,7 +174,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
const HloSharding& sharding) {
|
||||
int64 group_id = 0;
|
||||
for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
|
||||
group_id *= to_merge.tile_assignment().dim(i);
|
||||
group_id *= to_merge->tile_assignment().dim(i);
|
||||
group_id += tile_indices[i];
|
||||
}
|
||||
return group_id;
|
||||
@ -183,9 +183,9 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
old_group_members[get_group_index(indices, old)].insert(device);
|
||||
});
|
||||
to_merge.tile_assignment().Each(
|
||||
to_merge->tile_assignment().Each(
|
||||
[&](absl::Span<const int64> indices, int64 device) {
|
||||
new_group_members[get_group_index(indices, to_merge)].insert(device);
|
||||
new_group_members[get_group_index(indices, *to_merge)].insert(device);
|
||||
});
|
||||
// Try to find the intersection of old and new replication groups, in
|
||||
// order to determine the merged tile assignment.
|
||||
@ -199,12 +199,12 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
if (old.tile_assignment().dim(i) == 1) {
|
||||
old_index[i] = 0;
|
||||
}
|
||||
if (to_merge.tile_assignment().dim(i) == 1) {
|
||||
if (to_merge->tile_assignment().dim(i) == 1) {
|
||||
new_index[i] = 0;
|
||||
}
|
||||
}
|
||||
int64 old_group_id = get_group_index(old_index, old);
|
||||
int64 new_group_id = get_group_index(new_index, to_merge);
|
||||
int64 new_group_id = get_group_index(new_index, *to_merge);
|
||||
if (old_group_members[old_group_id].empty() ||
|
||||
new_group_members[new_group_id].empty() ||
|
||||
*old_group_members[old_group_id].begin() !=
|
||||
@ -220,11 +220,13 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
if (replication == 1) {
|
||||
new_tile_dims.pop_back();
|
||||
new_tile.Reshape(new_tile_dims);
|
||||
return HloSharding::Tile(new_tile);
|
||||
*to_merge = HloSharding::Tile(new_tile);
|
||||
} else {
|
||||
*to_merge = HloSharding::PartialTile(new_tile);
|
||||
}
|
||||
return HloSharding::PartialTile(new_tile);
|
||||
return true;
|
||||
}
|
||||
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
|
||||
return IsShardingMoreSpecific(*to_merge, old);
|
||||
}
|
||||
|
||||
// Updates the sharding of the specified instruction with the specified sharding
|
||||
@ -232,7 +234,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
|
||||
// been applied. If may_combine_partial_sharding is true, this may combine the
|
||||
// new and existing sharding if they are both partial tiling partial
|
||||
// replication.
|
||||
bool MaybeImproveInstructionSharding(const HloSharding& sharding,
|
||||
bool MaybeImproveInstructionSharding(HloSharding sharding,
|
||||
HloInstruction* instruction,
|
||||
bool may_combine_partial_sharding) {
|
||||
// We don't want to propagate tile maximal shardings.
|
||||
@ -241,13 +243,13 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding,
|
||||
}
|
||||
// Any sharding is better then no sharding.
|
||||
if (!instruction->has_sharding()) {
|
||||
instruction->set_sharding(sharding);
|
||||
instruction->set_sharding(std::move(sharding));
|
||||
return true;
|
||||
}
|
||||
auto merged = MergeSharding(instruction->sharding(), sharding,
|
||||
auto merged = MergeSharding(instruction->sharding(), &sharding,
|
||||
may_combine_partial_sharding);
|
||||
if (merged != instruction->sharding()) {
|
||||
instruction->set_sharding(merged);
|
||||
if (merged) {
|
||||
instruction->set_sharding(std::move(sharding));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -620,7 +622,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
HloSharding new_sharding = operand->sharding().GetSubSharding(
|
||||
operand->shape(), {instruction->tuple_index()});
|
||||
return MaybeImproveInstructionSharding(
|
||||
new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
||||
std::move(new_sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
case HloOpcode::kTuple: {
|
||||
if (absl::c_none_of(instruction->operands(),
|
||||
@ -685,12 +688,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
if (!IsSpatiallyPartitioned(operand)) {
|
||||
continue;
|
||||
}
|
||||
auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) {
|
||||
auto get_maybe_tuple_sharding = [&](HloSharding sharding) {
|
||||
if (instruction->operand_count() == 2) {
|
||||
return sharding;
|
||||
}
|
||||
std::vector<HloSharding> tuple(instruction->operand_count() / 2,
|
||||
sharding);
|
||||
std::move(sharding));
|
||||
return HloSharding::Tuple(instruction->shape(), tuple);
|
||||
};
|
||||
if (operand->sharding().IsReplicated() ||
|
||||
@ -722,7 +725,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
|
||||
after_partial_replication, instruction->dimensions()));
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
new_sharding, instruction,
|
||||
std::move(new_sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
return changed;
|
||||
@ -764,7 +767,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
? HloSharding::PartialTile(new_tile_assignment)
|
||||
: HloSharding::Tile(new_tile_assignment);
|
||||
return MaybeImproveInstructionSharding(
|
||||
new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
||||
std::move(new_sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
case HloOpcode::kConvolution:
|
||||
return InferConvolutionShardingFromOperands(
|
||||
@ -778,7 +782,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
HloSharding sharding = hlo_sharding_util::TransposeSharding(
|
||||
input->sharding(), instruction->dimensions());
|
||||
return MaybeImproveInstructionSharding(
|
||||
sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
||||
std::move(sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
case HloOpcode::kReduceWindow: {
|
||||
const HloInstruction* lhs = instruction->operand(0);
|
||||
@ -831,7 +836,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
instruction->operand(0)->sharding());
|
||||
if (new_sharding.has_value()) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
new_sharding.value(), instruction,
|
||||
std::move(*new_sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
return false;
|
||||
@ -947,7 +952,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
||||
instruction->operand(1)->sharding(), instruction);
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
new_sharding, instruction,
|
||||
std::move(new_sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||
@ -956,7 +961,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
instruction->operand(0)->sharding(), *instruction);
|
||||
if (maybe_from_data) {
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
*maybe_from_data, instruction,
|
||||
std::move(*maybe_from_data), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
}
|
||||
@ -979,7 +984,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
instruction->operand(2)->sharding(), *instruction);
|
||||
if (maybe_from_update) {
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
*maybe_from_update, instruction,
|
||||
std::move(*maybe_from_update), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
}
|
||||
@ -998,7 +1003,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
MergeForMoreSpecificSharding(sharding, instruction->sharding());
|
||||
}
|
||||
return MaybeImproveInstructionSharding(
|
||||
sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
|
||||
std::move(sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
default: {
|
||||
if (instruction->IsElementwise() && is_spmd) {
|
||||
@ -1089,12 +1095,14 @@ HloSharding InferDotOperandSharding(
|
||||
operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
|
||||
operand_index == 0 ? dim.rhs : dim.lhs;
|
||||
}
|
||||
sharding =
|
||||
MergeSharding(sharding,
|
||||
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
|
||||
other_operand_dims_replicated, other_to_operand_dims,
|
||||
operand_to_other_dims),
|
||||
may_combine_partial_sharding);
|
||||
HloSharding sharding_from_other =
|
||||
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
|
||||
other_operand_dims_replicated, other_to_operand_dims,
|
||||
operand_to_other_dims);
|
||||
if (MergeSharding(sharding, &sharding_from_other,
|
||||
may_combine_partial_sharding)) {
|
||||
sharding = std::move(sharding_from_other);
|
||||
}
|
||||
}
|
||||
return sharding;
|
||||
}
|
||||
@ -1376,7 +1384,7 @@ bool InferShardingFromUsers(HloInstruction* instruction,
|
||||
GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd);
|
||||
if (user_sharding) {
|
||||
improved_sharding |= MaybeImproveInstructionSharding(
|
||||
*user_sharding, instruction,
|
||||
std::move(*user_sharding), instruction,
|
||||
/*may_combine_partial_sharding=*/is_spmd);
|
||||
}
|
||||
}
|
||||
@ -1648,9 +1656,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
// indefinitely.
|
||||
int64 iterations = 0;
|
||||
auto run_to_fix_point = [&](bool aggressive_prop) {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
absl::flat_hash_set<const HloInstruction*> workset;
|
||||
for (const HloComputation* computation : module->computations()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
// Remove the instructions where the sharding was provided from the
|
||||
// outside so we don't modify them.
|
||||
if (!provided_shardings.contains(instruction)) {
|
||||
workset.insert(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
while (!workset.empty()) {
|
||||
int64 inferred_from_operand_counter = 0;
|
||||
int64 inferred_from_user_counter = 0;
|
||||
int64 instruction_counter = 0;
|
||||
@ -1664,12 +1680,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
|
||||
}
|
||||
|
||||
// Remove the instructions where the sharding was provided from the
|
||||
// outside so we don't modify them.
|
||||
instructions.erase(
|
||||
std::remove_if(instructions.begin(), instructions.end(),
|
||||
[&](HloInstruction* instruction) {
|
||||
return provided_shardings.contains(instruction);
|
||||
return !workset.contains(instruction);
|
||||
}),
|
||||
instructions.end());
|
||||
|
||||
@ -1679,10 +1693,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
if (InferShardingFromOperands(instruction, computation_map, is_spmd_,
|
||||
aggressive_prop)) {
|
||||
++inferred_from_operand_counter;
|
||||
changed = true;
|
||||
any_changed = true;
|
||||
VLOG(2) << "Add sharding (forward-pass): "
|
||||
<< instruction->ToString();
|
||||
maybe_computation_propagation(instruction);
|
||||
for (auto user : instruction->users()) {
|
||||
if (!provided_shardings.contains(user)) {
|
||||
workset.insert(user);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
workset.erase(instruction);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1692,13 +1713,18 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
|
||||
if (InferShardingFromUsers(*it, computation_map, aggressive_prop,
|
||||
is_spmd_)) {
|
||||
++inferred_from_user_counter;
|
||||
changed = true;
|
||||
any_changed = true;
|
||||
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
|
||||
maybe_computation_propagation(*it);
|
||||
workset.insert(*it);
|
||||
for (auto operand : (*it)->operands()) {
|
||||
if (!provided_shardings.contains(operand)) {
|
||||
workset.insert(operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
any_changed |= changed;
|
||||
VLOG(1) << "Sharding propagation iteration " << iterations << ";";
|
||||
VLOG(1) << " total instructions: " << instruction_counter;
|
||||
VLOG(1) << " instructions already sharded: " << already_sharded_counter;
|
||||
|
Loading…
x
Reference in New Issue
Block a user