From ea0a469bdd5ba225e23e3979ce20f18a5cedd049 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Fri, 14 Aug 2020 18:53:21 -0700 Subject: [PATCH] [XLA] Make sharding propagation faster 1. Avoid copy in MergeSharding 2. Keep track of a workset to avoid unnecessary computing. PiperOrigin-RevId: 326768403 Change-Id: Iea3f1ff3c448864a06f4ebb14c37f73a16ebea1e --- .../xla/service/sharding_propagation.cc | 142 +++++++++++------- 1 file changed, 84 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 7293bd9770d..408fdfb7612 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -120,34 +120,34 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a, return IsShardingMoreSpecific(a, b) ? a : b; } -// Returns a sharding that is refined by merging old and to_merge. May combine -// partial sharding in addition to MergeForMoreSpecificSharding(). -HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, - bool may_combine_partial_sharding) { +// Tries to refine `to_merge` by combining with `old`. Returns if the final +// `to_merge` is more specific than `old`. May combine partial sharding in +// addition to MergeForMoreSpecificSharding(). +bool MergeSharding(const HloSharding& old, HloSharding* to_merge, + bool may_combine_partial_sharding) { if (old.IsTuple()) { - HloSharding result = old; - CHECK(to_merge.IsTuple()); - CHECK_EQ(old.tuple_elements().size(), to_merge.tuple_elements().size()); - for (int64 i = 0; i < result.tuple_elements().size(); ++i) { - result.tuple_elements()[i] = - MergeSharding(old.tuple_elements()[i], to_merge.tuple_elements()[i], + CHECK(to_merge->IsTuple()); + bool changed = false; + for (int64 i = 0; i < old.tuple_elements().size(); ++i) { + changed |= + MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i], may_combine_partial_sharding); } - return result; + return changed; } if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() || - !to_merge.ReplicateOnLastTileDim() || + !to_merge->ReplicateOnLastTileDim() || old.tile_assignment().num_elements() != - to_merge.tile_assignment().num_elements()) { - return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; + to_merge->tile_assignment().num_elements()) { + return IsShardingMoreSpecific(*to_merge, old); } // Combine the tile dimension sizes from new and old. int64 num_devices = old.tile_assignment().num_elements(); std::vector new_tile_dims; bool compatible = true; - new_tile_dims.reserve(to_merge.tile_assignment().num_dimensions()); - for (int64 i = 0; i < to_merge.tile_assignment().num_dimensions() - 1; ++i) { - int64 new_dim = to_merge.tile_assignment().dim(i); + new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); + for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { + int64 new_dim = to_merge->tile_assignment().dim(i); int64 old_dim = old.tile_assignment().dim(i); if (new_dim == 1) { new_tile_dims.push_back(old_dim); @@ -163,7 +163,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, int64 replication = num_devices / Product(new_tile_dims); if (!compatible || num_devices % Product(new_tile_dims) != 0 || replication >= old.tile_assignment().dimensions().back()) { - return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; + return IsShardingMoreSpecific(*to_merge, old); } new_tile_dims.push_back(replication); Array new_tile(new_tile_dims); @@ -174,7 +174,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, const HloSharding& sharding) { int64 group_id = 0; for (int64 i = 0; i < tile_indices.size() - 1; ++i) { - group_id *= to_merge.tile_assignment().dim(i); + group_id *= to_merge->tile_assignment().dim(i); group_id += tile_indices[i]; } return group_id; @@ -183,9 +183,9 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, [&](absl::Span indices, int64 device) { old_group_members[get_group_index(indices, old)].insert(device); }); - to_merge.tile_assignment().Each( + to_merge->tile_assignment().Each( [&](absl::Span indices, int64 device) { - new_group_members[get_group_index(indices, to_merge)].insert(device); + new_group_members[get_group_index(indices, *to_merge)].insert(device); }); // Try to find the intersection of old and new replication groups, in // order to determine the merged tile assignment. @@ -199,12 +199,12 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, if (old.tile_assignment().dim(i) == 1) { old_index[i] = 0; } - if (to_merge.tile_assignment().dim(i) == 1) { + if (to_merge->tile_assignment().dim(i) == 1) { new_index[i] = 0; } } int64 old_group_id = get_group_index(old_index, old); - int64 new_group_id = get_group_index(new_index, to_merge); + int64 new_group_id = get_group_index(new_index, *to_merge); if (old_group_members[old_group_id].empty() || new_group_members[new_group_id].empty() || *old_group_members[old_group_id].begin() != @@ -220,11 +220,13 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, if (replication == 1) { new_tile_dims.pop_back(); new_tile.Reshape(new_tile_dims); - return HloSharding::Tile(new_tile); + *to_merge = HloSharding::Tile(new_tile); + } else { + *to_merge = HloSharding::PartialTile(new_tile); } - return HloSharding::PartialTile(new_tile); + return true; } - return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; + return IsShardingMoreSpecific(*to_merge, old); } // Updates the sharding of the specified instruction with the specified sharding @@ -232,7 +234,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, // been applied. If may_combine_partial_sharding is true, this may combine the // new and existing sharding if they are both partial tiling partial // replication. -bool MaybeImproveInstructionSharding(const HloSharding& sharding, +bool MaybeImproveInstructionSharding(HloSharding sharding, HloInstruction* instruction, bool may_combine_partial_sharding) { // We don't want to propagate tile maximal shardings. @@ -241,13 +243,13 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding, } // Any sharding is better then no sharding. if (!instruction->has_sharding()) { - instruction->set_sharding(sharding); + instruction->set_sharding(std::move(sharding)); return true; } - auto merged = MergeSharding(instruction->sharding(), sharding, + auto merged = MergeSharding(instruction->sharding(), &sharding, may_combine_partial_sharding); - if (merged != instruction->sharding()) { - instruction->set_sharding(merged); + if (merged) { + instruction->set_sharding(std::move(sharding)); return true; } return false; @@ -620,7 +622,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, HloSharding new_sharding = operand->sharding().GetSubSharding( operand->shape(), {instruction->tuple_index()}); return MaybeImproveInstructionSharding( - new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + std::move(new_sharding), instruction, + /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kTuple: { if (absl::c_none_of(instruction->operands(), @@ -685,12 +688,12 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (!IsSpatiallyPartitioned(operand)) { continue; } - auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) { + auto get_maybe_tuple_sharding = [&](HloSharding sharding) { if (instruction->operand_count() == 2) { return sharding; } std::vector tuple(instruction->operand_count() / 2, - sharding); + std::move(sharding)); return HloSharding::Tuple(instruction->shape(), tuple); }; if (operand->sharding().IsReplicated() || @@ -722,7 +725,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions( after_partial_replication, instruction->dimensions())); changed |= MaybeImproveInstructionSharding( - new_sharding, instruction, + std::move(new_sharding), instruction, /*may_combine_partial_sharding=*/is_spmd); } return changed; @@ -764,7 +767,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, ? HloSharding::PartialTile(new_tile_assignment) : HloSharding::Tile(new_tile_assignment); return MaybeImproveInstructionSharding( - new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + std::move(new_sharding), instruction, + /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kConvolution: return InferConvolutionShardingFromOperands( @@ -778,7 +782,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, HloSharding sharding = hlo_sharding_util::TransposeSharding( input->sharding(), instruction->dimensions()); return MaybeImproveInstructionSharding( - sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + std::move(sharding), instruction, + /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kReduceWindow: { const HloInstruction* lhs = instruction->operand(0); @@ -831,7 +836,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(0)->sharding()); if (new_sharding.has_value()) { return MaybeImproveInstructionSharding( - new_sharding.value(), instruction, + std::move(*new_sharding), instruction, /*may_combine_partial_sharding=*/is_spmd); } return false; @@ -947,7 +952,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( instruction->operand(1)->sharding(), instruction); changed |= MaybeImproveInstructionSharding( - new_sharding, instruction, + std::move(new_sharding), instruction, /*may_combine_partial_sharding=*/is_spmd); } if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { @@ -956,7 +961,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(0)->sharding(), *instruction); if (maybe_from_data) { changed |= MaybeImproveInstructionSharding( - *maybe_from_data, instruction, + std::move(*maybe_from_data), instruction, /*may_combine_partial_sharding=*/is_spmd); } } @@ -979,7 +984,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(2)->sharding(), *instruction); if (maybe_from_update) { changed |= MaybeImproveInstructionSharding( - *maybe_from_update, instruction, + std::move(*maybe_from_update), instruction, /*may_combine_partial_sharding=*/is_spmd); } } @@ -998,7 +1003,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, MergeForMoreSpecificSharding(sharding, instruction->sharding()); } return MaybeImproveInstructionSharding( - sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + std::move(sharding), instruction, + /*may_combine_partial_sharding=*/is_spmd); } default: { if (instruction->IsElementwise() && is_spmd) { @@ -1089,12 +1095,14 @@ HloSharding InferDotOperandSharding( operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = operand_index == 0 ? dim.rhs : dim.lhs; } - sharding = - MergeSharding(sharding, - *hlo_sharding_util::TransposeShardingWithCollapsedDims( - other_operand_dims_replicated, other_to_operand_dims, - operand_to_other_dims), - may_combine_partial_sharding); + HloSharding sharding_from_other = + *hlo_sharding_util::TransposeShardingWithCollapsedDims( + other_operand_dims_replicated, other_to_operand_dims, + operand_to_other_dims); + if (MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { + sharding = std::move(sharding_from_other); + } } return sharding; } @@ -1376,7 +1384,7 @@ bool InferShardingFromUsers(HloInstruction* instruction, GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); if (user_sharding) { improved_sharding |= MaybeImproveInstructionSharding( - *user_sharding, instruction, + std::move(*user_sharding), instruction, /*may_combine_partial_sharding=*/is_spmd); } } @@ -1648,9 +1656,17 @@ StatusOr ShardingPropagation::Run(HloModule* module) { // indefinitely. int64 iterations = 0; auto run_to_fix_point = [&](bool aggressive_prop) { - bool changed = true; - while (changed) { - changed = false; + absl::flat_hash_set workset; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + // Remove the instructions where the sharding was provided from the + // outside so we don't modify them. + if (!provided_shardings.contains(instruction)) { + workset.insert(instruction); + } + } + } + while (!workset.empty()) { int64 inferred_from_operand_counter = 0; int64 inferred_from_user_counter = 0; int64 instruction_counter = 0; @@ -1664,12 +1680,10 @@ StatusOr ShardingPropagation::Run(HloModule* module) { already_sharded_counter += (instruction->has_sharding() ? 1 : 0); } - // Remove the instructions where the sharding was provided from the - // outside so we don't modify them. instructions.erase( std::remove_if(instructions.begin(), instructions.end(), [&](HloInstruction* instruction) { - return provided_shardings.contains(instruction); + return !workset.contains(instruction); }), instructions.end()); @@ -1679,10 +1693,17 @@ StatusOr ShardingPropagation::Run(HloModule* module) { if (InferShardingFromOperands(instruction, computation_map, is_spmd_, aggressive_prop)) { ++inferred_from_operand_counter; - changed = true; + any_changed = true; VLOG(2) << "Add sharding (forward-pass): " << instruction->ToString(); maybe_computation_propagation(instruction); + for (auto user : instruction->users()) { + if (!provided_shardings.contains(user)) { + workset.insert(user); + } + } + } else { + workset.erase(instruction); } } @@ -1692,13 +1713,18 @@ StatusOr ShardingPropagation::Run(HloModule* module) { if (InferShardingFromUsers(*it, computation_map, aggressive_prop, is_spmd_)) { ++inferred_from_user_counter; - changed = true; + any_changed = true; VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); maybe_computation_propagation(*it); + workset.insert(*it); + for (auto operand : (*it)->operands()) { + if (!provided_shardings.contains(operand)) { + workset.insert(operand); + } + } } } } - any_changed |= changed; VLOG(1) << "Sharding propagation iteration " << iterations << ";"; VLOG(1) << " total instructions: " << instruction_counter; VLOG(1) << " instructions already sharded: " << already_sharded_counter;