[XLA] Make sharding propagation faster
1. Avoid copy in MergeSharding 2. Keep track of a workset to avoid unnecessary computing. PiperOrigin-RevId: 326768403 Change-Id: Iea3f1ff3c448864a06f4ebb14c37f73a16ebea1e
This commit is contained in:
		
							parent
							
								
									a74a993732
								
							
						
					
					
						commit
						ea0a469bdd
					
				| @ -120,34 +120,34 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a, | ||||
|   return IsShardingMoreSpecific(a, b) ? a : b; | ||||
| } | ||||
| 
 | ||||
| // Returns a sharding that is refined by merging old and to_merge. May combine
 | ||||
| // partial sharding in addition to MergeForMoreSpecificSharding().
 | ||||
| HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
| // Tries to refine `to_merge` by combining with `old`. Returns if the final
 | ||||
| // `to_merge` is more specific than `old`. May combine partial sharding in
 | ||||
| // addition to MergeForMoreSpecificSharding().
 | ||||
| bool MergeSharding(const HloSharding& old, HloSharding* to_merge, | ||||
|                    bool may_combine_partial_sharding) { | ||||
|   if (old.IsTuple()) { | ||||
|     HloSharding result = old; | ||||
|     CHECK(to_merge.IsTuple()); | ||||
|     CHECK_EQ(old.tuple_elements().size(), to_merge.tuple_elements().size()); | ||||
|     for (int64 i = 0; i < result.tuple_elements().size(); ++i) { | ||||
|       result.tuple_elements()[i] = | ||||
|           MergeSharding(old.tuple_elements()[i], to_merge.tuple_elements()[i], | ||||
|     CHECK(to_merge->IsTuple()); | ||||
|     bool changed = false; | ||||
|     for (int64 i = 0; i < old.tuple_elements().size(); ++i) { | ||||
|       changed |= | ||||
|           MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i], | ||||
|                         may_combine_partial_sharding); | ||||
|     } | ||||
|     return result; | ||||
|     return changed; | ||||
|   } | ||||
|   if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() || | ||||
|       !to_merge.ReplicateOnLastTileDim() || | ||||
|       !to_merge->ReplicateOnLastTileDim() || | ||||
|       old.tile_assignment().num_elements() != | ||||
|           to_merge.tile_assignment().num_elements()) { | ||||
|     return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; | ||||
|           to_merge->tile_assignment().num_elements()) { | ||||
|     return IsShardingMoreSpecific(*to_merge, old); | ||||
|   } | ||||
|   // Combine the tile dimension sizes from new and old.
 | ||||
|   int64 num_devices = old.tile_assignment().num_elements(); | ||||
|   std::vector<int64> new_tile_dims; | ||||
|   bool compatible = true; | ||||
|   new_tile_dims.reserve(to_merge.tile_assignment().num_dimensions()); | ||||
|   for (int64 i = 0; i < to_merge.tile_assignment().num_dimensions() - 1; ++i) { | ||||
|     int64 new_dim = to_merge.tile_assignment().dim(i); | ||||
|   new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); | ||||
|   for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { | ||||
|     int64 new_dim = to_merge->tile_assignment().dim(i); | ||||
|     int64 old_dim = old.tile_assignment().dim(i); | ||||
|     if (new_dim == 1) { | ||||
|       new_tile_dims.push_back(old_dim); | ||||
| @ -163,7 +163,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
|   int64 replication = num_devices / Product(new_tile_dims); | ||||
|   if (!compatible || num_devices % Product(new_tile_dims) != 0 || | ||||
|       replication >= old.tile_assignment().dimensions().back()) { | ||||
|     return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; | ||||
|     return IsShardingMoreSpecific(*to_merge, old); | ||||
|   } | ||||
|   new_tile_dims.push_back(replication); | ||||
|   Array<int64> new_tile(new_tile_dims); | ||||
| @ -174,7 +174,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
|                              const HloSharding& sharding) { | ||||
|     int64 group_id = 0; | ||||
|     for (int64 i = 0; i < tile_indices.size() - 1; ++i) { | ||||
|       group_id *= to_merge.tile_assignment().dim(i); | ||||
|       group_id *= to_merge->tile_assignment().dim(i); | ||||
|       group_id += tile_indices[i]; | ||||
|     } | ||||
|     return group_id; | ||||
| @ -183,9 +183,9 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
|       [&](absl::Span<const int64> indices, int64 device) { | ||||
|         old_group_members[get_group_index(indices, old)].insert(device); | ||||
|       }); | ||||
|   to_merge.tile_assignment().Each( | ||||
|   to_merge->tile_assignment().Each( | ||||
|       [&](absl::Span<const int64> indices, int64 device) { | ||||
|         new_group_members[get_group_index(indices, to_merge)].insert(device); | ||||
|         new_group_members[get_group_index(indices, *to_merge)].insert(device); | ||||
|       }); | ||||
|   // Try to find the intersection of old and new replication groups, in
 | ||||
|   // order to determine the merged tile assignment.
 | ||||
| @ -199,12 +199,12 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
|       if (old.tile_assignment().dim(i) == 1) { | ||||
|         old_index[i] = 0; | ||||
|       } | ||||
|       if (to_merge.tile_assignment().dim(i) == 1) { | ||||
|       if (to_merge->tile_assignment().dim(i) == 1) { | ||||
|         new_index[i] = 0; | ||||
|       } | ||||
|     } | ||||
|     int64 old_group_id = get_group_index(old_index, old); | ||||
|     int64 new_group_id = get_group_index(new_index, to_merge); | ||||
|     int64 new_group_id = get_group_index(new_index, *to_merge); | ||||
|     if (old_group_members[old_group_id].empty() || | ||||
|         new_group_members[new_group_id].empty() || | ||||
|         *old_group_members[old_group_id].begin() != | ||||
| @ -220,11 +220,13 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
|     if (replication == 1) { | ||||
|       new_tile_dims.pop_back(); | ||||
|       new_tile.Reshape(new_tile_dims); | ||||
|       return HloSharding::Tile(new_tile); | ||||
|       *to_merge = HloSharding::Tile(new_tile); | ||||
|     } else { | ||||
|       *to_merge = HloSharding::PartialTile(new_tile); | ||||
|     } | ||||
|     return HloSharding::PartialTile(new_tile); | ||||
|     return true; | ||||
|   } | ||||
|   return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; | ||||
|   return IsShardingMoreSpecific(*to_merge, old); | ||||
| } | ||||
| 
 | ||||
| // Updates the sharding of the specified instruction with the specified sharding
 | ||||
| @ -232,7 +234,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, | ||||
| // been applied. If may_combine_partial_sharding is true, this may combine the
 | ||||
| // new and existing sharding if they are both partial tiling partial
 | ||||
| // replication.
 | ||||
| bool MaybeImproveInstructionSharding(const HloSharding& sharding, | ||||
| bool MaybeImproveInstructionSharding(HloSharding sharding, | ||||
|                                      HloInstruction* instruction, | ||||
|                                      bool may_combine_partial_sharding) { | ||||
|   // We don't want to propagate tile maximal shardings.
 | ||||
| @ -241,13 +243,13 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding, | ||||
|   } | ||||
|   // Any sharding is better then no sharding.
 | ||||
|   if (!instruction->has_sharding()) { | ||||
|     instruction->set_sharding(sharding); | ||||
|     instruction->set_sharding(std::move(sharding)); | ||||
|     return true; | ||||
|   } | ||||
|   auto merged = MergeSharding(instruction->sharding(), sharding, | ||||
|   auto merged = MergeSharding(instruction->sharding(), &sharding, | ||||
|                               may_combine_partial_sharding); | ||||
|   if (merged != instruction->sharding()) { | ||||
|     instruction->set_sharding(merged); | ||||
|   if (merged) { | ||||
|     instruction->set_sharding(std::move(sharding)); | ||||
|     return true; | ||||
|   } | ||||
|   return false; | ||||
| @ -620,7 +622,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|       HloSharding new_sharding = operand->sharding().GetSubSharding( | ||||
|           operand->shape(), {instruction->tuple_index()}); | ||||
|       return MaybeImproveInstructionSharding( | ||||
|           new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); | ||||
|           std::move(new_sharding), instruction, | ||||
|           /*may_combine_partial_sharding=*/is_spmd); | ||||
|     } | ||||
|     case HloOpcode::kTuple: { | ||||
|       if (absl::c_none_of(instruction->operands(), | ||||
| @ -685,12 +688,12 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|         if (!IsSpatiallyPartitioned(operand)) { | ||||
|           continue; | ||||
|         } | ||||
|         auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) { | ||||
|         auto get_maybe_tuple_sharding = [&](HloSharding sharding) { | ||||
|           if (instruction->operand_count() == 2) { | ||||
|             return sharding; | ||||
|           } | ||||
|           std::vector<HloSharding> tuple(instruction->operand_count() / 2, | ||||
|                                          sharding); | ||||
|                                          std::move(sharding)); | ||||
|           return HloSharding::Tuple(instruction->shape(), tuple); | ||||
|         }; | ||||
|         if (operand->sharding().IsReplicated() || | ||||
| @ -722,7 +725,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|             get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions( | ||||
|                 after_partial_replication, instruction->dimensions())); | ||||
|         changed |= MaybeImproveInstructionSharding( | ||||
|             new_sharding, instruction, | ||||
|             std::move(new_sharding), instruction, | ||||
|             /*may_combine_partial_sharding=*/is_spmd); | ||||
|       } | ||||
|       return changed; | ||||
| @ -764,7 +767,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|               ? HloSharding::PartialTile(new_tile_assignment) | ||||
|               : HloSharding::Tile(new_tile_assignment); | ||||
|       return MaybeImproveInstructionSharding( | ||||
|           new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); | ||||
|           std::move(new_sharding), instruction, | ||||
|           /*may_combine_partial_sharding=*/is_spmd); | ||||
|     } | ||||
|     case HloOpcode::kConvolution: | ||||
|       return InferConvolutionShardingFromOperands( | ||||
| @ -778,7 +782,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|       HloSharding sharding = hlo_sharding_util::TransposeSharding( | ||||
|           input->sharding(), instruction->dimensions()); | ||||
|       return MaybeImproveInstructionSharding( | ||||
|           sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); | ||||
|           std::move(sharding), instruction, | ||||
|           /*may_combine_partial_sharding=*/is_spmd); | ||||
|     } | ||||
|     case HloOpcode::kReduceWindow: { | ||||
|       const HloInstruction* lhs = instruction->operand(0); | ||||
| @ -831,7 +836,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|               instruction->operand(0)->sharding()); | ||||
|       if (new_sharding.has_value()) { | ||||
|         return MaybeImproveInstructionSharding( | ||||
|             new_sharding.value(), instruction, | ||||
|             std::move(*new_sharding), instruction, | ||||
|             /*may_combine_partial_sharding=*/is_spmd); | ||||
|       } | ||||
|       return false; | ||||
| @ -947,7 +952,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|         HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( | ||||
|             instruction->operand(1)->sharding(), instruction); | ||||
|         changed |= MaybeImproveInstructionSharding( | ||||
|             new_sharding, instruction, | ||||
|             std::move(new_sharding), instruction, | ||||
|             /*may_combine_partial_sharding=*/is_spmd); | ||||
|       } | ||||
|       if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { | ||||
| @ -956,7 +961,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|                 instruction->operand(0)->sharding(), *instruction); | ||||
|         if (maybe_from_data) { | ||||
|           changed |= MaybeImproveInstructionSharding( | ||||
|               *maybe_from_data, instruction, | ||||
|               std::move(*maybe_from_data), instruction, | ||||
|               /*may_combine_partial_sharding=*/is_spmd); | ||||
|         } | ||||
|       } | ||||
| @ -979,7 +984,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|                 instruction->operand(2)->sharding(), *instruction); | ||||
|         if (maybe_from_update) { | ||||
|           changed |= MaybeImproveInstructionSharding( | ||||
|               *maybe_from_update, instruction, | ||||
|               std::move(*maybe_from_update), instruction, | ||||
|               /*may_combine_partial_sharding=*/is_spmd); | ||||
|         } | ||||
|       } | ||||
| @ -998,7 +1003,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, | ||||
|             MergeForMoreSpecificSharding(sharding, instruction->sharding()); | ||||
|       } | ||||
|       return MaybeImproveInstructionSharding( | ||||
|           sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); | ||||
|           std::move(sharding), instruction, | ||||
|           /*may_combine_partial_sharding=*/is_spmd); | ||||
|     } | ||||
|     default: { | ||||
|       if (instruction->IsElementwise() && is_spmd) { | ||||
| @ -1089,12 +1095,14 @@ HloSharding InferDotOperandSharding( | ||||
|       operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = | ||||
|           operand_index == 0 ? dim.rhs : dim.lhs; | ||||
|     } | ||||
|     sharding = | ||||
|         MergeSharding(sharding, | ||||
|     HloSharding sharding_from_other = | ||||
|         *hlo_sharding_util::TransposeShardingWithCollapsedDims( | ||||
|             other_operand_dims_replicated, other_to_operand_dims, | ||||
|                           operand_to_other_dims), | ||||
|                       may_combine_partial_sharding); | ||||
|             operand_to_other_dims); | ||||
|     if (MergeSharding(sharding, &sharding_from_other, | ||||
|                       may_combine_partial_sharding)) { | ||||
|       sharding = std::move(sharding_from_other); | ||||
|     } | ||||
|   } | ||||
|   return sharding; | ||||
| } | ||||
| @ -1376,7 +1384,7 @@ bool InferShardingFromUsers(HloInstruction* instruction, | ||||
|         GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); | ||||
|     if (user_sharding) { | ||||
|       improved_sharding |= MaybeImproveInstructionSharding( | ||||
|           *user_sharding, instruction, | ||||
|           std::move(*user_sharding), instruction, | ||||
|           /*may_combine_partial_sharding=*/is_spmd); | ||||
|     } | ||||
|   } | ||||
| @ -1648,9 +1656,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) { | ||||
|   // indefinitely.
 | ||||
|   int64 iterations = 0; | ||||
|   auto run_to_fix_point = [&](bool aggressive_prop) { | ||||
|     bool changed = true; | ||||
|     while (changed) { | ||||
|       changed = false; | ||||
|     absl::flat_hash_set<const HloInstruction*> workset; | ||||
|     for (const HloComputation* computation : module->computations()) { | ||||
|       for (const HloInstruction* instruction : computation->instructions()) { | ||||
|         // Remove the instructions where the sharding was provided from the
 | ||||
|         // outside so we don't modify them.
 | ||||
|         if (!provided_shardings.contains(instruction)) { | ||||
|           workset.insert(instruction); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     while (!workset.empty()) { | ||||
|       int64 inferred_from_operand_counter = 0; | ||||
|       int64 inferred_from_user_counter = 0; | ||||
|       int64 instruction_counter = 0; | ||||
| @ -1664,12 +1680,10 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) { | ||||
|           already_sharded_counter += (instruction->has_sharding() ? 1 : 0); | ||||
|         } | ||||
| 
 | ||||
|         // Remove the instructions where the sharding was provided from the
 | ||||
|         // outside so we don't modify them.
 | ||||
|         instructions.erase( | ||||
|             std::remove_if(instructions.begin(), instructions.end(), | ||||
|                            [&](HloInstruction* instruction) { | ||||
|                              return provided_shardings.contains(instruction); | ||||
|                              return !workset.contains(instruction); | ||||
|                            }), | ||||
|             instructions.end()); | ||||
| 
 | ||||
| @ -1679,10 +1693,17 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) { | ||||
|           if (InferShardingFromOperands(instruction, computation_map, is_spmd_, | ||||
|                                         aggressive_prop)) { | ||||
|             ++inferred_from_operand_counter; | ||||
|             changed = true; | ||||
|             any_changed = true; | ||||
|             VLOG(2) << "Add sharding (forward-pass): " | ||||
|                     << instruction->ToString(); | ||||
|             maybe_computation_propagation(instruction); | ||||
|             for (auto user : instruction->users()) { | ||||
|               if (!provided_shardings.contains(user)) { | ||||
|                 workset.insert(user); | ||||
|               } | ||||
|             } | ||||
|           } else { | ||||
|             workset.erase(instruction); | ||||
|           } | ||||
|         } | ||||
| 
 | ||||
| @ -1692,13 +1713,18 @@ StatusOr<bool> ShardingPropagation::Run(HloModule* module) { | ||||
|           if (InferShardingFromUsers(*it, computation_map, aggressive_prop, | ||||
|                                      is_spmd_)) { | ||||
|             ++inferred_from_user_counter; | ||||
|             changed = true; | ||||
|             any_changed = true; | ||||
|             VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); | ||||
|             maybe_computation_propagation(*it); | ||||
|             workset.insert(*it); | ||||
|             for (auto operand : (*it)->operands()) { | ||||
|               if (!provided_shardings.contains(operand)) { | ||||
|                 workset.insert(operand); | ||||
|               } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       any_changed |= changed; | ||||
|       VLOG(1) << "Sharding propagation iteration " << iterations << ";"; | ||||
|       VLOG(1) << "  total instructions: " << instruction_counter; | ||||
|       VLOG(1) << "  instructions already sharded: " << already_sharded_counter; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user