[XLA] Move SPMD partitioner to third_party
This change moves the work on SPMD partitioning that the XLA team has been working on in the past 12 months. PiperOrigin-RevId: 311367525 Change-Id: If174527128c222c53736dc8db2ef1ea4177fb476
This commit is contained in:
		
							parent
							
								
									59239ab499
								
							
						
					
					
						commit
						d45abae4e9
					
				| @ -460,6 +460,37 @@ cc_library( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "hlo_sharding_util", | ||||
|     srcs = [ | ||||
|         "hlo_sharding_util.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
|         "hlo_sharding_util.h", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":hlo", | ||||
|         "//tensorflow/compiler/xla:array", | ||||
|         "//tensorflow/compiler/xla:literal_util", | ||||
|         "//tensorflow/compiler/xla:shape_util", | ||||
|         "//tensorflow/compiler/xla:xla_data_proto_cc", | ||||
|         "@com_google_absl//absl/algorithm:container", | ||||
|         "@com_google_absl//absl/types:optional", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_cc_test( | ||||
|     name = "hlo_sharding_util_test", | ||||
|     srcs = [ | ||||
|         "hlo_sharding_util_test.cc", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":hlo_sharding_util", | ||||
|         "//tensorflow/compiler/xla:test", | ||||
|         "//tensorflow/compiler/xla/tests:xla_internal_test_main", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_cc_test( | ||||
|     name = "dynamic_parameter_binding_test", | ||||
|     srcs = ["dynamic_parameter_binding_test.cc"], | ||||
|  | ||||
							
								
								
									
										574
									
								
								tensorflow/compiler/xla/service/hlo_sharding_util.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										574
									
								
								tensorflow/compiler/xla/service/hlo_sharding_util.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,574 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" | ||||
| 
 | ||||
| #include <map> | ||||
| 
 | ||||
| #include "absl/algorithm/container.h" | ||||
| #include "tensorflow/compiler/xla/array.h" | ||||
| #include "tensorflow/compiler/xla/literal_util.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding.h" | ||||
| #include "tensorflow/compiler/xla/shape_util.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| namespace hlo_sharding_util { | ||||
| 
 | ||||
| absl::optional<int64> SelectDominantDevice( | ||||
|     const std::map<int64, int64>& device_map, int64* top_count) { | ||||
|   int64 device = 0; | ||||
|   int64 count = 0; | ||||
|   for (auto& it : device_map) { | ||||
|     if (it.second > count) { | ||||
|       count = it.second; | ||||
|       device = it.first; | ||||
|     } | ||||
|   } | ||||
|   if (top_count != nullptr) { | ||||
|     *top_count = count; | ||||
|   } | ||||
|   return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>(); | ||||
| } | ||||
| 
 | ||||
| Status AssignComputationDevice(HloComputation* computation, int64 device) { | ||||
|   VLOG(4) << "Assigning device " << device << " to " << computation->name() | ||||
|           << " computation"; | ||||
|   for (HloInstruction* instruction : computation->instructions()) { | ||||
|     if (!instruction->has_sharding()) { | ||||
|       VLOG(4) << "Assigning device " << device << " to " << instruction->name(); | ||||
|       instruction->set_device_sharding(device); | ||||
|     } | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| absl::optional<int64> GetMostOccurringDevice( | ||||
|     absl::Span<HloInstruction* const> instructions) { | ||||
|   std::map<int64, int64> device_map; | ||||
|   for (HloInstruction* instruction : instructions) { | ||||
|     if (instruction->has_sharding()) { | ||||
|       for (auto& it : instruction->sharding().UsedDevices(nullptr)) { | ||||
|         // The UsedDevices() API returns a map<device, occurrence_count>.
 | ||||
|         device_map[it.first] += it.second; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return SelectDominantDevice(device_map, nullptr); | ||||
| } | ||||
| 
 | ||||
| StatusOr<absl::optional<int64>> GetDominantDevice( | ||||
|     absl::Span<HloComputation* const> computations, double dominant_factor) { | ||||
|   int64 instruction_count = 0; | ||||
|   std::map<int64, int64> device_map; | ||||
|   for (HloComputation* computation : computations) { | ||||
|     for (HloInstruction* instruction : computation->instructions()) { | ||||
|       int64 count = 1; | ||||
|       if (instruction->has_sharding()) { | ||||
|         for (auto& it : instruction->sharding().UsedDevices(&count)) { | ||||
|           // The UsedDevices() API returns a map<device, occurrence_count>.
 | ||||
|           device_map[it.first] += it.second; | ||||
|         } | ||||
|       } | ||||
|       instruction_count += count; | ||||
|     } | ||||
|   } | ||||
|   int64 count; | ||||
|   absl::optional<int64> device = SelectDominantDevice(device_map, &count); | ||||
|   absl::optional<int64> dominant_device; | ||||
|   if (device) { | ||||
|     double factor = | ||||
|         static_cast<double>(count) / static_cast<double>(instruction_count); | ||||
|     if (factor >= dominant_factor) { | ||||
|       dominant_device = device; | ||||
|     } | ||||
|   } | ||||
|   return dominant_device; | ||||
| } | ||||
| 
 | ||||
| HloSharding TransposeSharding(const HloSharding& sharding, | ||||
|                               const std::vector<int64>& dimensions) { | ||||
|   if (sharding.IsTileMaximal()) { | ||||
|     return sharding; | ||||
|   } | ||||
|   const int64 rank = dimensions.size(); | ||||
|   std::vector<int64> tile_assignment_dim(rank); | ||||
|   for (int64 i = 0; i < rank; ++i) { | ||||
|     tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); | ||||
|   } | ||||
|   Array<int64> tile_assignment = sharding.tile_assignment(); | ||||
|   tile_assignment.Reshape(tile_assignment_dim); | ||||
|   tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) { | ||||
|     std::vector<int64> src_indices(indices.size(), -1); | ||||
|     for (int64 i = 0; i < indices.size(); ++i) { | ||||
|       src_indices[dimensions[i]] = indices[i]; | ||||
|     } | ||||
|     *value = sharding.tile_assignment()(src_indices); | ||||
|   }); | ||||
|   return HloSharding::Tile(tile_assignment); | ||||
| } | ||||
| 
 | ||||
| absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape, | ||||
|                                             const Shape& target_shape, | ||||
|                                             const HloSharding& sharding) { | ||||
|   if (sharding.IsTileMaximal()) { | ||||
|     return sharding; | ||||
|   } | ||||
| 
 | ||||
|   // In case of a tiled sharding the reshaped sharding will be a valid if the
 | ||||
|   // reshape is composed from the following operations:
 | ||||
|   // * Adding or removing dimensions with size 1.
 | ||||
|   // * Merging consecutive dimensions where only the most major is sharded.
 | ||||
|   // * Splitting a dimension to consecutive dimensions.
 | ||||
|   // * Any reshaping of unsharded dimensions.
 | ||||
|   // Note that merge and split can happen consecutively on the same dimension,
 | ||||
|   // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
 | ||||
|   // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
 | ||||
|   // to make supporting such cases easy.
 | ||||
|   const Shape tile_shape = sharding.TileShape(source_shape); | ||||
|   std::vector<int64> target_tile_assignment_dimensions; | ||||
|   std::vector<int64> source_dims_stack(source_shape.rank()); | ||||
|   std::vector<int64> target_dims_stack(target_shape.rank()); | ||||
|   std::vector<int64> sharding_tile_dims_stack(source_shape.rank()); | ||||
|   for (int64 i = 0; i < source_shape.rank(); ++i) { | ||||
|     source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i); | ||||
|     sharding_tile_dims_stack[i] = | ||||
|         sharding.tile_assignment().dim(source_shape.rank() - 1 - i); | ||||
|   } | ||||
|   for (int64 i = 0; i < target_shape.rank(); ++i) { | ||||
|     target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i); | ||||
|   } | ||||
|   while (!source_dims_stack.empty() || !target_dims_stack.empty()) { | ||||
|     if (target_dims_stack.empty()) { | ||||
|       if (Product(sharding_tile_dims_stack) != 1) { | ||||
|         return absl::nullopt; | ||||
|       } | ||||
|       break; | ||||
|     } | ||||
|     int64 s_size = 1; | ||||
|     int64 t_size = 1; | ||||
|     int64 s_partitions = 1; | ||||
|     if (!source_dims_stack.empty()) { | ||||
|       s_size = source_dims_stack.back(); | ||||
|       source_dims_stack.pop_back(); | ||||
|       s_partitions = sharding_tile_dims_stack.back(); | ||||
|       sharding_tile_dims_stack.pop_back(); | ||||
|     } | ||||
|     t_size = target_dims_stack.back(); | ||||
|     target_dims_stack.pop_back(); | ||||
|     if (s_partitions * Product(sharding_tile_dims_stack) == 1) { | ||||
|       // No more partitions left.
 | ||||
|       target_tile_assignment_dimensions.push_back(1); | ||||
|       continue; | ||||
|     } | ||||
|     if (s_size == t_size) { | ||||
|       // Same dimension.
 | ||||
|       target_tile_assignment_dimensions.push_back(s_partitions); | ||||
|     } else if (t_size == 1) { | ||||
|       // Trivial dimension added.
 | ||||
|       target_tile_assignment_dimensions.push_back(1); | ||||
|       source_dims_stack.push_back(s_size); | ||||
|       sharding_tile_dims_stack.push_back(s_partitions); | ||||
|     } else if (s_size == 1) { | ||||
|       // Trivial dimension removed.
 | ||||
|       if (s_partitions != 1) { | ||||
|         return absl::nullopt; | ||||
|       } | ||||
|       target_dims_stack.push_back(t_size); | ||||
|     } else if (s_size > t_size) { | ||||
|       // Dimension split.
 | ||||
|       if (s_size % t_size != 0 || t_size % s_partitions != 0) { | ||||
|         return absl::nullopt; | ||||
|       } | ||||
|       target_tile_assignment_dimensions.push_back(s_partitions); | ||||
|       // We have part of the s_size unprocessed, so put it back to stack.
 | ||||
|       source_dims_stack.push_back(s_size / t_size); | ||||
|       sharding_tile_dims_stack.push_back(1); | ||||
|     } else { | ||||
|       // Dimension merge. Also merge the source dimension with the next, and
 | ||||
|       // process it next time.
 | ||||
|       if (s_size % s_partitions != 0) { | ||||
|         return absl::nullopt; | ||||
|       } | ||||
|       CHECK(!source_dims_stack.empty()); | ||||
|       if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) { | ||||
|         // If the next dimension to combine is sharded, we require that the
 | ||||
|         // current dimension's shard size to be 1. Otherwise, the new shard
 | ||||
|         // would be non-contiguous.
 | ||||
|         return absl::nullopt; | ||||
|       } | ||||
|       source_dims_stack.back() *= s_size; | ||||
|       sharding_tile_dims_stack.back() *= s_partitions; | ||||
|       target_dims_stack.push_back(t_size); | ||||
|     } | ||||
|   } | ||||
|   Array<int64> new_tile_assignment = sharding.tile_assignment(); | ||||
|   new_tile_assignment.Reshape(target_tile_assignment_dimensions); | ||||
|   return HloSharding::Tile(new_tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, | ||||
|                                    absl::Span<const int64> dims) { | ||||
|   CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); | ||||
|   CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims"; | ||||
|   // We optimize the tile assignment on the single dimension dim in a way to
 | ||||
|   // minimize communication among devices caused by the reshard:
 | ||||
|   // +---+---+               +---+---+              +-+-+-+-+
 | ||||
|   // |   |   |               |   0   |              | | | | |
 | ||||
|   // | 0 | 1 |               +-------+              | | | | |
 | ||||
|   // |   |   |  reshape on   |   1   |  reshape on  | | | | |
 | ||||
|   // +---+---+   dim 0  =>   +-------+   dim 1  =>  |0|2|1|3|
 | ||||
|   // |   |   |               |   2   |              | | | | |
 | ||||
|   // | 2 | 3 |               +-------+              | | | | |
 | ||||
|   // |   |   |               |   3   |              | | | | |
 | ||||
|   // +---+---+               +---+---+              +-+-+-+-+
 | ||||
| 
 | ||||
|   std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1); | ||||
|   // Handle ignore dimensions.
 | ||||
|   std::vector<int64> ignore_sizes; | ||||
|   int64 ignore_size = 1; | ||||
|   for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { | ||||
|     if (absl::c_find(dims, i) == dims.end()) { | ||||
|       int64 size = sharding.tile_assignment().dim(i); | ||||
|       ignore_sizes.push_back(size); | ||||
|       tile_dims[i] = size; | ||||
|       ignore_size *= size; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   using Buckets = std::vector<std::vector<int64>>; | ||||
|   Array<Buckets> buckets(ignore_sizes, | ||||
|                          Buckets(sharding.tile_assignment().dim(dim))); | ||||
|   sharding.tile_assignment().Each( | ||||
|       [&](absl::Span<const int64> index, int64 device) { | ||||
|         std::vector<int64> ignore_index; | ||||
|         for (int64 i = 0; i < index.size(); ++i) { | ||||
|           if (absl::c_find(dims, i) == dims.end()) { | ||||
|             ignore_index.push_back(index[i]); | ||||
|           } | ||||
|         } | ||||
|         buckets(ignore_index)[index[dim]].push_back(device); | ||||
|       }); | ||||
|   std::vector<int64> devices; | ||||
|   buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) { | ||||
|     for (auto& bucket : buckets) { | ||||
|       devices.insert(devices.end(), bucket.begin(), bucket.end()); | ||||
|     } | ||||
|   }); | ||||
|   tile_dims[dim] = devices.size() / ignore_size; | ||||
|   Array<int64> tile_assignment(tile_dims); | ||||
|   tile_assignment.SetValues(devices); | ||||
|   return HloSharding::Tile(tile_assignment); | ||||
| } | ||||
| 
 | ||||
| bool ContainsTileSharding(const HloModule& module) { | ||||
|   for (const HloComputation* computation : module.computations()) { | ||||
|     for (const HloInstruction* instruction : computation->instructions()) { | ||||
|       if (instruction->has_sharding() && | ||||
|           !instruction->sharding().IsTileMaximal()) { | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   return false; | ||||
| } | ||||
| 
 | ||||
| HloSharding GatherOutputSharding(const HloSharding& index_sharding, | ||||
|                                  const HloInstruction* hlo) { | ||||
|   if (index_sharding.IsTileMaximal()) { | ||||
|     return index_sharding; | ||||
|   } | ||||
| 
 | ||||
|   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); | ||||
|   std::vector<int64> output_tile_assignment_dims; | ||||
|   for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { | ||||
|     if (absl::c_binary_search(dnums.offset_dims(), i)) { | ||||
|       output_tile_assignment_dims.push_back(1); | ||||
|     } else { | ||||
|       output_tile_assignment_dims.push_back( | ||||
|           index_sharding.tile_assignment().dim(index_dim)); | ||||
|       index_dim++; | ||||
|     } | ||||
|   } | ||||
|   Array<int64> new_tile_assignment = index_sharding.tile_assignment(); | ||||
|   new_tile_assignment.Reshape(output_tile_assignment_dims); | ||||
|   return HloSharding::Tile(new_tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding GatherIndexSharding(const HloSharding& output_sharding, | ||||
|                                 const HloInstruction* hlo) { | ||||
|   if (output_sharding.IsTileMaximal()) { | ||||
|     return output_sharding; | ||||
|   } | ||||
| 
 | ||||
|   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); | ||||
|   std::vector<int64> index_tile_assignment_dims; | ||||
|   for (int64 i = 0; i < hlo->shape().rank(); ++i) { | ||||
|     if (!absl::c_binary_search(dnums.offset_dims(), i)) { | ||||
|       index_tile_assignment_dims.push_back( | ||||
|           output_sharding.tile_assignment().dim(i)); | ||||
|     } | ||||
|   } | ||||
|   Array<int64> new_tile_assignment = output_sharding.tile_assignment(); | ||||
|   new_tile_assignment.Reshape(index_tile_assignment_dims); | ||||
|   return HloSharding::Tile(new_tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { | ||||
|   if (hlo.sharding().IsTileMaximal()) { | ||||
|     return hlo.sharding(); | ||||
|   } | ||||
| 
 | ||||
|   const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); | ||||
|   std::vector<int64> tile_assignment_dims(hlo.shape().rank()); | ||||
|   int64 num_elements = 1; | ||||
|   for (int64 i = 0; i < hlo.shape().rank(); ++i) { | ||||
|     if (!absl::c_binary_search(dnums.offset_dims(), i)) { | ||||
|       tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); | ||||
|       num_elements *= hlo.sharding().tile_assignment().dim(i); | ||||
|     } else { | ||||
|       tile_assignment_dims[i] = 1; | ||||
|     } | ||||
|   } | ||||
|   if (num_elements == hlo.sharding().tile_assignment().num_elements()) { | ||||
|     // Output sharding is only on non offset dimensions. We use output sharding
 | ||||
|     // to shard this gather op directly.
 | ||||
|     return hlo.sharding(); | ||||
|   } | ||||
| 
 | ||||
|   if (num_elements == 1) { | ||||
|     // Output sharding is only on offset dimensions. We do not shard this gather
 | ||||
|     // op. Return a tile maximal sharding with the first device in output
 | ||||
|     // sharding tile assignment.
 | ||||
|     return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin()); | ||||
|   } | ||||
| 
 | ||||
|   // Output sharding is on both offset and non offset dimensions. We shard the
 | ||||
|   // gather op only on non offset dimensions.
 | ||||
|   // For example:
 | ||||
|   // - the gather op has sharding [2,2]{0,1,2,3},
 | ||||
|   // - first dimension is non offset dimension,
 | ||||
|   // - second dimension is offset dimension,
 | ||||
|   // Then the result sharding will be [2,1]{0,2}.
 | ||||
|   std::vector<int64> slice_starts(hlo.shape().rank(), 0LL), | ||||
|       slice_limits(hlo.shape().rank()); | ||||
|   for (int64 i = 0; i < hlo.shape().rank(); ++i) { | ||||
|     if (!absl::c_binary_search(dnums.offset_dims(), i)) { | ||||
|       slice_limits[i] = hlo.sharding().tile_assignment().dim(i); | ||||
|     } else { | ||||
|       slice_limits[i] = 1; | ||||
|     } | ||||
|   } | ||||
|   Array<int64> tile_assignment = | ||||
|       hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits); | ||||
|   return HloSharding::Tile(tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding ScatterIndexSharding(const HloSharding& data_sharding, | ||||
|                                  const HloInstruction* hlo) { | ||||
|   if (data_sharding.IsTileMaximal()) { | ||||
|     return data_sharding; | ||||
|   } | ||||
| 
 | ||||
|   const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); | ||||
|   std::vector<int64> index_tile_assignment_dims; | ||||
|   for (int64 i = 0; i < hlo->shape().rank(); ++i) { | ||||
|     if (!absl::c_binary_search(dnums.update_window_dims(), i)) { | ||||
|       index_tile_assignment_dims.push_back( | ||||
|           data_sharding.tile_assignment().dim(i)); | ||||
|     } | ||||
|   } | ||||
|   if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { | ||||
|     index_tile_assignment_dims.push_back(1); | ||||
|   } | ||||
|   Array<int64> new_tile_assignment = data_sharding.tile_assignment(); | ||||
|   new_tile_assignment.Reshape(index_tile_assignment_dims); | ||||
|   return HloSharding::Tile(new_tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding ScatterDataSharding(const HloSharding& index_sharding, | ||||
|                                 const HloInstruction* hlo) { | ||||
|   if (index_sharding.IsTileMaximal()) { | ||||
|     return index_sharding; | ||||
|   } | ||||
| 
 | ||||
|   const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); | ||||
|   std::vector<int64> data_tile_assignment_dims; | ||||
|   for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { | ||||
|     if (absl::c_binary_search(dnums.update_window_dims(), i)) { | ||||
|       data_tile_assignment_dims.push_back(1); | ||||
|     } else { | ||||
|       data_tile_assignment_dims.push_back( | ||||
|           index_sharding.tile_assignment().dim(index_dim)); | ||||
|       index_dim++; | ||||
|     } | ||||
|   } | ||||
|   Array<int64> new_tile_assignment = index_sharding.tile_assignment(); | ||||
|   new_tile_assignment.Reshape(data_tile_assignment_dims); | ||||
|   return HloSharding::Tile(new_tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, | ||||
|                                           const HloInstruction& hlo) { | ||||
|   if (index_sharding.IsTileMaximal()) { | ||||
|     return index_sharding; | ||||
|   } | ||||
| 
 | ||||
|   // Only shard on first "number of scatter_window_dims" dimensions.
 | ||||
|   const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); | ||||
|   int64 num_elements = 1; | ||||
|   int64 index_dim = 0; | ||||
|   for (int64 i = 0; i < hlo.shape().rank(); ++i) { | ||||
|     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { | ||||
|       num_elements *= index_sharding.tile_assignment().dim(index_dim); | ||||
|       index_dim++; | ||||
|     } | ||||
|   } | ||||
|   if (num_elements == index_sharding.tile_assignment().num_elements()) { | ||||
|     // Index sharding is only on scatter_window_dims. We use this index sharding
 | ||||
|     // directly.
 | ||||
|     return index_sharding; | ||||
|   } | ||||
| 
 | ||||
|   // Index sharding is only on update_window_dims. We do not shard this scatter
 | ||||
|   // op. Return a tile maximal sharding with the first device in index sharding
 | ||||
|   // tile assignment.
 | ||||
|   if (num_elements == 1) { | ||||
|     return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin()); | ||||
|   } | ||||
| 
 | ||||
|   const int64 index_rank = hlo.operand(1)->shape().rank(); | ||||
|   std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank); | ||||
|   for (int64 i = 0; i < index_rank; ++i) { | ||||
|     if (i < index_dim) { | ||||
|       slice_limits[i] = index_sharding.tile_assignment().dim(i); | ||||
|     } else { | ||||
|       slice_limits[i] = 1; | ||||
|     } | ||||
|   } | ||||
|   Array<int64> tile_assignment = | ||||
|       index_sharding.tile_assignment().Slice(slice_starts, slice_limits); | ||||
|   return HloSharding::Tile(tile_assignment); | ||||
| } | ||||
| 
 | ||||
| HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, | ||||
|                                          const HloInstruction& hlo) { | ||||
|   if (data_sharding.IsTileMaximal()) { | ||||
|     return data_sharding; | ||||
|   } | ||||
| 
 | ||||
|   const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); | ||||
|   const int64 data_rank = hlo.operand(2)->shape().rank(); | ||||
|   std::vector<int64> tile_assignment_dims(data_rank, 1LL); | ||||
|   int64 num_elements = 1; | ||||
|   for (int64 i = 0; i < hlo.shape().rank(); ++i) { | ||||
|     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { | ||||
|       CHECK_LT(i, data_rank); | ||||
|       tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); | ||||
|       num_elements *= data_sharding.tile_assignment().dim(i); | ||||
|     } | ||||
|   } | ||||
|   if (num_elements == data_sharding.tile_assignment().num_elements()) { | ||||
|     // Data sharding is only on scatter_window_dims. We use this data sharding
 | ||||
|     // directly.
 | ||||
|     return data_sharding; | ||||
|   } | ||||
| 
 | ||||
|   if (num_elements == 1) { | ||||
|     // Data sharding is only on update_window_dims. We do not shard this
 | ||||
|     // scatter op. Return a tile maximal sharding with the first device in
 | ||||
|     // data sharding tile assignment.
 | ||||
|     return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin()); | ||||
|   } | ||||
| 
 | ||||
|   // Data sharding is on both update_window_dims and scatter_window_dims. We
 | ||||
|   // shard the scatter op only on scatter_window_dims. For example:
 | ||||
|   // - the scatter data has sharding [2,2]{0,1,2,3},
 | ||||
|   // - first dimension is scatter_window_dims,
 | ||||
|   // - second dimension is update_window_dims,
 | ||||
|   // Then the result sharding will be [2,1]{0,2}.
 | ||||
|   std::vector<int64> slice_starts(data_rank, 0LL); | ||||
|   Array<int64> tile_assignment = | ||||
|       data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); | ||||
|   return HloSharding::Tile(tile_assignment); | ||||
| } | ||||
| 
 | ||||
| StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>> | ||||
| IdentityValueAndHloOpcodeForScatterReduceComputation( | ||||
|     const HloScatterInstruction& scatter) { | ||||
|   auto computation = scatter.to_apply(); | ||||
|   // We only handle computations with 2 parameters and only 1 calculation.
 | ||||
|   if (computation->instruction_count() != 3) { | ||||
|     return Status( | ||||
|         tensorflow::error::Code::INVALID_ARGUMENT, | ||||
|         "Expected scatter reduce computation with 2 parameters and only 1 " | ||||
|         "calculation"); | ||||
|   } | ||||
| 
 | ||||
|   auto root_instruction = computation->root_instruction(); | ||||
|   if (root_instruction->opcode() == HloOpcode::kAdd || | ||||
|       root_instruction->opcode() == HloOpcode::kOr) { | ||||
|     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero( | ||||
|                               scatter.shape().element_type())), | ||||
|                           root_instruction->opcode()); | ||||
|   } else if (root_instruction->opcode() == HloOpcode::kMultiply || | ||||
|              root_instruction->opcode() == HloOpcode::kAnd) { | ||||
|     return std::make_pair(HloInstruction::CreateConstant( | ||||
|                               LiteralUtil::One(scatter.shape().element_type())), | ||||
|                           root_instruction->opcode()); | ||||
|   } else if (root_instruction->opcode() == HloOpcode::kMaximum) { | ||||
|     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue( | ||||
|                               scatter.shape().element_type())), | ||||
|                           root_instruction->opcode()); | ||||
|   } else if (root_instruction->opcode() == HloOpcode::kMinimum) { | ||||
|     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue( | ||||
|                               scatter.shape().element_type())), | ||||
|                           root_instruction->opcode()); | ||||
|   } | ||||
| 
 | ||||
|   return Status(tensorflow::error::Code::INVALID_ARGUMENT, | ||||
|                 "Expected scatter reduce computation which is " | ||||
|                 "add/or/multiply/add/min/max"); | ||||
| } | ||||
| 
 | ||||
| std::vector<int64> DevicesForSharding( | ||||
|     const HloSharding& sharding, const std::vector<int64>& available_devices) { | ||||
|   std::vector<int64> devices; | ||||
|   if (sharding.IsReplicated()) { | ||||
|     for (int64 d : available_devices) { | ||||
|       if (!HloSharding::IsReservedDevice(d)) { | ||||
|         devices.push_back(d); | ||||
|       } | ||||
|     } | ||||
|     return devices; | ||||
|   } | ||||
| 
 | ||||
|   for (int64 i : available_devices) { | ||||
|     if (sharding.UsesDevice(i)) { | ||||
|       devices.push_back(i); | ||||
|     } | ||||
|   } | ||||
|   DCHECK(std::all_of(sharding.tile_assignment().begin(), | ||||
|                      sharding.tile_assignment().end(), [&](int64 device) { | ||||
|                        return std::find(available_devices.begin(), | ||||
|                                         available_devices.end(), | ||||
|                                         device) != available_devices.end(); | ||||
|                      })); | ||||
|   return devices; | ||||
| } | ||||
| 
 | ||||
| }  // namespace hlo_sharding_util
 | ||||
| }  // namespace xla
 | ||||
							
								
								
									
										143
									
								
								tensorflow/compiler/xla/service/hlo_sharding_util.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								tensorflow/compiler/xla/service/hlo_sharding_util.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ | ||||
| #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ | ||||
| 
 | ||||
| #include <map> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/types/optional.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| namespace hlo_sharding_util { | ||||
| 
 | ||||
| // Given a map<device, occurrence_count>, selects the device with higher
 | ||||
| // occurrence count (if any). If top_count in not nullptr, it will receive the
 | ||||
| // count of the dominant device returned.
 | ||||
| absl::optional<int64> SelectDominantDevice( | ||||
|     const std::map<int64, int64>& device_map, int64* top_count); | ||||
| 
 | ||||
| // Assigns all the instructions of a computation, to a given device.
 | ||||
| // This API does not recurse into called computations, and does not assign
 | ||||
| // instructions which already have sharding.
 | ||||
| Status AssignComputationDevice(HloComputation* computation, int64 device); | ||||
| 
 | ||||
| // Given an instruction container, returns the device which is most commonly
 | ||||
| // occurring among the instructions.
 | ||||
| absl::optional<int64> GetMostOccurringDevice( | ||||
|     absl::Span<HloInstruction* const> instructions); | ||||
| 
 | ||||
| // Given a set of computations, tries to extract the dominant device. A device
 | ||||
| // is dominant if the combined occurrence among all the instructions of the
 | ||||
| // input computations, is greater/equal than/to dominant_factor (real number
 | ||||
| // from 0 to 1).
 | ||||
| // This API does not recurse into called computations.
 | ||||
| // If no device exists that satisfies the condition, the returned optional will
 | ||||
| // hold no value.
 | ||||
| StatusOr<absl::optional<int64>> GetDominantDevice( | ||||
|     absl::Span<HloComputation* const> computations, double dominant_factor); | ||||
| 
 | ||||
| // Returns the HloSharding with the tile dimensions and tile assignment
 | ||||
| // transposed based on the specified dimension numbers. In case of a tile
 | ||||
| // maximal sharding returns the original sharding.
 | ||||
| HloSharding TransposeSharding(const HloSharding& sharding, | ||||
|                               const std::vector<int64>& dimensions); | ||||
| 
 | ||||
| // Returns the HloSharding with the tile shape reshaped based on the source and
 | ||||
| // target shapes and the tile assignment adjusted to correspond to the new tile
 | ||||
| // shape or absl::nullopt if the resulting reshape would create an invalid
 | ||||
| // sharding (non continuous or non uniformly sized tiles). In case of a tile
 | ||||
| // maximal sharding returns the original sharding.
 | ||||
| absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape, | ||||
|                                             const Shape& target_shape, | ||||
|                                             const HloSharding& sharding); | ||||
| 
 | ||||
| // Returns a sharding tiled on unique dimension dim by reshaping the tile
 | ||||
| // assignment of the sharding argument. Only dimensions in the dims span
 | ||||
| // argument are considered for reshaping, the others are ignored.
 | ||||
| // Assumptions: sharding is tile sharded, and dim must be included in dims.
 | ||||
| HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, | ||||
|                                    absl::Span<const int64> dims); | ||||
| 
 | ||||
| // Returns true if the provided module includes one or more instructions with
 | ||||
| // a tile sharding.
 | ||||
| bool ContainsTileSharding(const HloModule& module); | ||||
| 
 | ||||
| // Returns the preferred output sharding for a gather op based on the sharding
 | ||||
| // of the indces.
 | ||||
| HloSharding GatherOutputSharding(const HloSharding& index_sharding, | ||||
|                                  const HloInstruction* hlo); | ||||
| 
 | ||||
| // Returns the preferred index sharding for a gather op based on the sharding
 | ||||
| // of the output.
 | ||||
| HloSharding GatherIndexSharding(const HloSharding& output_sharding, | ||||
|                                 const HloInstruction* hlo); | ||||
| 
 | ||||
| // Returns a new HloSharding for a gather op so that only non offset dimensions
 | ||||
| // are sharded. Assume "result" is returned by this function. It is ensured that
 | ||||
| // "GetIndexSharding(result, hlo)" will have the same number of elements as
 | ||||
| // "result".
 | ||||
| HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); | ||||
| 
 | ||||
| // Returns the preferred index sharding for a scatter op based on the sharding
 | ||||
| // of the data.
 | ||||
| HloSharding ScatterIndexSharding(const HloSharding& data_sharding, | ||||
|                                  const HloInstruction* hlo); | ||||
| 
 | ||||
| // Returns the preferred data sharding for a scatter op based on the sharding
 | ||||
| // of the index.
 | ||||
| HloSharding ScatterDataSharding(const HloSharding& index_sharding, | ||||
|                                 const HloInstruction* hlo); | ||||
| 
 | ||||
| // Returns a new index sharding for a scatter op so that we only shard on first
 | ||||
| // "number of scatter_window_dims" dimensions. Assume "result" is returned by
 | ||||
| // this function. It is ensured that "ScatterDataSharding(result, hlo)" will
 | ||||
| // have the same number of elements as "result".
 | ||||
| HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, | ||||
|                                           const HloInstruction& hlo); | ||||
| 
 | ||||
| // Returns a new data sharding for a scatter op so that we only shard on
 | ||||
| // scatter_window_dims. Assume "result" is returned by this function. It is
 | ||||
| // ensured that "ScatterIndexSharding(result, hlo)" will have the same number of
 | ||||
| // elements as "result".
 | ||||
| HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, | ||||
|                                          const HloInstruction& hlo); | ||||
| 
 | ||||
| // Returns an identity value and an HloOpcode for reduce computation of scatter
 | ||||
| // instruction.
 | ||||
| // - If computation is add/or, return 0/false with corresponding op code;
 | ||||
| // - If computation is multiply/and, return 1/true with corresponding op code.
 | ||||
| // - If computation is min/max, return max value/min value with corresponding op
 | ||||
| //   code.
 | ||||
| // - Otherwise, return error status.
 | ||||
| StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>> | ||||
| IdentityValueAndHloOpcodeForScatterReduceComputation( | ||||
|     const HloScatterInstruction& scatter); | ||||
| 
 | ||||
| // Given a sharding and a list of devices in the topology, return a
 | ||||
| // list of the devices that `sharding` applies to.
 | ||||
| std::vector<int64> DevicesForSharding( | ||||
|     const HloSharding& sharding, const std::vector<int64>& available_devices); | ||||
| 
 | ||||
| }  // namespace hlo_sharding_util
 | ||||
| }  // namespace xla
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_
 | ||||
							
								
								
									
										206
									
								
								tensorflow/compiler/xla/service/hlo_sharding_util_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								tensorflow/compiler/xla/service/hlo_sharding_util_test.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,206 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/test.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| namespace hlo_sharding_util { | ||||
| namespace { | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, TransposeShardingReplicated) { | ||||
|   EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), | ||||
|             HloSharding::Replicate()); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, TransposeShardingTiled) { | ||||
|   HloSharding input = HloSharding::Tile(Array4D<int64>({{{{0, 1}}, {{2, 3}}}})); | ||||
|   HloSharding output = | ||||
|       HloSharding::Tile(Array4D<int64>({{{{0}, {2}}}, {{{1}, {3}}}})); | ||||
|   EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingMaximal) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); | ||||
|   HloSharding sharding = HloSharding::AssignDevice(7); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); | ||||
|   HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, sharding); | ||||
|   EXPECT_FALSE(result.has_value()); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7}); | ||||
|   HloSharding input_sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}})); | ||||
|   HloSharding output_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, input_sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), output_sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); | ||||
|   HloSharding input_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}})); | ||||
|   HloSharding output_sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, input_sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), output_sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); | ||||
|   HloSharding input_sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}})); | ||||
|   HloSharding output_sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, input_sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), output_sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14}); | ||||
|   Array<int64> sharding_array({2, 1, 1, 1}); | ||||
|   sharding_array(0, 0, 0, 0) = 0; | ||||
|   sharding_array(1, 0, 0, 0) = 1; | ||||
|   HloSharding sharding = HloSharding::Tile(sharding_array); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7}); | ||||
|   HloSharding input_sharding = | ||||
|       HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})); | ||||
|   HloSharding output_sharding = | ||||
|       HloSharding::Tile(Array4D<int64>({{{{0}}, {{1}}}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, input_sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), output_sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1}); | ||||
|   HloSharding input_sharding = HloSharding::Tile(Array2D<int64>({{0}, {1}})); | ||||
|   HloSharding output_sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0}}, {{1}}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, input_sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), output_sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) { | ||||
|   Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1}); | ||||
|   HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); | ||||
|   absl::optional<HloSharding> result = ReshapeSharding(shape, shape, sharding); | ||||
|   EXPECT_TRUE(result.has_value()); | ||||
|   EXPECT_EQ(result.value(), sharding); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeShardingScalar) { | ||||
|   Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1}); | ||||
|   Shape output_shape = ShapeUtil::MakeShape(F32, {}); | ||||
|   HloSharding sharding = HloSharding::Tile(Array3D<int64>({{{0}, {1}}})); | ||||
|   absl::optional<HloSharding> result = | ||||
|       ReshapeSharding(input_shape, output_shape, sharding); | ||||
|   EXPECT_FALSE(result.has_value()); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) { | ||||
|   HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}})); | ||||
|   HloSharding result = | ||||
|       ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}); | ||||
|   EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0}, {1}, {2}, {3}})); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) { | ||||
|   HloSharding sharding = HloSharding::Tile(Array2D<int64>({{0, 1}, {2, 3}})); | ||||
|   HloSharding result = | ||||
|       ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}); | ||||
|   EXPECT_EQ(result.tile_assignment(), Array2D<int64>({{0, 2, 1, 3}})); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) { | ||||
|   HloSharding sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); | ||||
|   HloSharding result = | ||||
|       ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}); | ||||
|   EXPECT_EQ( | ||||
|       result.tile_assignment(), | ||||
|       Array3D<int64>({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}})); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) { | ||||
|   HloSharding sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); | ||||
|   HloSharding result = | ||||
|       ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}); | ||||
|   EXPECT_EQ(result.tile_assignment(), | ||||
|             Array3D<int64>({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}})); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) { | ||||
|   HloSharding sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); | ||||
|   HloSharding result = | ||||
|       ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}); | ||||
|   EXPECT_EQ(result.tile_assignment(), | ||||
|             Array3D<int64>({{{0, 2, 4, 6, 1, 3, 5, 7}}})); | ||||
| } | ||||
| 
 | ||||
| TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) { | ||||
|   // Tile sharding in batch dimension, i.e.
 | ||||
|   // sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}.
 | ||||
|   HloSharding sharding = | ||||
|       HloSharding::Tile(Array3D<int64>({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); | ||||
|   // Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0.
 | ||||
|   HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2, | ||||
|                                               /*dims=*/{1, 2}); | ||||
|   // Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two
 | ||||
|   // non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually
 | ||||
|   // reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}.
 | ||||
|   EXPECT_EQ(result.tile_assignment(), | ||||
|             Array3D<int64>({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}})); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace hlo_sharding_util
 | ||||
| }  // namespace xla
 | ||||
							
								
								
									
										69
									
								
								tensorflow/compiler/xla/service/spmd/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tensorflow/compiler/xla/service/spmd/BUILD
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,69 @@ | ||||
| # Description: SPMD partitioning pass. | ||||
| 
 | ||||
| load("//tensorflow:tensorflow.bzl", "tf_cc_test") | ||||
| 
 | ||||
| package( | ||||
|     default_visibility = [":friends"], | ||||
|     licenses = ["notice"],  # Apache 2.0 | ||||
| ) | ||||
| 
 | ||||
| package_group( | ||||
|     name = "friends", | ||||
|     includes = [ | ||||
|         "//tensorflow/compiler/xla:friends", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "spmd_partitioner", | ||||
|     srcs = [ | ||||
|         "spmd_partitioner.cc", | ||||
|         "spmd_partitioner_util.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
|         "spmd_partitioner.h", | ||||
|         "spmd_partitioner_util.h", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//tensorflow/compiler/xla:comparison_util", | ||||
|         "//tensorflow/compiler/xla:literal_util", | ||||
|         "//tensorflow/compiler/xla:protobuf_util", | ||||
|         "//tensorflow/compiler/xla:shape_util", | ||||
|         "//tensorflow/compiler/xla:util", | ||||
|         "//tensorflow/compiler/xla:window_util", | ||||
|         "//tensorflow/compiler/xla:xla_data_proto_cc", | ||||
|         "//tensorflow/compiler/xla/client/lib:comparators", | ||||
|         "//tensorflow/compiler/xla/service:flatten_call_graph", | ||||
|         "//tensorflow/compiler/xla/service:hlo", | ||||
|         "//tensorflow/compiler/xla/service:hlo_casting_utils", | ||||
|         "//tensorflow/compiler/xla/service:hlo_cse", | ||||
|         "//tensorflow/compiler/xla/service:hlo_dce", | ||||
|         "//tensorflow/compiler/xla/service:hlo_pass", | ||||
|         "//tensorflow/compiler/xla/service:hlo_pass_pipeline", | ||||
|         "//tensorflow/compiler/xla/service:hlo_query", | ||||
|         "//tensorflow/compiler/xla/service:hlo_sharding_util", | ||||
|         "//tensorflow/compiler/xla/service:shape_inference", | ||||
|         "//tensorflow/compiler/xla/service:tuple_simplifier", | ||||
|         "//tensorflow/core/platform:numbers", | ||||
|         "@com_google_absl//absl/algorithm:container", | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_absl//absl/types:optional", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| tf_cc_test( | ||||
|     name = "spmd_partitioner_test", | ||||
|     srcs = ["spmd_partitioner_test.cc"], | ||||
|     deps = [ | ||||
|         ":spmd_partitioner", | ||||
|         "//tensorflow/compiler/xla:xla_data_proto_cc", | ||||
|         "//tensorflow/compiler/xla/service:hlo_matchers", | ||||
|         "//tensorflow/compiler/xla/service:hlo_parser", | ||||
|         "//tensorflow/compiler/xla/service:hlo_pass_pipeline", | ||||
|         "//tensorflow/compiler/xla/service:hlo_verifier", | ||||
|         "//tensorflow/compiler/xla/tests:hlo_test_base", | ||||
|         "//tensorflow/compiler/xla/tests:xla_internal_test_main", | ||||
|         "//tensorflow/core:test", | ||||
|     ], | ||||
| ) | ||||
							
								
								
									
										4655
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4655
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										435
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										435
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,435 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ | ||||
| #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <string> | ||||
| #include <unordered_map> | ||||
| 
 | ||||
| #include "absl/types/optional.h" | ||||
| #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| namespace spmd { | ||||
| 
 | ||||
| struct SpmdPartitionerOptions { | ||||
|   // Always exchange halo on LHS for all convolutions. If false, backprop filter
 | ||||
|   // convolution exchanges halo on RHS.
 | ||||
|   bool conv_halo_exchange_always_on_lhs = true; | ||||
| 
 | ||||
|   // The number of instructions to be reported for the highest memory profile
 | ||||
|   // instructions.
 | ||||
|   int64 report_instruction_count = 5; | ||||
| 
 | ||||
|   // The minimum size in MiB of an einsum operand to be considered using
 | ||||
|   // windowed implementation in an HLO loop.
 | ||||
|   int64 threshold_for_windowed_einsum_mib = 256; | ||||
| 
 | ||||
|   // Whether the entry computations' signature could change after partitioning.
 | ||||
|   bool allow_module_signature_change = false; | ||||
| }; | ||||
| 
 | ||||
| // Class to wrap the computation builder to capture information during SPMD
 | ||||
| // transformation.
 | ||||
| class SpmdBuilder : public HloComputation::Builder { | ||||
|  public: | ||||
|   SpmdBuilder(const std::string& name, HloInstruction* hlo) | ||||
|       : HloComputation::Builder(name) { | ||||
|     visiting_hlo_ = hlo; | ||||
|   } | ||||
|   HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction); | ||||
| 
 | ||||
|   const std::vector<HloInstruction*>& derived_instructions( | ||||
|       HloInstruction* hlo) { | ||||
|     return instructions_.at(hlo); | ||||
|   } | ||||
| 
 | ||||
|   void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; } | ||||
| 
 | ||||
|   HloInstruction* visiting_hlo() const { return visiting_hlo_; } | ||||
| 
 | ||||
|  private: | ||||
|   // Currently visiting instruction.
 | ||||
|   HloInstruction* visiting_hlo_; | ||||
| 
 | ||||
|   // Map from the currently visiting (old) instruction to new instructions
 | ||||
|   // created during SPMD partitioning.
 | ||||
|   HloInstructionMap<std::vector<HloInstruction*>> instructions_; | ||||
| }; | ||||
| 
 | ||||
| // A set of functions that create the cross-partition collective ops.
 | ||||
| struct SPMDCollectiveOpsCreator { | ||||
|   // Function used to create a partition ID HLO.
 | ||||
|   std::function<HloInstruction*(SpmdBuilder*)> create_partition_id; | ||||
| 
 | ||||
|   // Function used to create a cross-partition all-reduce HLO.
 | ||||
|   std::function<HloInstruction*(SpmdBuilder*, HloInstruction* operand, | ||||
|                                 HloComputation* reduction, int64 channel_id)> | ||||
|       create_cross_partition_all_reduce; | ||||
| 
 | ||||
|   // Function used to create a cross-partition collective-permute HLO.
 | ||||
|   std::function<HloInstruction*( | ||||
|       SpmdBuilder*, HloInstruction* operand, | ||||
|       std::vector<std::pair<int64, int64>>& src_dst_pairs, | ||||
|       int64 next_channel_id)> | ||||
|       create_cross_partition_collective_permute; | ||||
| 
 | ||||
|   // Function used to create a cross-partition all-to-all HLO.
 | ||||
|   std::function<HloInstruction*( | ||||
|       SpmdBuilder*, absl::Span<HloInstruction* const> operands, | ||||
|       const std::vector<ReplicaGroup>& replica_groups, int64 channel_id, | ||||
|       absl::optional<int64> split_dimension)> | ||||
|       create_cross_partition_all_to_all; | ||||
| }; | ||||
| 
 | ||||
| // Logger to report memory usage during SPMD partitioning.
 | ||||
| class SpmdLogger { | ||||
|  public: | ||||
|   explicit SpmdLogger(int64 report_instruction_count) | ||||
|       : report_instruction_count_(report_instruction_count) {} | ||||
|   static std::string ReportBeforePartition(const HloModule& module, | ||||
|                                            int64 report_instruction_count); | ||||
|   static std::string ReportAfterPartition(const HloModule& module, | ||||
|                                           int64 report_instruction_count); | ||||
| 
 | ||||
|   // Registers the logging for the groups of instructions created to transform
 | ||||
|   // the given hlo.
 | ||||
|   void RegisterLogEntry(HloInstruction* hlo, | ||||
|                         const std::vector<HloInstruction*>& group); | ||||
| 
 | ||||
|   std::string MakeReport(); | ||||
| 
 | ||||
|  private: | ||||
|   template <typename F> | ||||
|   static std::string ReportMemoryUsage(const HloModule& module, const F& filter, | ||||
|                                        int64 report_instruction_count); | ||||
| 
 | ||||
|   // A vector of logging messages (one for each original HLO instruction), where
 | ||||
|   // the first integer of the pair represents the size of the HBM used.
 | ||||
|   std::vector<std::pair<int64, std::string>> entries_; | ||||
| 
 | ||||
|   int64 report_instruction_count_; | ||||
| }; | ||||
| 
 | ||||
| class SpmdPartitioningVisitor; | ||||
| 
 | ||||
| class SpmdPartitioner : public HloModulePass { | ||||
|  public: | ||||
|   SpmdPartitioner(int64 num_partitions, int64 num_replicas, | ||||
|                   SpmdPartitionerOptions options); | ||||
|   SpmdPartitioner(int64 num_partitions, int64 num_replicas, | ||||
|                   SpmdPartitionerOptions options, | ||||
|                   SPMDCollectiveOpsCreator collective_ops_creator) | ||||
|       : num_partitions_(num_partitions), | ||||
|         num_replicas_(num_replicas), | ||||
|         options_(std::move(options)), | ||||
|         collective_ops_creator_(std::move(collective_ops_creator)) {} | ||||
|   absl::string_view name() const override { return "spmd-partitioning"; } | ||||
|   StatusOr<bool> Run(HloModule* module) override; | ||||
| 
 | ||||
|   // Transforms the given computation with SPMD instructions, replacing it with
 | ||||
|   // a new computation.
 | ||||
|   StatusOr<bool> PartitionComputation(HloComputation* computation, | ||||
|                                       const HloSharding& root_sharding, | ||||
|                                       int64* next_channel_id, | ||||
|                                       SpmdLogger* logger); | ||||
| 
 | ||||
|  protected: | ||||
|   virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor( | ||||
|       HloComputation* computation, int64 num_partitions, int64 num_replicas, | ||||
|       const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|       int64* next_channel_id, SpmdLogger* logger, | ||||
|       SpmdPartitionerOptions options); | ||||
| 
 | ||||
|  private: | ||||
|   // Verify that the sharding of instructions in the module are valid, and also
 | ||||
|   // fill in missing sharding information.
 | ||||
|   Status PreprocessSharding(HloModule* module); | ||||
| 
 | ||||
|   const int64 num_partitions_; | ||||
|   const int64 num_replicas_; | ||||
| 
 | ||||
|   SpmdPartitionerOptions options_; | ||||
|   SPMDCollectiveOpsCreator collective_ops_creator_; | ||||
| }; | ||||
| 
 | ||||
| // Class describes partition state of the data represented by an HLO created
 | ||||
| // during SPMD partitioning pass.
 | ||||
| //
 | ||||
| // Data on some devices may include padding region, if the base (full) shape
 | ||||
| // could not be evenly partitioned.
 | ||||
| class PartitionedHlo { | ||||
|  public: | ||||
|   // Return value for ReshardAsWindowedInput which describes the resharded HLO,
 | ||||
|   // the window for the user on the shard, and if necessary, the dynamic slice
 | ||||
|   // offsets to be applied to the output of the op being sharded.
 | ||||
|   struct WindowedInputShardReturnValue { | ||||
|     HloInstruction* sharded_input; | ||||
|     Window shard_window; | ||||
|     absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output; | ||||
|   }; | ||||
|   // A cache for resharding each partitioned HLO.
 | ||||
|   struct ReshardCache { | ||||
|     struct PerHloCache { | ||||
|       std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache; | ||||
|       std::vector< | ||||
|           std::tuple<HloSharding, Window, WindowedInputShardReturnValue>> | ||||
|           window_reshard_cache; | ||||
|     }; | ||||
|     std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache; | ||||
|   }; | ||||
|   struct PartitioningState { | ||||
|     SpmdBuilder* b; | ||||
|     HloModule* module; | ||||
|     int64 num_replicas; | ||||
|     HloInstruction* partition_id; | ||||
|     SPMDCollectiveOpsCreator collective_ops_creator; | ||||
|     int64* next_channel_id; | ||||
|     ReshardCache* reshard_cache; | ||||
|   }; | ||||
|   PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) | ||||
|       : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { | ||||
|     CHECK(hlo->has_sharding()) | ||||
|         << "PartitionedHlo is missing sharding:" << hlo->ToString(); | ||||
|     // If the tuple shape instruction does not have a tuple sharding, reassign
 | ||||
|     // to use the tuple sharding. Reshard() implementation assumes this.
 | ||||
|     if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { | ||||
|       hlo_->set_sharding( | ||||
|           hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Reshards the current SPMD instruction to a new sharding. Could only modify
 | ||||
|   // the reshard cache.
 | ||||
|   PartitionedHlo Reshard(const HloSharding& target); | ||||
| 
 | ||||
|   // Pads the garbage area of the output with the provided value.
 | ||||
|   PartitionedHlo PadWithValue(HloInstruction* pad_value) const; | ||||
| 
 | ||||
|   // Returns the SPMD instruction.
 | ||||
|   HloInstruction* hlo() const { return hlo_; } | ||||
| 
 | ||||
|   // Returns the sharding of the SPMD instruction.
 | ||||
|   const HloSharding& sharding() const { return hlo_->sharding(); } | ||||
| 
 | ||||
|   // Original full shape of the data.
 | ||||
|   const Shape& base_shape() const { return base_shape_; } | ||||
| 
 | ||||
|   int64 NewChannel() const { return (*state_.next_channel_id)++; } | ||||
| 
 | ||||
|   // Reshards the HLO to a usable partitioned input for a windowed user. Could
 | ||||
|   // only modify the reshard cache.
 | ||||
|   absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput( | ||||
|       const Window& window, const HloSharding& target, | ||||
|       HloInstruction* pad_value, bool mask_invalid_region = true); | ||||
| 
 | ||||
|  private: | ||||
|   // Same as Reshard except that it does not explicitly modify the reshard
 | ||||
|   // cache, although it would indirectly modify by calling Replicate().
 | ||||
|   PartitionedHlo ReshardNoCache(const HloSharding& target); | ||||
| 
 | ||||
|   // Helper function to replicate the data on all devices. Could only modify
 | ||||
|   // the reshard cache.
 | ||||
|   PartitionedHlo Replicate(); | ||||
| 
 | ||||
|   // Helper function to broadcast data from a single device to all devices.
 | ||||
|   PartitionedHlo Broadcast() const; | ||||
| 
 | ||||
|   // Helper function to reshard the tensor using AllToAll (instead of the
 | ||||
|   // default of Replicate followed by Slice).
 | ||||
|   PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const; | ||||
| 
 | ||||
|   // Helper function to reshard the tensor using CollectivePermute.
 | ||||
|   PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; | ||||
| 
 | ||||
|   // SPMD instruction.
 | ||||
|   HloInstruction* hlo_; | ||||
| 
 | ||||
|   // The original shape of the data before SPMD transformation is applied.
 | ||||
|   Shape base_shape_; | ||||
| 
 | ||||
|   PartitioningState state_; | ||||
| }; | ||||
| 
 | ||||
| struct DotGeneralDimsMapping { | ||||
|   // The dimension numbers for the operands and output corresponding to a
 | ||||
|   // logical dimension (e.g., batch, contracting, non-contracting). If an
 | ||||
|   // operand or the output doesn't have the logical dimension, it is set to
 | ||||
|   // -1.
 | ||||
|   struct DimsMapping { | ||||
|     int64 lhs; | ||||
|     int64 rhs; | ||||
|     int64 output; | ||||
|   }; | ||||
|   std::vector<DimsMapping> batch_dims; | ||||
|   std::vector<DimsMapping> contracting_dims; | ||||
|   std::vector<DimsMapping> lhs_non_contracting_dims; | ||||
|   std::vector<DimsMapping> rhs_non_contracting_dims; | ||||
| }; | ||||
| 
 | ||||
| class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { | ||||
|  public: | ||||
|   SpmdPartitioningVisitor( | ||||
|       HloComputation* computation, int64 num_partitions, int64 num_replicas, | ||||
|       const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|       int64* next_channel_id, SpmdLogger* logger, | ||||
|       SpmdPartitionerOptions options, SpmdPartitioner* partitioner); | ||||
| 
 | ||||
|   Status DefaultAction(HloInstruction* hlo) override; | ||||
|   Status HandleAllReduce(HloInstruction* hlo) override; | ||||
|   Status HandleBroadcast(HloInstruction* hlo) override; | ||||
|   Status HandleConstant(HloInstruction* hlo) override; | ||||
|   Status HandleCustomCall(HloInstruction* hlo) override; | ||||
|   Status HandleDot(HloInstruction* hlo) override; | ||||
|   Status HandleDynamicSlice(HloInstruction* hlo) override; | ||||
|   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; | ||||
|   Status HandleGather(HloInstruction* hlo) override; | ||||
|   Status HandleGetTupleElement(HloInstruction* hlo) override; | ||||
|   Status HandleInfeed(HloInstruction* hlo) override; | ||||
|   Status HandleOutfeed(HloInstruction* hlo) override; | ||||
|   Status HandlePad(HloInstruction* hlo) override; | ||||
|   Status HandleParameter(HloInstruction* hlo) override; | ||||
|   Status HandleReduce(HloInstruction* hlo) override; | ||||
|   Status HandleReverse(HloInstruction* hlo) override; | ||||
|   Status HandleWhile(HloInstruction* hlo) override; | ||||
|   Status HandleConditional(HloInstruction* hlo) override; | ||||
|   Status HandleReduceWindow(HloInstruction* hlo) override; | ||||
|   Status HandleSelectAndScatter(HloInstruction* hlo) override; | ||||
|   Status HandleTuple(HloInstruction* hlo) override; | ||||
|   Status HandleRng(HloInstruction* hlo) override; | ||||
|   Status HandleConvolution(HloInstruction* hlo) override; | ||||
|   Status HandleConcatenate(HloInstruction* hlo) override; | ||||
|   Status HandleScatter(HloInstruction* hlo) override; | ||||
|   Status HandleSlice(HloInstruction* hlo) override; | ||||
|   Status HandleSort(HloInstruction* hlo) override; | ||||
|   Status HandleTranspose(HloInstruction* hlo) override; | ||||
|   Status HandleReshape(HloInstruction* hlo) override; | ||||
|   Status HandleIota(HloInstruction* hlo) override; | ||||
|   Status HandlePartitionId(HloInstruction* hlo) override; | ||||
| 
 | ||||
|   // Handles convolution where both LHS and RHS operands are tiled.
 | ||||
|   Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo); | ||||
| 
 | ||||
|   // Implementation of dot partitioning given DotGeneralDimsMapping.
 | ||||
|   Status HandleDotHelper( | ||||
|       HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, | ||||
|       const std::function<StatusOr<HloInstruction*>( | ||||
|           HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); | ||||
| 
 | ||||
|   // Common handle for elementwise HLOs.
 | ||||
|   Status HandleElementwise(HloInstruction* hlo); | ||||
| 
 | ||||
|   // Common handle for HLOs that runs on a single device.
 | ||||
|   Status HandleSingleDevice(const HloInstruction* hlo); | ||||
| 
 | ||||
|   // Returns the PartitionedHlo that corresponds to the original hlo.
 | ||||
|   PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { | ||||
|     CHECK_EQ(partitioned_instructions_.count(hlo), 1); | ||||
|     return partitioned_instructions_.find(hlo)->second; | ||||
|   } | ||||
| 
 | ||||
|   // Sets the PartitionedHlo for the original hlo.
 | ||||
|   void SetPartitionedHlo(const HloInstruction* hlo, | ||||
|                          const PartitionedHlo& partitioned_hlo) { | ||||
|     CHECK_EQ(partitioned_instructions_.count(hlo), 0); | ||||
|     partitioned_instructions_.emplace(hlo, partitioned_hlo); | ||||
|     changed_ = true; | ||||
|   } | ||||
| 
 | ||||
|   // Convenient wrapper that creates PartitionedHlo from the result of the func
 | ||||
|   // and maps it to the given original hlo.
 | ||||
|   void SetPartitionedHlo(const HloInstruction* hlo, | ||||
|                          const std::function<HloInstruction*()>& func) { | ||||
|     HloInstruction* new_hlo = func(); | ||||
|     new_hlo->set_sharding(hlo->sharding()); | ||||
|     new_hlo->set_metadata(hlo->metadata()); | ||||
|     SetPartitionedHlo( | ||||
|         hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); | ||||
|     changed_ = true; | ||||
|   } | ||||
| 
 | ||||
|   int64 NewChannel() { return (*next_channel_id_)++; } | ||||
| 
 | ||||
|   PartitionedHlo::PartitioningState MakePartitioningState() { | ||||
|     return PartitionedHlo::PartitioningState{ | ||||
|         .b = &b_, | ||||
|         .module = module_, | ||||
|         .num_replicas = num_replicas_, | ||||
|         .partition_id = partition_id_, | ||||
|         .collective_ops_creator = collective_ops_creator_, | ||||
|         .next_channel_id = next_channel_id_, | ||||
|         .reshard_cache = &reshard_cache_}; | ||||
|   } | ||||
| 
 | ||||
|   SpmdBuilder* builder() { return &b_; } | ||||
| 
 | ||||
|   StatusOr<bool> DoPartition(HloComputation* computation, | ||||
|                              const HloSharding& root_sharding); | ||||
| 
 | ||||
|  private: | ||||
|   Status Preprocess(HloInstruction* hlo) override; | ||||
|   Status Postprocess(HloInstruction* hlo) override; | ||||
| 
 | ||||
|   // Performs code motion for windowed dot-general loops in
 | ||||
|   // windowed_dot_general_loops_. Invoked after the visitor finishes traversing
 | ||||
|   // the graph.
 | ||||
|   Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation); | ||||
| 
 | ||||
|   bool changed_; | ||||
|   HloModule* module_; | ||||
|   int64 num_partitions_; | ||||
|   int64 num_replicas_; | ||||
| 
 | ||||
|   SPMDCollectiveOpsCreator collective_ops_creator_; | ||||
| 
 | ||||
|   // Tracks the next channel id to use for cross-partition all-reduce.
 | ||||
|   int64* next_channel_id_; | ||||
|   SpmdBuilder b_; | ||||
| 
 | ||||
|   HloInstruction* partition_id_; | ||||
| 
 | ||||
|   PartitionedHlo::ReshardCache reshard_cache_; | ||||
| 
 | ||||
|   // Mapping from the instruction in the original computation to the new SPMD
 | ||||
|   // partitioned instruction.
 | ||||
|   ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_; | ||||
| 
 | ||||
|   // Information about a loop created for windowed dot-general. Used when
 | ||||
|   // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
 | ||||
|   // finishes traversing the graph.
 | ||||
|   struct WindowedDotGeneralLoop { | ||||
|     HloInstruction* while_loop; | ||||
|     int64 windowed_operand; | ||||
|     bool windowed_in_contracting_dims; | ||||
|     bool windowed_in_batch_dims; | ||||
|   }; | ||||
|   std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_; | ||||
| 
 | ||||
|   HloInstruction* visiting_hlo_; | ||||
|   SpmdLogger* logger_; | ||||
|   const SpmdPartitionerOptions options_; | ||||
|   SpmdPartitioner* partitioner_; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace spmd
 | ||||
| }  // namespace xla
 | ||||
| #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
 | ||||
							
								
								
									
										3191
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3191
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										662
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										662
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,662 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" | ||||
| 
 | ||||
| #include "absl/types/optional.h" | ||||
| #include "tensorflow/compiler/xla/literal_util.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding.h" | ||||
| #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" | ||||
| #include "tensorflow/compiler/xla/shape_util.h" | ||||
| #include "tensorflow/compiler/xla/xla_data.pb.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| namespace spmd { | ||||
| 
 | ||||
| bool HasReplicatedSharding(const HloSharding& sharding) { | ||||
|   if (sharding.IsTuple()) { | ||||
|     return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding); | ||||
|   } | ||||
|   return sharding.IsReplicated(); | ||||
| } | ||||
| 
 | ||||
| HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { | ||||
|   if (shape.IsTuple()) { | ||||
|     std::vector<HloInstruction*> elements; | ||||
|     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { | ||||
|       elements.push_back( | ||||
|           CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b)); | ||||
|     } | ||||
|     return b->AddInstruction(HloInstruction::CreateTuple(elements)); | ||||
|   } | ||||
| 
 | ||||
|   if (shape.IsToken()) { | ||||
|     return b->AddInstruction(HloInstruction::CreateToken()); | ||||
|   } | ||||
|   auto zero = b->AddInstruction( | ||||
|       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); | ||||
|   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); | ||||
| } | ||||
| 
 | ||||
| HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { | ||||
|   HloComputation::Builder sum_b("add"); | ||||
|   auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( | ||||
|       /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); | ||||
|   auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( | ||||
|       /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); | ||||
|   if (type == PRED) { | ||||
|     sum_b.AddInstruction(HloInstruction::CreateBinary( | ||||
|         ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); | ||||
|   } else { | ||||
|     sum_b.AddInstruction(HloInstruction::CreateBinary( | ||||
|         ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); | ||||
|   } | ||||
|   HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); | ||||
|   return reduction; | ||||
| } | ||||
| 
 | ||||
| bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { | ||||
|   if (sharding.IsTuple()) { | ||||
|     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { | ||||
|       if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i), | ||||
|                             sharding.GetSubSharding(shape, {i}))) { | ||||
|         return false; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   if (sharding.IsTileMaximal()) { | ||||
|     return sharding.IsReplicated(); | ||||
|   } | ||||
|   for (int64 i = 0; i < shape.dimensions_size(); ++i) { | ||||
|     if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { | ||||
|   if (sharding.IsTuple()) { | ||||
|     std::vector<Shape> subshapes; | ||||
|     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { | ||||
|       subshapes.push_back( | ||||
|           MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), | ||||
|                                sharding.GetSubSharding(shape, {i}))); | ||||
|     } | ||||
|     return ShapeUtil::MakeTupleShape(subshapes); | ||||
|   } | ||||
|   return sharding.TileShape(shape); | ||||
| } | ||||
| 
 | ||||
| Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, | ||||
|                                           const HloSharding& sharding, | ||||
|                                           int64 partition_id) { | ||||
|   if (sharding.IsTuple()) { | ||||
|     std::vector<Shape> subshapes; | ||||
|     for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { | ||||
|       subshapes.push_back(MakeNonPaddedShapeForGivenPartition( | ||||
|           ShapeUtil::GetTupleElementShape(shape, i), | ||||
|           sharding.GetSubSharding(shape, {i}), partition_id)); | ||||
|     } | ||||
|     return ShapeUtil::MakeTupleShape(subshapes); | ||||
|   } | ||||
| 
 | ||||
|   auto partition_shape = shape; | ||||
|   std::vector<int64> tile_offset = | ||||
|       sharding.TileOffsetForDevice(shape, partition_id); | ||||
|   std::vector<int64> tile_limit = | ||||
|       sharding.TileLimitForDevice(shape, partition_id); | ||||
|   for (int64 i = 0; i < tile_offset.size(); ++i) { | ||||
|     if (sharding.UsesDevice(partition_id)) { | ||||
|       partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]); | ||||
|     } else { | ||||
|       partition_shape.set_dimensions(i, 0); | ||||
|     } | ||||
|   } | ||||
|   return partition_shape; | ||||
| } | ||||
| 
 | ||||
| std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape, | ||||
|                                                   const HloSharding& sharding, | ||||
|                                                   HloInstruction* partition_id, | ||||
|                                                   SpmdBuilder* b) { | ||||
|   CHECK(!shape.IsTuple()); | ||||
| 
 | ||||
|   Array2D<int32> offset_array( | ||||
|       {sharding.tile_assignment().num_elements(), shape.rank()}); | ||||
|   offset_array.Each([&](int64 i, int64 j, int32* value) { | ||||
|     *value = sharding.TileOffsetForDevice(shape, i)[j]; | ||||
|   }); | ||||
|   auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( | ||||
|       LiteralUtil::CreateR2FromArray2D(offset_array))); | ||||
|   std::vector<HloInstruction*> offsets; | ||||
|   for (int64 i = 0; i < shape.rank(); ++i) { | ||||
|     if (sharding.tile_assignment().dim(i) == 1) { | ||||
|       offsets.push_back(b->AddInstruction( | ||||
|           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); | ||||
|     } else { | ||||
|       auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( | ||||
|           ShapeUtil::MakeShape(S32, {1, 1}), offset_table, | ||||
|           {partition_id, b->AddInstruction(HloInstruction::CreateConstant( | ||||
|                              LiteralUtil::CreateR0<uint32>(i)))}, | ||||
|           {1, 1})); | ||||
|       offsets.push_back(b->AddInstruction( | ||||
|           HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); | ||||
|     } | ||||
|   } | ||||
|   return offsets; | ||||
| } | ||||
| 
 | ||||
| std::vector<HloInstruction*> MakeTiledPartitionOrdinals( | ||||
|     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { | ||||
|   CHECK(!sharding.IsTileMaximal()); | ||||
|   auto table_shape = | ||||
|       ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); | ||||
|   return MakePartitionOffsets(table_shape, sharding, partition_id, b); | ||||
| } | ||||
| 
 | ||||
| HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, | ||||
|                            SpmdBuilder* b, HloComputation* computation) { | ||||
|   CHECK(b == nullptr || computation == nullptr); | ||||
|   if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) { | ||||
|     return hlo; | ||||
|   } | ||||
|   PaddingConfig padding_config; | ||||
|   for (int64 i = 0; i < padded_shape.rank(); ++i) { | ||||
|     auto padding_config_dim = padding_config.add_dimensions(); | ||||
|     padding_config_dim->set_edge_padding_low(0); | ||||
|     padding_config_dim->set_interior_padding(0); | ||||
|     padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - | ||||
|                                               hlo->shape().dimensions(i)); | ||||
|   } | ||||
|   auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) { | ||||
|     if (b == nullptr) { | ||||
|       return computation->AddInstruction(std::move(to_add)); | ||||
|     } | ||||
|     return b->AddInstruction(std::move(to_add)); | ||||
|   }; | ||||
|   auto zero = add_hlo(HloInstruction::CreateConstant( | ||||
|       LiteralUtil::Zero(hlo->shape().element_type()))); | ||||
|   return add_hlo( | ||||
|       HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config)); | ||||
| } | ||||
| 
 | ||||
| Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, | ||||
|                                           const HloSharding& sharding) { | ||||
|   if (sharding.IsTileMaximal()) { | ||||
|     return base_shape; | ||||
|   } | ||||
|   if (EvenlyPartitions(base_shape, sharding)) { | ||||
|     return base_shape; | ||||
|   } | ||||
|   auto shard_shape = MakePartitionedShape(base_shape, sharding); | ||||
|   Shape padded_base_shape = base_shape; | ||||
|   for (int64 i = 0; i < padded_base_shape.rank(); ++i) { | ||||
|     padded_base_shape.set_dimensions( | ||||
|         i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); | ||||
|   } | ||||
|   return padded_base_shape; | ||||
| } | ||||
| 
 | ||||
| HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( | ||||
|     HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { | ||||
|   auto padded_base_shape = | ||||
|       GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding); | ||||
|   if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) { | ||||
|     return hlo; | ||||
|   } | ||||
|   return PadToShape(hlo, padded_base_shape, b); | ||||
| } | ||||
| 
 | ||||
| absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) { | ||||
|   if (sharding.IsTileMaximal()) { | ||||
|     return absl::nullopt; | ||||
|   } | ||||
|   int64 dim = -1; | ||||
|   for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { | ||||
|     if (sharding.tile_assignment().dim(i) > 1) { | ||||
|       if (dim != -1) { | ||||
|         return absl::nullopt; | ||||
|       } | ||||
|       dim = i; | ||||
|     } | ||||
|   } | ||||
|   CHECK_NE(dim, -1); | ||||
|   return dim; | ||||
| } | ||||
| 
 | ||||
| MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation( | ||||
|     int64 multiplier, int64 offset, int64 divisor) | ||||
|     : multiplier_(multiplier), offset_(offset), divisor_(divisor) { | ||||
|   CHECK_GT(divisor_, 0); | ||||
|   Simplify(); | ||||
| } | ||||
| 
 | ||||
| OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-( | ||||
|     const MultiplyAddDivideOffsetCalculation& other) const { | ||||
|   if (divisor_ == 1 && other.divisor_ == 1) { | ||||
|     return OffsetCalculation(MultiplyAddDivideOffsetCalculation( | ||||
|         multiplier_ - other.multiplier_, offset_ - other.offset_, 1)); | ||||
|   } | ||||
|   return OffsetCalculation(HloOpcode::kSubtract, *this, other); | ||||
| } | ||||
| 
 | ||||
| void MultiplyAddDivideOffsetCalculation::Simplify() { | ||||
|   // We could simplify the calculation when multiplier is a multiple of
 | ||||
|   // divisor_. However, when offset_ is not a multiple of divisor_, we must
 | ||||
|   // make sure that offset_ and multiplier_ are both non-negative or both
 | ||||
|   // non-positive. E.g., (3 * i  - 1) / 3 is not equivalent to i or i - 1.
 | ||||
|   if (divisor_ != 1 && multiplier_ % divisor_ == 0 && | ||||
|       (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) { | ||||
|     multiplier_ /= divisor_; | ||||
|     offset_ /= divisor_; | ||||
|     divisor_ = 1; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const { | ||||
|   return (shard_ordinal * multiplier_ + offset_) / divisor_; | ||||
| } | ||||
| 
 | ||||
| HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate( | ||||
|     HloInstruction* shard_ordinal, SpmdBuilder* b) const { | ||||
|   auto scalar_shape = ShapeUtil::MakeShape(S32, {}); | ||||
|   if (multiplier_ == 0) { | ||||
|     return b->AddInstruction(HloInstruction::CreateConstant( | ||||
|         LiteralUtil::CreateR0<int32>(offset_ / divisor_))); | ||||
|   } | ||||
|   HloInstruction* result = shard_ordinal; | ||||
|   if (multiplier_ != 1) { | ||||
|     result = b->AddInstruction(HloInstruction::CreateBinary( | ||||
|         scalar_shape, HloOpcode::kMultiply, shard_ordinal, | ||||
|         b->AddInstruction(HloInstruction::CreateConstant( | ||||
|             LiteralUtil::CreateR0<int32>(multiplier_))))); | ||||
|   } | ||||
|   if (offset_ != 0) { | ||||
|     auto offset = b->AddInstruction( | ||||
|         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_))); | ||||
|     result = b->AddInstruction(HloInstruction::CreateBinary( | ||||
|         scalar_shape, HloOpcode::kAdd, result, offset)); | ||||
|   } | ||||
|   if (divisor_ != 1) { | ||||
|     auto divisor = b->AddInstruction( | ||||
|         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_))); | ||||
|     result = b->AddInstruction(HloInstruction::CreateBinary( | ||||
|         scalar_shape, HloOpcode::kDivide, result, divisor)); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| int64 MultiplyAddDivideOffsetCalculation::MaxInRange( | ||||
|     int64 start_ordinal, int64 limit_ordinal) const { | ||||
|   int64 max = Calculate(start_ordinal); | ||||
|   for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { | ||||
|     max = std::max(max, Calculate(i)); | ||||
|   } | ||||
|   return max; | ||||
| } | ||||
| 
 | ||||
| OffsetCalculation& OffsetCalculation::operator=( | ||||
|     const OffsetCalculation& other) { | ||||
|   opcode_ = other.opcode_; | ||||
|   copy_from_ = other.copy_from_; | ||||
|   if (opcode_ != HloOpcode::kCopy) { | ||||
|     lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_); | ||||
|     rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_); | ||||
|   } | ||||
|   return *this; | ||||
| } | ||||
| 
 | ||||
| bool OffsetCalculation::IsConstant() const { | ||||
|   if (opcode_ == HloOpcode::kCopy) { | ||||
|     return copy_from_.IsConstant(); | ||||
|   } | ||||
|   if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) { | ||||
|     return true; | ||||
|   } | ||||
|   return lhs_->IsConstant() && rhs_->IsConstant(); | ||||
| } | ||||
| 
 | ||||
| OffsetCalculation OffsetCalculation::operator-( | ||||
|     const OffsetCalculation& other) const { | ||||
|   if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) { | ||||
|     return copy_from_ - other.copy_from_; | ||||
|   } | ||||
|   return OffsetCalculation(HloOpcode::kSubtract, *this, other); | ||||
| } | ||||
| 
 | ||||
| bool OffsetCalculation::operator==(const OffsetCalculation& other) const { | ||||
|   if (opcode_ != other.opcode_) { | ||||
|     return false; | ||||
|   } | ||||
|   if (opcode_ == HloOpcode::kCopy) { | ||||
|     return copy_from_ == other.copy_from_; | ||||
|   } | ||||
|   return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_; | ||||
| } | ||||
| 
 | ||||
| int64 OffsetCalculation::Calculate(int64 shard_ordinal) const { | ||||
|   switch (opcode_) { | ||||
|     case HloOpcode::kCopy: | ||||
|       return copy_from_.Calculate(shard_ordinal); | ||||
|     case HloOpcode::kSubtract: | ||||
|       return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal); | ||||
|     case HloOpcode::kMultiply: | ||||
|       return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal); | ||||
|     default: | ||||
|       LOG(FATAL) << "Should not happen"; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal, | ||||
|                                              SpmdBuilder* b) const { | ||||
|   if (opcode_ == HloOpcode::kCopy) { | ||||
|     return copy_from_.Calculate(shard_ordinal, b); | ||||
|   } | ||||
|   auto lhs = lhs_->Calculate(shard_ordinal, b); | ||||
|   auto rhs = rhs_->Calculate(shard_ordinal, b); | ||||
|   return b->AddInstruction( | ||||
|       HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs)); | ||||
| } | ||||
| 
 | ||||
| int64 OffsetCalculation::MaxInRange(int64 start_ordinal, | ||||
|                                     int64 limit_ordinal) const { | ||||
|   if (IsConstant()) { | ||||
|     return Calculate(start_ordinal); | ||||
|   } | ||||
|   if (opcode_ == HloOpcode::kCopy) { | ||||
|     return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1)); | ||||
|   } | ||||
|   int64 max = Calculate(start_ordinal); | ||||
|   for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { | ||||
|     max = std::max(max, Calculate(i)); | ||||
|   } | ||||
|   return max; | ||||
| } | ||||
| 
 | ||||
| absl::optional<HloInstruction*> ExchangeHalo( | ||||
|     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, | ||||
|     const OffsetCalculation& right_halo_size_function, int64 dim, | ||||
|     const HloSharding& target, | ||||
|     const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|     int64* next_channel_id, SpmdBuilder* b) { | ||||
|   int64 input_shard_size = hlo->shape().dimensions(dim); | ||||
|   int64 shard_count = target.tile_assignment().dim(dim); | ||||
| 
 | ||||
|   std::vector<HloInstruction*> concat_pieces; | ||||
| 
 | ||||
|   int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); | ||||
|   if (max_left_halo_size > input_shard_size) { | ||||
|     VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; | ||||
|     return absl::nullopt; | ||||
|   } | ||||
|   if (max_left_halo_size > 0) { | ||||
|     std::vector<std::pair<int64, int64>> source_target_pairs; | ||||
|     target.tile_assignment().Each( | ||||
|         [&](absl::Span<const int64> indices, int64 device) { | ||||
|           if (indices[dim] > 0) { | ||||
|             std::vector<int64> source_indices(indices.begin(), indices.end()); | ||||
|             source_indices[dim] -= 1; | ||||
|             source_target_pairs.emplace_back( | ||||
|                 target.tile_assignment()(source_indices), device); | ||||
|           } | ||||
|         }); | ||||
|     auto halo_shape = hlo->shape(); | ||||
|     auto source_halo_slice = hlo; | ||||
|     if (max_left_halo_size != hlo->shape().dimensions(dim)) { | ||||
|       halo_shape.set_dimensions(dim, max_left_halo_size); | ||||
|       std::vector<int64> halo_start_indices(halo_shape.rank(), 0); | ||||
|       halo_start_indices[dim] = | ||||
|           hlo->shape().dimensions(dim) - max_left_halo_size; | ||||
|       std::vector<int64> halo_slice_strides(halo_shape.rank(), 1); | ||||
| 
 | ||||
|       source_halo_slice = b->AddInstruction( | ||||
|           hlo->CreateSlice(halo_shape, hlo, halo_start_indices, | ||||
|                            hlo->shape().dimensions(), halo_slice_strides)); | ||||
|     } | ||||
|     auto left_halo = | ||||
|         collective_ops_creator.create_cross_partition_collective_permute( | ||||
|             b, source_halo_slice, source_target_pairs, (*next_channel_id)++); | ||||
|     concat_pieces.push_back(left_halo); | ||||
|   } | ||||
| 
 | ||||
|   concat_pieces.push_back(hlo); | ||||
| 
 | ||||
|   // Right halo.
 | ||||
|   int64 max_right_halo_size = | ||||
|       right_halo_size_function.MaxInRange(0, shard_count - 1); | ||||
|   if (max_right_halo_size > input_shard_size) { | ||||
|     VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; | ||||
|     return absl::nullopt; | ||||
|   } | ||||
|   if (max_right_halo_size > 0) { | ||||
|     std::vector<std::pair<int64, int64>> source_target_pairs; | ||||
|     target.tile_assignment().Each( | ||||
|         [&](absl::Span<const int64> indices, int64 device) { | ||||
|           if (indices[dim] > 0) { | ||||
|             std::vector<int64> target_indices(indices.begin(), indices.end()); | ||||
|             target_indices[dim] -= 1; | ||||
|             source_target_pairs.emplace_back( | ||||
|                 device, target.tile_assignment()(target_indices)); | ||||
|           } | ||||
|         }); | ||||
|     auto halo_shape = hlo->shape(); | ||||
|     halo_shape.set_dimensions(dim, max_right_halo_size); | ||||
|     std::vector<int64> halo_start_indices(halo_shape.rank(), 0); | ||||
|     std::vector<int64> halo_slice_strides(halo_shape.rank(), 1); | ||||
| 
 | ||||
|     auto source_halo_slice = b->AddInstruction( | ||||
|         hlo->CreateSlice(halo_shape, hlo, halo_start_indices, | ||||
|                          halo_shape.dimensions(), halo_slice_strides)); | ||||
|     auto right_halo = | ||||
|         collective_ops_creator.create_cross_partition_collective_permute( | ||||
|             b, source_halo_slice, source_target_pairs, (*next_channel_id)++); | ||||
|     concat_pieces.push_back(right_halo); | ||||
|   } | ||||
| 
 | ||||
|   auto concat = hlo; | ||||
|   // Concat with halos/padding.
 | ||||
|   if (concat_pieces.size() > 1) { | ||||
|     auto concat_shape = hlo->shape(); | ||||
|     int64 concat_dim_size = 0; | ||||
|     for (auto piece : concat_pieces) { | ||||
|       concat_dim_size += piece->shape().dimensions(dim); | ||||
|     } | ||||
|     concat_shape.set_dimensions(dim, concat_dim_size); | ||||
|     concat = b->AddInstruction( | ||||
|         HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim)); | ||||
|   } | ||||
| 
 | ||||
|   return concat; | ||||
| } | ||||
| 
 | ||||
| absl::optional<HloInstruction*> ExchangeHalo( | ||||
|     HloInstruction* hlo, | ||||
|     std::vector<OffsetCalculation> left_halo_size_functions, | ||||
|     std::vector<OffsetCalculation> right_halo_size_functions, | ||||
|     const HloSharding& target, | ||||
|     const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|     int64* next_channel_id, SpmdBuilder* b) { | ||||
|   CHECK(left_halo_size_functions.size() == hlo->shape().rank()); | ||||
|   CHECK(right_halo_size_functions.size() == hlo->shape().rank()); | ||||
| 
 | ||||
|   HloInstruction* visiting_hlo = hlo; | ||||
|   for (int dim = 0; dim < hlo->shape().rank(); ++dim) { | ||||
|     auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim], | ||||
|                                right_halo_size_functions[dim], dim, target, | ||||
|                                collective_ops_creator, next_channel_id, b); | ||||
|     if (!concat) { | ||||
|       return absl::nullopt; | ||||
|     } | ||||
|     visiting_hlo = *concat; | ||||
|   } | ||||
|   return visiting_hlo; | ||||
| } | ||||
| 
 | ||||
| absl::optional<HloInstruction*> ExchangeHaloAndGetValidData( | ||||
|     HloInstruction* hlo, const Shape& base_shape, | ||||
|     const OffsetCalculation& left_halo_size_function, | ||||
|     const OffsetCalculation& right_halo_size_function, | ||||
|     int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, | ||||
|     int64 shard_size_with_halo, int64 dim, const HloSharding& target, | ||||
|     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, | ||||
|     HloInstruction* partition_ordinal, | ||||
|     const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|     int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { | ||||
|   auto halo_exchange_result = | ||||
|       ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim, | ||||
|                    target, collective_ops_creator, next_channel_id, b); | ||||
|   if (!halo_exchange_result) { | ||||
|     return absl::nullopt; | ||||
|   } | ||||
|   auto concat = *halo_exchange_result; | ||||
|   int64 shard_count = target.tile_assignment().dim(dim); | ||||
|   int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); | ||||
| 
 | ||||
|   // Now we determine if we need extra padding after the concat.
 | ||||
|   //
 | ||||
|   // The max of halo size or the first shard's explicit left padding.
 | ||||
|   int64 max_left_halo_or_padding_size = | ||||
|       std::max(std::max(int64{0}, max_left_halo_size), | ||||
|                explicit_left_padding_on_full_shape); | ||||
|   // The calculation that returns the dynamic slice index for a shard on the
 | ||||
|   // padded concat, which is the difference between
 | ||||
|   // max_left_halo_or_padding_size and its left halo size.
 | ||||
|   auto start_offset_on_padded_concat_calculation = | ||||
|       OffsetCalculation(MultiplyAddDivideOffsetCalculation( | ||||
|           0, max_left_halo_or_padding_size, 1)) - | ||||
|       left_halo_size_function; | ||||
| 
 | ||||
|   // See if we need to pad the concat before dynamic slice.
 | ||||
|   int64 extra_left_padding = | ||||
|       std::max(int64{0}, max_left_halo_or_padding_size - | ||||
|                              std::max(int64{0}, max_left_halo_size)); | ||||
|   int64 extra_right_padding = | ||||
|       start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) + | ||||
|       shard_size_with_halo - concat->shape().dimensions(dim) - | ||||
|       extra_left_padding; | ||||
|   extra_right_padding = std::max(int64{0}, extra_right_padding); | ||||
|   if (extra_left_padding > 0 || extra_right_padding > 0) { | ||||
|     PaddingConfig padding_config; | ||||
|     auto padded_concat_shape = concat->shape(); | ||||
|     for (int64 i = 0; i < base_shape.rank(); ++i) { | ||||
|       auto padding_config_dim = padding_config.add_dimensions(); | ||||
|       padding_config_dim->set_interior_padding(0); | ||||
|       padding_config_dim->set_edge_padding_low(0); | ||||
|       padding_config_dim->set_edge_padding_high(0); | ||||
|       if (i != dim) { | ||||
|         continue; | ||||
|       } | ||||
|       padding_config_dim->set_edge_padding_low(extra_left_padding); | ||||
|       padding_config_dim->set_edge_padding_high(extra_right_padding); | ||||
|       padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) + | ||||
|                                                   extra_left_padding + | ||||
|                                                   extra_right_padding); | ||||
|     } | ||||
|     concat = b->AddInstruction(HloInstruction::CreatePad( | ||||
|         padded_concat_shape, concat, pad_value, padding_config)); | ||||
|   } | ||||
| 
 | ||||
|   auto valid_slice = concat; | ||||
|   if (shard_size_with_halo != concat->shape().dimensions(dim)) { | ||||
|     // Concat is bigger than the shard shape, so we need a dynamic slice.
 | ||||
|     CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim)); | ||||
|     auto slice_shape = concat->shape(); | ||||
|     slice_shape.set_dimensions(dim, shard_size_with_halo); | ||||
| 
 | ||||
|     if (left_halo_size_function.IsConstant() && | ||||
|         left_halo_size_function.Calculate(0) == | ||||
|             explicit_left_padding_on_full_shape) { | ||||
|       std::vector<int64> start_indices(slice_shape.rank(), 0); | ||||
|       std::vector<int64> strides(slice_shape.rank(), 1); | ||||
|       valid_slice = b->AddInstruction( | ||||
|           HloInstruction::CreateSlice(slice_shape, concat, start_indices, | ||||
|                                       slice_shape.dimensions(), strides)); | ||||
|     } else { | ||||
|       auto zero = b->AddInstruction( | ||||
|           HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); | ||||
|       std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero); | ||||
|       slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( | ||||
|           partition_ordinal, b); | ||||
|       valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice( | ||||
|           slice_shape, concat, slice_offsets, slice_shape.dimensions())); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   if (!mask_invalid_region) { | ||||
|     return valid_slice; | ||||
|   } | ||||
| 
 | ||||
|   int64 total_right_padding = padded_full_shape_size - | ||||
|                               base_shape.dimensions(dim) - | ||||
|                               explicit_left_padding_on_full_shape; | ||||
|   // Mask off garbage data due to uneven partition or low/high padding.
 | ||||
|   if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) { | ||||
|     auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32); | ||||
|     auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); | ||||
|     auto broadcast_start_index_in_padded_shape = | ||||
|         b->AddInstruction(HloInstruction::CreateBroadcast( | ||||
|             index_shape, offset_on_padded_shape, {})); | ||||
|     auto index_in_padded_shape = b->AddInstruction( | ||||
|         HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota, | ||||
|                                      broadcast_start_index_in_padded_shape)); | ||||
|     auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); | ||||
|     std::vector<HloInstruction*> predicates; | ||||
|     if (explicit_left_padding_on_full_shape > 0) { | ||||
|       auto valid_index_start = | ||||
|           b->AddInstruction(HloInstruction::CreateBroadcast( | ||||
|               index_shape, | ||||
|               b->AddInstruction( | ||||
|                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>( | ||||
|                       explicit_left_padding_on_full_shape))), | ||||
|               {})); | ||||
|       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( | ||||
|           mask_shape, index_in_padded_shape, valid_index_start, | ||||
|           ComparisonDirection::kGe))); | ||||
|     } | ||||
|     if (total_right_padding > 0) { | ||||
|       auto valid_index_limit = | ||||
|           b->AddInstruction(HloInstruction::CreateBroadcast( | ||||
|               index_shape, | ||||
|               b->AddInstruction( | ||||
|                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>( | ||||
|                       base_shape.dimensions(dim) + | ||||
|                       explicit_left_padding_on_full_shape))), | ||||
|               {})); | ||||
|       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( | ||||
|           mask_shape, index_in_padded_shape, valid_index_limit, | ||||
|           ComparisonDirection::kLt))); | ||||
|     } | ||||
|     CHECK(!predicates.empty()); | ||||
|     auto is_valid = | ||||
|         predicates.size() == 2 | ||||
|             ? b->AddInstruction(HloInstruction::CreateBinary( | ||||
|                   mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) | ||||
|             : predicates[0]; | ||||
|     auto masking_value = b->AddInstruction( | ||||
|         HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {})); | ||||
|     valid_slice = b->AddInstruction( | ||||
|         HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect, | ||||
|                                       is_valid, valid_slice, masking_value)); | ||||
|   } | ||||
|   return valid_slice; | ||||
| } | ||||
| 
 | ||||
| }  // namespace spmd
 | ||||
| }  // namespace xla
 | ||||
							
								
								
									
										229
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										229
									
								
								tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,229 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ | ||||
| #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <string> | ||||
| 
 | ||||
| #include "absl/types/optional.h" | ||||
| #include "tensorflow/compiler/xla/literal_util.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_computation.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_module.h" | ||||
| #include "tensorflow/compiler/xla/service/hlo_sharding.h" | ||||
| #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" | ||||
| 
 | ||||
| namespace xla { | ||||
| namespace spmd { | ||||
| 
 | ||||
| // Returns true if the given sharding contains any replicated sharding.
 | ||||
| bool HasReplicatedSharding(const HloSharding& sharding); | ||||
| 
 | ||||
| // Creates zero value instructions of the given shape.
 | ||||
| HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); | ||||
| 
 | ||||
| template <typename NativeT> | ||||
| HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, | ||||
|                                  SpmdBuilder* b) { | ||||
|   auto literal = LiteralUtil::CreateR0(value) | ||||
|                      .ConvertToShape(ShapeUtil::MakeShape(type, {})) | ||||
|                      .ValueOrDie(); | ||||
|   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); | ||||
| } | ||||
| 
 | ||||
| // Create a binary add computation of the given type and add to the module.
 | ||||
| HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module); | ||||
| 
 | ||||
| // Returns true if the shape can be evenly partitioned for the given sharding.
 | ||||
| // All tile sharded dimensions should be evenly divisible and there should be no
 | ||||
| // single-device sharding. Replicate sharding is considered even partition.
 | ||||
| bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding); | ||||
| 
 | ||||
| // Returns the shard shape of the given shape when it is partitioned for the
 | ||||
| // target sharding.
 | ||||
| Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding); | ||||
| 
 | ||||
| // Returns the shard shape for a partition without padding due to uneven
 | ||||
| // sharding.
 | ||||
| Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, | ||||
|                                           const HloSharding& sharding, | ||||
|                                           int64 partition_id); | ||||
| 
 | ||||
| // Generates the HLO instructions that represent the dimension offsets on any
 | ||||
| // device. The size of the returned vector is the rank of the given shape.
 | ||||
| std::vector<HloInstruction*> MakePartitionOffsets(const Shape& shape, | ||||
|                                                   const HloSharding& sharding, | ||||
|                                                   HloInstruction* partition_id, | ||||
|                                                   SpmdBuilder* b); | ||||
| 
 | ||||
| // Returns the offsets of the partition in the tile assignment.
 | ||||
| std::vector<HloInstruction*> MakeTiledPartitionOrdinals( | ||||
|     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b); | ||||
| 
 | ||||
| // Pads hlo to the desired shape using high padding. Either a builder or a
 | ||||
| // computation needs to be supplied, but not both.
 | ||||
| HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, | ||||
|                            SpmdBuilder* b, | ||||
|                            HloComputation* computation = nullptr); | ||||
| 
 | ||||
| // Returns the padded shape when combining all partitions.
 | ||||
| Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, | ||||
|                                           const HloSharding& sharding); | ||||
| 
 | ||||
| // Pads the HLO (with base shape) for uneven tiled partition to make it evenly
 | ||||
| // partitionable.
 | ||||
| HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( | ||||
|     HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b); | ||||
| 
 | ||||
| // Returns the index of the unique tile dimension. Returns absl::nullopt if the
 | ||||
| // given sharding is not tiled or tiled along multiple dimensions.
 | ||||
| absl::optional<int64> UniqueTiledDim(const HloSharding& sharding); | ||||
| 
 | ||||
| // Utilities for symbolic offset calculation and halo exchange.
 | ||||
| class OffsetCalculation; | ||||
| 
 | ||||
| // Represents a calculation over integers:
 | ||||
| //   (shard_ordinal * multiplier + offset) / divisor
 | ||||
| class MultiplyAddDivideOffsetCalculation { | ||||
|  public: | ||||
|   MultiplyAddDivideOffsetCalculation() | ||||
|       : multiplier_(0), offset_(0), divisor_(1) {} | ||||
|   MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset, | ||||
|                                      int64 divisor); | ||||
| 
 | ||||
|   OffsetCalculation operator-( | ||||
|       const MultiplyAddDivideOffsetCalculation& other) const; | ||||
| 
 | ||||
|   bool operator==(const MultiplyAddDivideOffsetCalculation& other) const { | ||||
|     return multiplier_ == other.multiplier_ && offset_ == other.offset_ && | ||||
|            divisor_ == other.divisor_; | ||||
|   } | ||||
| 
 | ||||
|   bool IsConstant() const { return multiplier_ == 0; } | ||||
|   void Simplify(); | ||||
|   int64 Calculate(int64 shard_ordinal) const; | ||||
|   HloInstruction* Calculate(HloInstruction* shard_ordinal, | ||||
|                             SpmdBuilder* b) const; | ||||
| 
 | ||||
|   // Returns the maximum result for shard ordinals in the range
 | ||||
|   // [start_ordinal, limit_ordinal).
 | ||||
|   int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; | ||||
| 
 | ||||
|  private: | ||||
|   int64 multiplier_; | ||||
|   int64 offset_; | ||||
|   int64 divisor_; | ||||
| }; | ||||
| 
 | ||||
| // Represents a calculation over integers based on results of other calculations
 | ||||
| // defined by an opcode. If the opcode is kCopy, it simply wraps an
 | ||||
| // MultiplyAddDivideOffsetCalculation.
 | ||||
| class OffsetCalculation { | ||||
|  public: | ||||
|   OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {} | ||||
|   explicit OffsetCalculation( | ||||
|       const MultiplyAddDivideOffsetCalculation& copy_from) | ||||
|       : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {} | ||||
|   OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; } | ||||
|   OffsetCalculation(HloOpcode opcode, | ||||
|                     const MultiplyAddDivideOffsetCalculation& lhs, | ||||
|                     const MultiplyAddDivideOffsetCalculation& rhs) | ||||
|       : opcode_(opcode), | ||||
|         lhs_(absl::make_unique<OffsetCalculation>(lhs)), | ||||
|         rhs_(absl::make_unique<OffsetCalculation>(rhs)) {} | ||||
|   OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs, | ||||
|                     const OffsetCalculation& rhs) | ||||
|       : opcode_(opcode), | ||||
|         lhs_(absl::make_unique<OffsetCalculation>(lhs)), | ||||
|         rhs_(absl::make_unique<OffsetCalculation>(rhs)) {} | ||||
| 
 | ||||
|   OffsetCalculation& operator=(const OffsetCalculation& other); | ||||
| 
 | ||||
|   // Returns whether the calculation returns the same value for all shards. This
 | ||||
|   // is conservative and could return false even if it is actually constant.
 | ||||
|   bool IsConstant() const; | ||||
| 
 | ||||
|   OffsetCalculation operator-(const OffsetCalculation& other) const; | ||||
|   bool operator==(const OffsetCalculation& other) const; | ||||
|   int64 Calculate(int64 shard_ordinal) const; | ||||
|   HloInstruction* Calculate(HloInstruction* shard_ordinal, | ||||
|                             SpmdBuilder* b) const; | ||||
| 
 | ||||
|   // Returns the maximum result for shard ordinals in the range
 | ||||
|   // [start_ordinal, limit_ordinal).
 | ||||
|   int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; | ||||
| 
 | ||||
|  private: | ||||
|   HloOpcode opcode_; | ||||
|   std::unique_ptr<OffsetCalculation> lhs_; | ||||
|   std::unique_ptr<OffsetCalculation> rhs_; | ||||
|   MultiplyAddDivideOffsetCalculation copy_from_; | ||||
| }; | ||||
| 
 | ||||
| // Performs halo exchange on the given dimension based on the provided
 | ||||
| // left/right halo size functions. Returns nullopt if the halo is beyond the
 | ||||
| // direct neighbor of the shard.
 | ||||
| absl::optional<HloInstruction*> ExchangeHalo( | ||||
|     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, | ||||
|     const OffsetCalculation& right_halo_size_function, int64 dim, | ||||
|     const HloSharding& target, | ||||
|     const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|     int64* next_channel_id, SpmdBuilder* b); | ||||
| 
 | ||||
| // Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
 | ||||
| // dimensions fails to exchange halo (halo is beyond the neighbor shard).
 | ||||
| absl::optional<HloInstruction*> ExchangeHalo( | ||||
|     HloInstruction* hlo, | ||||
|     std::vector<OffsetCalculation> left_halo_size_functions, | ||||
|     std::vector<OffsetCalculation> right_halo_size_functions, | ||||
|     const HloSharding& target, | ||||
|     const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|     int64* next_channel_id, SpmdBuilder* b); | ||||
| 
 | ||||
| // Exchanges halos and performs pad/dynamic-slice on the concatenated data such
 | ||||
| // that the result starts with the first needed element on each shard. It also
 | ||||
| // masks off invalid data due to padding.
 | ||||
| // Arguments:
 | ||||
| //  hlo: the HLO op before halo exchange
 | ||||
| //  explicit_left_padding_on_full_shape: the amount of left padding to be added
 | ||||
| //   explicitly by this function on the base shape before partitioning. Without
 | ||||
| //   base dilation, this is usually set to the window's padding_low so that the
 | ||||
| //   sharded op do not need to add padding_low on the window; however, with base
 | ||||
| //   dilation, this could only be set to a custom size.
 | ||||
| //  padded_full_shape_size: the size of the padded full shape on the given
 | ||||
| //   dimension, which includes explicit_left_padding_on_full_shape and required
 | ||||
| //   right padding to make the shape evenly shardable.
 | ||||
| //  shard_size_with_halo: the shard size on the dimension after halo exchange.
 | ||||
| //   If different shards have different sizes, use the maximum size.
 | ||||
| //  offset_on_padded_shape: the offset HLO (S32) that represents the start of
 | ||||
| //   each shard on the padded full shape.
 | ||||
| //  pad_value: the padding value used on the full shape.
 | ||||
| absl::optional<HloInstruction*> ExchangeHaloAndGetValidData( | ||||
|     HloInstruction* hlo, const Shape& base_shape, | ||||
|     const OffsetCalculation& left_halo_size_function, | ||||
|     const OffsetCalculation& right_halo_size_function, | ||||
|     int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, | ||||
|     int64 shard_size_with_halo, int64 dim, const HloSharding& target, | ||||
|     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, | ||||
|     HloInstruction* partition_ordinal, | ||||
|     const SPMDCollectiveOpsCreator& collective_ops_creator, | ||||
|     int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true); | ||||
| 
 | ||||
| }  // namespace spmd
 | ||||
| }  // namespace xla
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
 | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user