[XLA] Make sharding propagation faster

1. Avoid copy in MergeSharding
2. Keep track of a workset to avoid unnecessary computing.

PiperOrigin-RevId: 326768403
Change-Id: Iea3f1ff3c448864a06f4ebb14c37f73a16ebea1e
This commit is contained in:
Yuanzhong Xu 2020-08-14 18:53:21 -07:00 committed by TensorFlower Gardener
parent a74a993732
commit ea0a469bdd

View File

@ -120,34 +120,34 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a,
return IsShardingMoreSpecific(a, b) ? a : b;
}
// Returns a sharding that is refined by merging old and to_merge. May combine
// partial sharding in addition to MergeForMoreSpecificSharding().
HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
bool may_combine_partial_sharding) {
// Tries to refine `to_merge` by combining with `old`. Returns if the final
// `to_merge` is more specific than `old`. May combine partial sharding in
// addition to MergeForMoreSpecificSharding().
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
bool may_combine_partial_sharding) {
if (old.IsTuple()) {
HloSharding result = old;
CHECK(to_merge.IsTuple());
CHECK_EQ(old.tuple_elements().size(), to_merge.tuple_elements().size());
for (int64 i = 0; i < result.tuple_elements().size(); ++i) {
result.tuple_elements()[i] =
MergeSharding(old.tuple_elements()[i], to_merge.tuple_elements()[i],
CHECK(to_merge->IsTuple());
bool changed = false;
for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
changed |=
MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
may_combine_partial_sharding);
}
return result;
return changed;
}
if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
!to_merge.ReplicateOnLastTileDim() ||
!to_merge->ReplicateOnLastTileDim() ||
old.tile_assignment().num_elements() !=
to_merge.tile_assignment().num_elements()) {
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
to_merge->tile_assignment().num_elements()) {
return IsShardingMoreSpecific(*to_merge, old);
}
// Combine the tile dimension sizes from new and old.
int64 num_devices = old.tile_assignment().num_elements();
std::vector<int64> new_tile_dims;
bool compatible = true;
new_tile_dims.reserve(to_merge.tile_assignment().num_dimensions());
for (int64 i = 0; i < to_merge.tile_assignment().num_dimensions() - 1; ++i) {
int64 new_dim = to_merge.tile_assignment().dim(i);
new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
int64 new_dim = to_merge->tile_assignment().dim(i);
int64 old_dim = old.tile_assignment().dim(i);
if (new_dim == 1) {
new_tile_dims.push_back(old_dim);
@ -163,7 +163,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
int64 replication = num_devices / Product(new_tile_dims);
if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
replication >= old.tile_assignment().dimensions().back()) {
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
return IsShardingMoreSpecific(*to_merge, old);
}
new_tile_dims.push_back(replication);
Array<int64> new_tile(new_tile_dims);
@ -174,7 +174,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
const HloSharding& sharding) {
int64 group_id = 0;
for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
group_id *= to_merge.tile_assignment().dim(i);
group_id *= to_merge->tile_assignment().dim(i);
group_id += tile_indices[i];
}
return group_id;
@ -183,9 +183,9 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
[&](absl::Span<const int64> indices, int64 device) {
old_group_members[get_group_index(indices, old)].insert(device);
});
to_merge.tile_assignment().Each(
to_merge->tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
new_group_members[get_group_index(indices, to_merge)].insert(device);
new_group_members[get_group_index(indices, *to_merge)].insert(device);
});
// Try to find the intersection of old and new replication groups, in
// order to determine the merged tile assignment.
@ -199,12 +199,12 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
if (old.tile_assignment().dim(i) == 1) {
old_index[i] = 0;
}
if (to_merge.tile_assignment().dim(i) == 1) {
if (to_merge->tile_assignment().dim(i) == 1) {
new_index[i] = 0;
}
}
int64 old_group_id = get_group_index(old_index, old);
int64 new_group_id = get_group_index(new_index, to_merge);
int64 new_group_id = get_group_index(new_index, *to_merge);
if (old_group_members[old_group_id].empty() ||
new_group_members[new_group_id].empty() ||
*old_group_members[old_group_id].begin() !=
@ -220,11 +220,13 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
if (replication == 1) {
new_tile_dims.pop_back();
new_tile.Reshape(new_tile_dims);
return HloSharding::Tile(new_tile);
*to_merge = HloSharding::Tile(new_tile);
} else {
*to_merge = HloSharding::PartialTile(new_tile);
}
return HloSharding::PartialTile(new_tile);
return true;
}
return IsShardingMoreSpecific(to_merge, old) ? to_merge : old;
return IsShardingMoreSpecific(*to_merge, old);
}
// Updates the sharding of the specified instruction with the specified sharding
@ -232,7 +234,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge,
// been applied. If may_combine_partial_sharding is true, this may combine the
// new and existing sharding if they are both partial tiling partial
// replication.
bool MaybeImproveInstructionSharding(const HloSharding& sharding,
bool MaybeImproveInstructionSharding(HloSharding sharding,
HloInstruction* instruction,
bool may_combine_partial_sharding) {
// We don't want to propagate tile maximal shardings.
@ -241,13 +243,13 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding,
}
// Any sharding is better then no sharding.
if (!instruction->has_sharding()) {
instruction->set_sharding(sharding);
instruction->set_sharding(std::move(sharding));
return true;
}
auto merged = MergeSharding(instruction->sharding(), sharding,
auto merged = MergeSharding(instruction->sharding(), &sharding,
may_combine_partial_sharding);
if (merged != instruction->sharding()) {
instruction->set_sharding(merged);
if (merged) {
instruction->set_sharding(std::move(sharding));
return true;
}
return false;
@ -620,7 +622,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding new_sharding = operand->sharding().GetSubSharding(
operand->shape(), {instruction->tuple_index()});
return MaybeImproveInstructionSharding(
new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
std::move(new_sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
case HloOpcode::kTuple: {
if (absl::c_none_of(instruction->operands(),
@ -685,12 +688,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
if (!IsSpatiallyPartitioned(operand)) {
continue;
}
auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) {
auto get_maybe_tuple_sharding = [&](HloSharding sharding) {
if (instruction->operand_count() == 2) {
return sharding;
}
std::vector<HloSharding> tuple(instruction->operand_count() / 2,
sharding);
std::move(sharding));
return HloSharding::Tuple(instruction->shape(), tuple);
};
if (operand->sharding().IsReplicated() ||
@ -722,7 +725,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
after_partial_replication, instruction->dimensions()));
changed |= MaybeImproveInstructionSharding(
new_sharding, instruction,
std::move(new_sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
return changed;
@ -764,7 +767,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
? HloSharding::PartialTile(new_tile_assignment)
: HloSharding::Tile(new_tile_assignment);
return MaybeImproveInstructionSharding(
new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
std::move(new_sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
case HloOpcode::kConvolution:
return InferConvolutionShardingFromOperands(
@ -778,7 +782,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding sharding = hlo_sharding_util::TransposeSharding(
input->sharding(), instruction->dimensions());
return MaybeImproveInstructionSharding(
sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
std::move(sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
case HloOpcode::kReduceWindow: {
const HloInstruction* lhs = instruction->operand(0);
@ -831,7 +836,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
instruction->operand(0)->sharding());
if (new_sharding.has_value()) {
return MaybeImproveInstructionSharding(
new_sharding.value(), instruction,
std::move(*new_sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
return false;
@ -947,7 +952,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);
changed |= MaybeImproveInstructionSharding(
new_sharding, instruction,
std::move(new_sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
@ -956,7 +961,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
instruction->operand(0)->sharding(), *instruction);
if (maybe_from_data) {
changed |= MaybeImproveInstructionSharding(
*maybe_from_data, instruction,
std::move(*maybe_from_data), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
}
@ -979,7 +984,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
instruction->operand(2)->sharding(), *instruction);
if (maybe_from_update) {
changed |= MaybeImproveInstructionSharding(
*maybe_from_update, instruction,
std::move(*maybe_from_update), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
}
@ -998,7 +1003,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
MergeForMoreSpecificSharding(sharding, instruction->sharding());
}
return MaybeImproveInstructionSharding(
sharding, instruction, /*may_combine_partial_sharding=*/is_spmd);
std::move(sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
default: {
if (instruction->IsElementwise() && is_spmd) {
@ -1089,12 +1095,14 @@ HloSharding InferDotOperandSharding(
operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
operand_index == 0 ? dim.rhs : dim.lhs;
}
sharding =
MergeSharding(sharding,
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
other_operand_dims_replicated, other_to_operand_dims,
operand_to_other_dims),
may_combine_partial_sharding);
HloSharding sharding_from_other =
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
other_operand_dims_replicated, other_to_operand_dims,
operand_to_other_dims);
if (MergeSharding(sharding, &sharding_from_other,
may_combine_partial_sharding)) {
sharding = std::move(sharding_from_other);
}
}
return sharding;
}
@ -1376,7 +1384,7 @@ bool InferShardingFromUsers(HloInstruction* instruction,
GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd);
if (user_sharding) {
improved_sharding |= MaybeImproveInstructionSharding(
*user_sharding, instruction,
std::move(*user_sharding), instruction,
/*may_combine_partial_sharding=*/is_spmd);
}
}
@ -1648,9 +1656,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
// indefinitely.
int64 iterations = 0;
auto run_to_fix_point = [&](bool aggressive_prop) {
bool changed = true;
while (changed) {
changed = false;
absl::flat_hash_set<const HloInstruction*> workset;
for (const HloComputation* computation : module->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
// Remove the instructions where the sharding was provided from the
// outside so we don't modify them.
if (!provided_shardings.contains(instruction)) {
workset.insert(instruction);
}
}
}
while (!workset.empty()) {
int64 inferred_from_operand_counter = 0;
int64 inferred_from_user_counter = 0;
int64 instruction_counter = 0;
@ -1664,12 +1680,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
already_sharded_counter += (instruction->has_sharding() ? 1 : 0);
}
// Remove the instructions where the sharding was provided from the
// outside so we don't modify them.
instructions.erase(
std::remove_if(instructions.begin(), instructions.end(),
[&](HloInstruction* instruction) {
return provided_shardings.contains(instruction);
return !workset.contains(instruction);
}),
instructions.end());
@ -1679,10 +1693,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
if (InferShardingFromOperands(instruction, computation_map, is_spmd_,
aggressive_prop)) {
++inferred_from_operand_counter;
changed = true;
any_changed = true;
VLOG(2) << "Add sharding (forward-pass): "
<< instruction->ToString();
maybe_computation_propagation(instruction);
for (auto user : instruction->users()) {
if (!provided_shardings.contains(user)) {
workset.insert(user);
}
}
} else {
workset.erase(instruction);
}
}
@ -1692,13 +1713,18 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) {
if (InferShardingFromUsers(*it, computation_map, aggressive_prop,
is_spmd_)) {
++inferred_from_user_counter;
changed = true;
any_changed = true;
VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
maybe_computation_propagation(*it);
workset.insert(*it);
for (auto operand : (*it)->operands()) {
if (!provided_shardings.contains(operand)) {
workset.insert(operand);
}
}
}
}
}
any_changed |= changed;
VLOG(1) << "Sharding propagation iteration " << iterations << ";";
VLOG(1) << " total instructions: " << instruction_counter;
VLOG(1) << " instructions already sharded: " << already_sharded_counter;