diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4703a968991..f48ba3d0909 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -533,6 +533,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index feb17632439..7011dd79b1e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -34,6 +34,166 @@ limitations under the License. namespace xla { namespace hlo_sharding_util { +bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { + CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()); + if (lhs.IsTuple()) { + // For tuples we consider lhs to have a better sharding if none of the + // elements are worse and at least one element is better then in rhs + // sharding. + const auto& lhs_shardings = lhs.tuple_elements(); + const auto& rhs_shardings = rhs.tuple_elements(); + CHECK_EQ(lhs_shardings.size(), rhs_shardings.size()); + bool is_better = false; + for (int64 i = 0; i < lhs_shardings.size(); ++i) { + if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) { + return false; + } + if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) { + is_better = true; + } + } + return is_better; + } + if (!rhs.IsTileMaximal()) { + return lhs.NumTiles() > rhs.NumTiles(); + } else if (!rhs.IsReplicated()) { + // If we are not replicated then only tiled (not tile maximal) shardings + // can improve us. + return !lhs.IsTileMaximal(); + } else { + // If we are replicated then any non-replicated sharding can improve us. + return !lhs.IsReplicated(); + } +} + +bool MergeSharding(const HloSharding& old, HloSharding* to_merge, + bool may_combine_partial_sharding) { + if (old.IsTuple()) { + CHECK(to_merge->IsTuple()); + bool changed = false; + for (int64 i = 0; i < old.tuple_elements().size(); ++i) { + changed |= + MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i], + may_combine_partial_sharding); + } + return changed; + } + if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() || + !to_merge->ReplicateOnLastTileDim() || + old.tile_assignment().num_elements() != + to_merge->tile_assignment().num_elements()) { + return IsShardingMoreSpecific(*to_merge, old); + } + // Combine the tile dimension sizes from new and old. + int64 num_devices = old.tile_assignment().num_elements(); + std::vector new_tile_dims; + bool compatible = true; + new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); + for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { + int64 new_dim = to_merge->tile_assignment().dim(i); + int64 old_dim = old.tile_assignment().dim(i); + if (new_dim == 1) { + new_tile_dims.push_back(old_dim); + } else if (old_dim == 1) { + new_tile_dims.push_back(new_dim); + } else if (new_dim == old_dim) { + new_tile_dims.push_back(new_dim); + } else { + compatible = false; + break; + } + } + int64 replication = num_devices / Product(new_tile_dims); + if (!compatible || num_devices % Product(new_tile_dims) != 0 || + replication >= old.tile_assignment().dimensions().back()) { + return IsShardingMoreSpecific(*to_merge, old); + } + new_tile_dims.push_back(replication); + Array new_tile(new_tile_dims); + // Maps from replication group ID to sorted members. + absl::flat_hash_map> old_group_members; + absl::flat_hash_map> new_group_members; + auto get_group_index = [&](absl::Span tile_indices, + const HloSharding& sharding) { + int64 group_id = 0; + for (int64 i = 0; i < tile_indices.size() - 1; ++i) { + group_id *= to_merge->tile_assignment().dim(i); + group_id += tile_indices[i]; + } + return group_id; + }; + old.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + old_group_members[get_group_index(indices, old)].insert(device); + }); + to_merge->tile_assignment().Each( + [&](absl::Span indices, int64 device) { + new_group_members[get_group_index(indices, *to_merge)].insert(device); + }); + // Try to find the intersection of old and new replication groups, in + // order to determine the merged tile assignment. + new_tile.Each([&](absl::Span indices, int64* device) { + if (!compatible) { + return; + } + std::vector old_index(indices.begin(), indices.end()); + std::vector new_index = old_index; + for (int64 i = 0; i < indices.size() - 1; ++i) { + if (old.tile_assignment().dim(i) == 1) { + old_index[i] = 0; + } + if (to_merge->tile_assignment().dim(i) == 1) { + new_index[i] = 0; + } + } + int64 old_group_id = get_group_index(old_index, old); + int64 new_group_id = get_group_index(new_index, *to_merge); + if (old_group_members[old_group_id].empty() || + new_group_members[new_group_id].empty()) { + compatible = false; + return; + } + + int64 smallest_old = *old_group_members[old_group_id].begin(); + int64 smallest_new = *new_group_members[new_group_id].begin(); + if (smallest_old < smallest_new) { + if (old_group_members[old_group_id].count(smallest_new) == 0) { + compatible = false; + return; + } + *device = smallest_new; + } else { + if (new_group_members[new_group_id].count(smallest_old) == 0) { + compatible = false; + return; + } + *device = smallest_old; + } + old_group_members[old_group_id].erase(*device); + new_group_members[new_group_id].erase(*device); + }); + if (compatible) { + std::vector merged_metadata(std::move(to_merge->metadata())); + merged_metadata.reserve(merged_metadata.size() + old.metadata().size()); + const absl::flat_hash_set + metadata_set(merged_metadata.begin(), merged_metadata.end()); + absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata), + [&metadata_set](const OpMetadata& data) { + return !ContainsKey(metadata_set, data); + }); + if (replication == 1) { + new_tile_dims.pop_back(); + new_tile.Reshape(new_tile_dims); + *to_merge = HloSharding::Tile(new_tile, merged_metadata); + } else { + *to_merge = HloSharding::PartialTile(new_tile, merged_metadata); + } + return true; + } + return IsShardingMoreSpecific(*to_merge, old); +} + absl::optional SelectDominantDevice( const std::map& device_map, int64* top_count) { int64 device = 0; @@ -397,18 +557,28 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding, index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1); } + int64 partial_replication_size = 1; if (output_sharding.ReplicateOnLastTileDim()) { - index_tile_assignment_dims.push_back( - output_sharding.tile_assignment().dimensions().back()); + partial_replication_size *= + output_sharding.tile_assignment().dimensions().back(); } Array new_tile_assignment = output_sharding.tile_assignment(); - if (new_tile_assignment.num_elements() != - Product(index_tile_assignment_dims)) { - return HloSharding::Replicate(output_sharding.metadata()); + const int64 index_tile_elements = + Product(index_tile_assignment_dims) * partial_replication_size; + if (new_tile_assignment.num_elements() != index_tile_elements) { + if (new_tile_assignment.num_elements() % index_tile_elements == 0) { + partial_replication_size *= + (new_tile_assignment.num_elements() / index_tile_elements); + } else { + return HloSharding::Replicate(output_sharding.metadata()); + } + } + if (partial_replication_size > 1) { + index_tile_assignment_dims.push_back(partial_replication_size); } new_tile_assignment.Reshape(index_tile_assignment_dims); - return output_sharding.ReplicateOnLastTileDim() + return partial_replication_size > 1 ? HloSharding::PartialTile(new_tile_assignment, output_sharding.metadata()) : HloSharding::Tile(new_tile_assignment, @@ -722,17 +892,14 @@ absl::optional PassthroughGatherOutputOrScatterUpdateToOperand( // Collect data operand sharding for a gather with parallel dimensions from // the output. absl::optional GatherParallelDataOperandSharding( - const HloSharding& output_sharding, const HloInstruction& gather) { + const HloSharding& output_sharding, const HloInstruction& gather, + const GatherParallelDims& parallel_dims) { if (output_sharding.IsTileMaximal()) { return output_sharding; } - auto parallel_dims = GetGatherBatchParallelDims(gather); - if (!parallel_dims) { - return absl::nullopt; - } - auto output_parallel_dims = GatherParallelOutputDims(gather, *parallel_dims); + auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims); auto output_aligned_operand_parallel_dims = - GatherOutputAlignedOperandParallelDims(gather, *parallel_dims); + GatherOutputAlignedOperandParallelDims(gather, parallel_dims); const Shape gather_shape = gather.shape(); CHECK_EQ(output_parallel_dims.size(), output_aligned_operand_parallel_dims.size()); @@ -741,11 +908,6 @@ absl::optional GatherParallelDataOperandSharding( for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) { if (parallel_idx >= output_parallel_dims.size() || output_parallel_dims[parallel_idx] != i) { - // Support only the case where the output dimensions are sharded only - // across parallel dimensions. - if (output_sharding.tile_assignment().dim(i) != 1) { - return absl::nullopt; - } continue; } const int64 operand_dim = @@ -753,16 +915,27 @@ absl::optional GatherParallelDataOperandSharding( operand_tile_assignment[operand_dim] = output_sharding.tile_assignment().dim(i); } + int64 partially_replicated_size = 1; if (output_sharding.ReplicateOnLastTileDim()) { - operand_tile_assignment.push_back( - output_sharding.tile_assignment().dimensions().back()); + partially_replicated_size *= + output_sharding.tile_assignment().dimensions().back(); } Array tile_assignment = output_sharding.tile_assignment(); - if (tile_assignment.num_elements() != Product(operand_tile_assignment)) { - return absl::nullopt; + const int64 operand_tile_elements = + Product(operand_tile_assignment) * partially_replicated_size; + if (tile_assignment.num_elements() != operand_tile_elements) { + if (tile_assignment.num_elements() % operand_tile_elements == 0) { + partially_replicated_size *= + (tile_assignment.num_elements() / operand_tile_elements); + } else { + return absl::nullopt; + } + } + if (partially_replicated_size > 1) { + operand_tile_assignment.push_back(partially_replicated_size); } tile_assignment.Reshape(operand_tile_assignment); - return output_sharding.ReplicateOnLastTileDim() + return partially_replicated_size > 1 ? HloSharding::PartialTile(tile_assignment, output_sharding.metadata()) : HloSharding::Tile(tile_assignment, output_sharding.metadata()); @@ -771,7 +944,8 @@ absl::optional GatherParallelDataOperandSharding( } // namespace absl::optional GatherOutputShardingFromDataOperand( - const HloSharding& data_operand_sharding, const HloInstruction& hlo) { + const HloSharding& data_operand_sharding, const HloInstruction& hlo, + const Shape& output_shape, const Shape& operand_shape) { const auto& dnums = hlo.gather_dimension_numbers(); std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); @@ -780,9 +954,8 @@ absl::optional GatherOutputShardingFromDataOperand( std::vector offset_dims(dnums.offset_dims().begin(), dnums.offset_dims().end()); return PassthroughOperandToGatherOutputOrScatterUpdate( - hlo.operand(0)->shape(), data_operand_sharding, hlo.shape(), - collapsed_slice_dims, start_index_map, offset_dims, - hlo.gather_slice_sizes()); + operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims, + start_index_map, offset_dims, hlo.gather_slice_sizes()); } absl::optional GatherDataOperandShardingFromOutput( @@ -794,15 +967,41 @@ absl::optional GatherDataOperandShardingFromOutput( dnums.start_index_map().end()); std::vector offset_dims(dnums.offset_dims().begin(), dnums.offset_dims().end()); - // Prioritize parallel sharding first as this is how it is in - // spmd_partitioner. - if (auto parallel_sharding = - GatherParallelDataOperandSharding(hlo.sharding(), hlo)) { + + absl::optional parallel_sharding; + auto parallel_dims = GetGatherBatchParallelDims(hlo); + absl::Span operand_parallel_dims; + if (parallel_dims) { + // Prioritize parallel sharding first as this is how it is in + // spmd_partitioner. + parallel_sharding = + GatherParallelDataOperandSharding(hlo.sharding(), hlo, *parallel_dims); + operand_parallel_dims = parallel_dims->operand_parallel_dims; + } + HloSharding filtered_output_sharding = PartiallyReplicateTiledShardingOnDims( + output_sharding, operand_parallel_dims); + absl::optional passthrough_sharding = + PassthroughGatherOutputOrScatterUpdateToOperand( + hlo.operand(0)->shape(), filtered_output_sharding, + collapsed_slice_dims, start_index_map, offset_dims, + hlo.gather_slice_sizes()); + // Try to merge the two shardings or return the one that is present if only + // one of the two is. + if (!passthrough_sharding) { return parallel_sharding; } - return PassthroughGatherOutputOrScatterUpdateToOperand( - hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims, - start_index_map, offset_dims, hlo.gather_slice_sizes()); + if (!parallel_sharding) { + return passthrough_sharding; + } + if (MergeSharding(*parallel_sharding, &*passthrough_sharding, + /*may_combine_partial_sharding=*/true)) { + return passthrough_sharding; + } + if (MergeSharding(*passthrough_sharding, &*parallel_sharding, + /*may_combine_partial_sharding=*/true)) { + return parallel_sharding; + } + return absl::nullopt; } absl::optional ScatterOutputShardingFromUpdate( @@ -941,7 +1140,7 @@ std::vector DevicesForSharding( } HloSharding PartiallyReplicateTiledShardingOnDims( - const HloSharding& sharding, const std::vector& dims_to_replicate) { + const HloSharding& sharding, absl::Span dims_to_replicate) { if (sharding.IsTileMaximal()) { return sharding; } @@ -1119,8 +1318,13 @@ absl::optional GetGatherBatchParallelDims( if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) { return absl::nullopt; } - indices_parallel_dims.push_back(index_parallel_dim); - operand_parallel_dims.push_back(dnums.start_index_map(i)); + // Considered parallel only if the slice is of size 1 over the operand. + if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) { + indices_parallel_dims.push_back(index_parallel_dim); + operand_parallel_dims.push_back(dnums.start_index_map(i)); + } else { + index_parallel_in_dim[i] = -1; + } } absl::c_sort(indices_parallel_dims); if (!indices_parallel_dims.empty()) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h index b159ceb192a..c9e2c4c635e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -36,6 +36,19 @@ struct GatherParallelDims { std::vector index_parallel_in_dim; }; +// Returns true if the lhs sharding is preferable over the rhs sharding. +// The most specific sharding is tile maximal followed by single device tile +// maximal and finally replicated. This order aims to primarily reduce memory +// usage and secondly reduce total compute. +// Note: This does NOT provide a total ordering as we can have 2 different +// sharding with same preference level. +bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs); + +// Tries to refine `to_merge` by combining with `old`. Returns if the final +// `to_merge` is more specific than `old`. +bool MergeSharding(const HloSharding& old, HloSharding* to_merge, + bool may_combine_partial_sharding); + // Given a map, selects the device with higher // occurrence count (if any). If top_count in not nullptr, it will receive the // count of the dominant device returned. @@ -137,7 +150,8 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, // Returns an output sharding of gather by passing through the data operand's // sharding. absl::optional GatherOutputShardingFromDataOperand( - const HloSharding& data_operand_sharding, const HloInstruction& hlo); + const HloSharding& data_operand_sharding, const HloInstruction& hlo, + const Shape& output_shape, const Shape& operand_shape); // Returns a data operand sharding of gather by passing through the output's // sharding. @@ -173,7 +187,7 @@ std::vector DevicesForSharding( // Returns a sharding that replicates data across devices along the given // dimensions in the original sharding. HloSharding PartiallyReplicateTiledShardingOnDims( - const HloSharding& sharding, const std::vector& dims_to_replicate); + const HloSharding& sharding, absl::Span dims_to_replicate); // Returns a sharding the removes given tile dimensions. // diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index af338f5e1fa..c7ece50d202 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_split.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -65,174 +66,6 @@ bool IsSpatiallyPartitioned(const HloInstruction* hlo) { return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding()); } -// Returns true if the lhs sharding is preferable over the rhs sharding. -// The most specific sharding is tile maximal followed by single device tile -// maximal and finally replicated. This order aims to primarily reduce memory -// usage and secondly reduce total compute. -// Note: This does NOT provide a total ordering as we can have 2 different -// sharding with same preference level. -bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { - CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()); - if (lhs.IsTuple()) { - // For tuples we consider lhs to have a better sharding if none of the - // elements are worse and at least one element is better then in rhs - // sharding. - const auto& lhs_shardings = lhs.tuple_elements(); - const auto& rhs_shardings = rhs.tuple_elements(); - CHECK_EQ(lhs_shardings.size(), rhs_shardings.size()); - bool is_better = false; - for (int64 i = 0; i < lhs_shardings.size(); ++i) { - if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) { - return false; - } - if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) { - is_better = true; - } - } - return is_better; - } - if (!rhs.IsTileMaximal()) { - return lhs.NumTiles() > rhs.NumTiles(); - } else if (!rhs.IsReplicated()) { - // If we are not replicated then only tiled (not tile maximal) shardings - // can improve us. - return !lhs.IsTileMaximal(); - } else { - // If we are replicated then any non-replicated sharding can improve us. - return !lhs.IsReplicated(); - } -} - -// Tries to refine `to_merge` by combining with `old`. Returns if the final -// `to_merge` is more specific than `old`. -bool MergeSharding(const HloSharding& old, HloSharding* to_merge, - bool may_combine_partial_sharding) { - if (old.IsTuple()) { - CHECK(to_merge->IsTuple()); - bool changed = false; - for (int64 i = 0; i < old.tuple_elements().size(); ++i) { - changed |= - MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i], - may_combine_partial_sharding); - } - return changed; - } - if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() || - !to_merge->ReplicateOnLastTileDim() || - old.tile_assignment().num_elements() != - to_merge->tile_assignment().num_elements()) { - return IsShardingMoreSpecific(*to_merge, old); - } - // Combine the tile dimension sizes from new and old. - int64 num_devices = old.tile_assignment().num_elements(); - std::vector new_tile_dims; - bool compatible = true; - new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); - for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { - int64 new_dim = to_merge->tile_assignment().dim(i); - int64 old_dim = old.tile_assignment().dim(i); - if (new_dim == 1) { - new_tile_dims.push_back(old_dim); - } else if (old_dim == 1) { - new_tile_dims.push_back(new_dim); - } else if (new_dim == old_dim) { - new_tile_dims.push_back(new_dim); - } else { - compatible = false; - break; - } - } - int64 replication = num_devices / Product(new_tile_dims); - if (!compatible || num_devices % Product(new_tile_dims) != 0 || - replication >= old.tile_assignment().dimensions().back()) { - return IsShardingMoreSpecific(*to_merge, old); - } - new_tile_dims.push_back(replication); - Array new_tile(new_tile_dims); - // Maps from replication group ID to sorted members. - absl::flat_hash_map> old_group_members; - absl::flat_hash_map> new_group_members; - auto get_group_index = [&](absl::Span tile_indices, - const HloSharding& sharding) { - int64 group_id = 0; - for (int64 i = 0; i < tile_indices.size() - 1; ++i) { - group_id *= to_merge->tile_assignment().dim(i); - group_id += tile_indices[i]; - } - return group_id; - }; - old.tile_assignment().Each( - [&](absl::Span indices, int64 device) { - old_group_members[get_group_index(indices, old)].insert(device); - }); - to_merge->tile_assignment().Each( - [&](absl::Span indices, int64 device) { - new_group_members[get_group_index(indices, *to_merge)].insert(device); - }); - // Try to find the intersection of old and new replication groups, in - // order to determine the merged tile assignment. - new_tile.Each([&](absl::Span indices, int64* device) { - if (!compatible) { - return; - } - std::vector old_index(indices.begin(), indices.end()); - std::vector new_index = old_index; - for (int64 i = 0; i < indices.size() - 1; ++i) { - if (old.tile_assignment().dim(i) == 1) { - old_index[i] = 0; - } - if (to_merge->tile_assignment().dim(i) == 1) { - new_index[i] = 0; - } - } - int64 old_group_id = get_group_index(old_index, old); - int64 new_group_id = get_group_index(new_index, *to_merge); - if (old_group_members[old_group_id].empty() || - new_group_members[new_group_id].empty()) { - compatible = false; - return; - } - - int64 smallest_old = *old_group_members[old_group_id].begin(); - int64 smallest_new = *new_group_members[new_group_id].begin(); - if (smallest_old < smallest_new) { - if (old_group_members[old_group_id].count(smallest_new) == 0) { - compatible = false; - return; - } - *device = smallest_new; - } else { - if (new_group_members[new_group_id].count(smallest_old) == 0) { - compatible = false; - return; - } - *device = smallest_old; - } - old_group_members[old_group_id].erase(*device); - new_group_members[new_group_id].erase(*device); - }); - if (compatible) { - std::vector merged_metadata(std::move(to_merge->metadata())); - merged_metadata.reserve(merged_metadata.size() + old.metadata().size()); - const absl::flat_hash_set - metadata_set(merged_metadata.begin(), merged_metadata.end()); - absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata), - [&metadata_set](const OpMetadata& data) { - return !ContainsKey(metadata_set, data); - }); - if (replication == 1) { - new_tile_dims.pop_back(); - new_tile.Reshape(new_tile_dims); - *to_merge = HloSharding::Tile(new_tile, merged_metadata); - } else { - *to_merge = HloSharding::PartialTile(new_tile, merged_metadata); - } - return true; - } - return IsShardingMoreSpecific(*to_merge, old); -} - // Updates the sharding of the specified instruction with the specified sharding // if it is better than the current one and returns true if a new sharding have // been applied. If may_combine_partial_sharding is true, this may combine the @@ -251,8 +84,8 @@ bool MaybeImproveInstructionSharding(HloSharding sharding, return true; } int64 sharding_tiles = sharding.NumTiles(); - if (MergeSharding(instruction->sharding(), &sharding, - may_combine_partial_sharding)) { + if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding, + may_combine_partial_sharding)) { // Override existing tiled sharding only when the new sharding is compatible // with the existing one. This avoids unexpected resharding when `sharding` // just has more tiles than existing sharding but they are not mergeable. @@ -385,8 +218,8 @@ const HloInstruction* PickRepresentativeOperand( for (const HloInstruction* operand : instruction->operands()) { if (operand->has_sharding() && (best_operand == nullptr || - IsShardingMoreSpecific(operand->sharding(), - best_operand->sharding()))) { + hlo_sharding_util::IsShardingMoreSpecific( + operand->sharding(), best_operand->sharding()))) { best_operand = operand; } } @@ -823,7 +656,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (operand->shape().IsTuple()) { for (int64 i = 0, e = ShapeUtil::GetLeafCount(operand->shape()); i < e; ++i) { - if (IsShardingMoreSpecific( + if (hlo_sharding_util::IsShardingMoreSpecific( operand->sharding().tuple_elements()[i], sub_shardings[sub_sharding_index + i])) { sub_shardings[sub_sharding_index + i] = @@ -831,8 +664,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } } } else { - if (IsShardingMoreSpecific(operand->sharding(), - sub_shardings[sub_sharding_index])) { + if (hlo_sharding_util::IsShardingMoreSpecific( + operand->sharding(), sub_shardings[sub_sharding_index])) { sub_shardings[sub_sharding_index] = operand->sharding(); } } @@ -1120,6 +953,12 @@ bool InferShardingFromOperands(HloInstruction* instruction, } case HloOpcode::kGather: { bool changed = false; + if (IsSpatiallyPartitioned(instruction->operand(1))) { + HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( + instruction->operand(1)->sharding(), instruction); + changed |= MaybeImproveInstructionSharding( + std::move(new_sharding), instruction, may_combine_partial_sharding); + } if (is_spmd) { auto gather_parallel_dims = hlo_sharding_util::GetGatherBatchParallelDims(*instruction); @@ -1127,21 +966,24 @@ bool InferShardingFromOperands(HloInstruction* instruction, changed |= InferGatherParallelShardingFromOperands( instruction, *gather_parallel_dims, may_combine_partial_sharding); } - } - if (IsSpatiallyPartitioned(instruction->operand(1))) { - HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( - instruction->operand(1)->sharding(), instruction); - changed |= MaybeImproveInstructionSharding( - std::move(new_sharding), instruction, may_combine_partial_sharding); - } - if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { - auto maybe_from_data = - hlo_sharding_util::GatherOutputShardingFromDataOperand( - instruction->operand(0)->sharding(), *instruction); - if (maybe_from_data) { - changed |= MaybeImproveInstructionSharding( - std::move(*maybe_from_data), instruction, - may_combine_partial_sharding); + if (IsSpatiallyPartitioned(instruction->operand(0))) { + absl::Span operand_parallel_dims; + if (gather_parallel_dims) { + operand_parallel_dims = absl::MakeConstSpan( + gather_parallel_dims->operand_parallel_dims); + } + HloSharding filtered_operand_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + instruction->operand(0)->sharding(), operand_parallel_dims); + auto maybe_from_data = + hlo_sharding_util::GatherOutputShardingFromDataOperand( + filtered_operand_sharding, *instruction, instruction->shape(), + instruction->operand(0)->shape()); + if (maybe_from_data) { + changed |= MaybeImproveInstructionSharding( + std::move(*maybe_from_data), instruction, + may_combine_partial_sharding); + } } } return changed; @@ -1177,8 +1019,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } auto sharding = instruction->operand(0)->sharding(); if (instruction->has_sharding()) { - MergeSharding(instruction->sharding(), &sharding, - may_combine_partial_sharding); + hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding, + may_combine_partial_sharding); } return MaybeImproveInstructionSharding(std::move(sharding), instruction, may_combine_partial_sharding); @@ -1274,8 +1116,8 @@ HloSharding InferDotOperandSharding( *hlo_sharding_util::TransposeShardingWithCollapsedDims( other_operand_dims_replicated, other_to_operand_dims, operand_to_other_dims); - if (MergeSharding(sharding, &sharding_from_other, - may_combine_partial_sharding)) { + if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { sharding = std::move(sharding_from_other); } } diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index 85190ac41b4..cc425a24468 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -4573,5 +4573,164 @@ ENTRY %module { } } +TEST_P(ParameterizedMetadataTest, GatherParallelAndPassthroughMerged) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %module { + %arg0 = s32[4,8,2,2]{3,2,1,0} parameter(0) + %arg1 = s32[4]{0} parameter(1) + %input = s32[4,8,2,2]{3,2,1,0} copy(%arg0), + sharding={devices=[2,1,2,1]0,1,4,5 metadata={op_name="a"}} + %seq_size = s32[4]{0} copy(s32[4]{0} %arg1) + %seq_b = s32[1,4,8]{2,1,0} broadcast(s32[4]{0} %seq_size + ), dimensions={1} + %iota.11 = s32[1,4,8]{2,1,0} iota(), iota_dimension=1 + %concatenate = s32[2,4,8]{2,1,0} concatenate(s32[1,4,8]{2,1,0} %iota.11, + s32[1,4,8]{2,1,0} %seq_b), dimensions={0} + %gather = s32[4,8,2,2]{3,2,1,0} gather(s32[4,8,2,2]{3,2,1,0} %input, + s32[2,4,8]{2,1,0} %concatenate), offset_dims={2,3}, + collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0, + slice_sizes={1,1,2,2} + ROOT %copy = s32[4,8,2,2]{3,2,1,0} copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + EXPECT_TRUE(changed); + const HloInstruction* input = FindInstruction(module.get(), "input"); + ASSERT_NE(input, nullptr); + EXPECT_THAT(input, op::Sharding("{devices=[2,1,2,1]0,1,4,5 }")); + const HloInstruction* concatenate = + FindInstruction(module.get(), "concatenate"); + ASSERT_NE(concatenate, nullptr); + EXPECT_THAT( + concatenate, + op::Sharding("{devices=[1,2,1,2]0,1,4,5 last_tile_dim_replicate}")); + const HloInstruction* gather = FindInstruction(module.get(), "gather"); + ASSERT_NE(gather, nullptr); + EXPECT_THAT(gather, op::Sharding("{devices=[2,1,2,1]0,1,4,5}")); + + for (const HloInstruction* instruction : {input, gather}) { + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(instruction->sharding(), + ShardingMetadata({CreateMetadata("a")})); + } else { + EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); + } + } +} + +TEST_P(ParameterizedMetadataTest, GatherParallelAndTrivialMerged) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %module { + %arg0 = s32[4,8,2,2]{3,2,1,0} parameter(0) + %arg1 = s32[4]{0} parameter(1) + %input = s32[4,8,2,2]{3,2,1,0} copy(%arg0), + sharding={devices=[2,2,1,1]0,1,4,5 metadata={op_name="a"}} + %seq_size = s32[4]{0} copy(s32[4]{0} %arg1) + %seq_b = s32[1,4,1]{2,1,0} broadcast(s32[4]{0} %seq_size), dimensions={1} + %iota.11 = s32[1,4,1]{2,1,0} iota(), iota_dimension=1 + %concatenate = s32[2,4,1]{2,1,0} concatenate(s32[1,4,1]{2,1,0} %iota.11, + s32[1,4,1]{2,1,0} %seq_b), dimensions={0} + %gather = s32[4,1,2,2]{3,2,1,0} gather(s32[4,8,2,2]{3,2,1,0} %input, + s32[2,4,1]{2,1,0} %concatenate), offset_dims={2,3}, + collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0, + slice_sizes={1,1,2,2} + ROOT %copy = s32[4,1,2,2]{3,2,1,0} copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + EXPECT_TRUE(changed); + const HloInstruction* input = FindInstruction(module.get(), "input"); + ASSERT_NE(input, nullptr); + EXPECT_THAT(input, op::Sharding("{devices=[2,2,1,1]0,1,4,5}")); + const HloInstruction* concatenate = + FindInstruction(module.get(), "concatenate"); + ASSERT_NE(concatenate, nullptr); + EXPECT_THAT( + concatenate, + op::Sharding("{devices=[1,2,1,2]0,1,4,5 last_tile_dim_replicate}")); + const HloInstruction* gather = FindInstruction(module.get(), "gather"); + ASSERT_NE(gather, nullptr); + EXPECT_THAT( + gather, + op::Sharding("{devices=[2,1,1,1,2]0,1,4,5 last_tile_dim_replicate}")); + for (const HloInstruction* instruction : {input, gather}) { + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(instruction->sharding(), + ShardingMetadata({CreateMetadata("a")})); + } else { + EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); + } + } +} + +TEST_P(ParameterizedMetadataTest, + GatherParallelAndPassthroughMergedBackwardPass) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %module { + %arg0 = s32[4,8,2,2]{3,2,1,0} parameter(0) + %arg1 = s32[4]{0} parameter(1) + %input = s32[4,8,2,2]{3,2,1,0} copy(%arg0) + %seq_size = s32[4]{0} copy(s32[4]{0} %arg1) + %seq_b = s32[1,4,8]{2,1,0} broadcast(s32[4]{0} %seq_size + ), dimensions={1} + %iota.11 = s32[1,4,8]{2,1,0} iota(), iota_dimension=1 + %concatenate = s32[2,4,8]{2,1,0} concatenate(s32[1,4,8]{2,1,0} %iota.11, + s32[1,4,8]{2,1,0} %seq_b), dimensions={0} + %gather = s32[4,8,2,2]{3,2,1,0} gather(s32[4,8,2,2]{3,2,1,0} %input, + s32[2,4,8]{2,1,0} %concatenate), offset_dims={2,3}, + collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0, + slice_sizes={1,1,2,2}, + sharding={devices=[2,1,2,1]0,1,4,5 metadata={op_name="a"}} + ROOT %copy = s32[4,8,2,2]{3,2,1,0} copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + if (GetParam().clear_metadata) { + ClearMetadata(module.get()); + } + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata) + .Run(module.get())); + EXPECT_TRUE(changed); + const HloInstruction* input = FindInstruction(module.get(), "input"); + ASSERT_NE(input, nullptr); + EXPECT_THAT(input, op::Sharding("{devices=[2,1,2,1]0,1,4,5 }")); + const HloInstruction* concatenate = + FindInstruction(module.get(), "concatenate"); + ASSERT_NE(concatenate, nullptr); + EXPECT_THAT( + concatenate, + op::Sharding("{devices=[1,2,1,2]0,1,4,5 last_tile_dim_replicate}")); + const HloInstruction* gather = FindInstruction(module.get(), "gather"); + ASSERT_NE(gather, nullptr); + EXPECT_THAT(gather, op::Sharding("{devices=[2,1,2,1]0,1,4,5}")); + if (GetParam().propagate_metadata && !GetParam().clear_metadata) { + EXPECT_THAT(gather->sharding(), ShardingMetadata({CreateMetadata("a")})); + } else { + EXPECT_THAT(gather->sharding(), ShardingMetadata({})); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index 167db723435..7cd065b232b 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -55,6 +55,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/core:lib", "//tensorflow/core/platform:numbers", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc b/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc index af0c4c88642..ca1b0385ad1 100644 --- a/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/gather_scatter_handler.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { namespace spmd { @@ -118,6 +121,386 @@ IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( return {broadcast_min, broadcast_max}; } +// Function that tries to perform recursive partitioning of Gather. +StatusOr PartitionGather(const HloGatherInstruction* gather, + PartitionedHlo& operand, + PartitionedHlo& indices, + const Shape& output_shape, + const HloSharding& output_sharding, + absl::Span batch_dims, + SpmdPartitioningVisitor* visitor); + +// Perform partitioning of Gather when the indices are partitioned and +// the operand is replicated. +StatusOr PartitionIndexOnlyPartition( + const HloGatherInstruction* gather, absl::Span batch_dims, + PartitionedHlo& operand, PartitionedHlo& indices, SpmdBuilder* b) { + GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + if (operand.sharding().IsTileMaximal()) { + if (!indices.sharding().IsTileMaximal() && + (dnums.index_vector_dim() == indices.base_shape().rank() || + indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == + 1)) { + auto replicated_operand = operand.Replicate(); + TF_ASSIGN_OR_RETURN( + Shape partitioned_output_shape, + ShapeInference::InferGatherShape(replicated_operand.hlo()->shape(), + indices.hlo()->shape(), dnums, + gather->gather_slice_sizes())); + auto pgather = b->AddInstruction(gather->CloneWithNewOperands( + partitioned_output_shape, {replicated_operand.hlo(), indices.hlo()})); + std::vector output_dim_to_index_dim(pgather->shape().rank(), -1); + std::vector index_dim_to_output_dim(indices.base_shape().rank(), + -1); + for (int64 i = 0; i < batch_dims.size(); ++i) { + int64 indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1; + output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim; + index_dim_to_output_dim[indices_batch_dim] = batch_dims[i]; + } + auto pgather_sharding = + hlo_sharding_util::TransposeShardingWithCollapsedDims( + indices.sharding(), index_dim_to_output_dim, + output_dim_to_index_dim); + CHECK(pgather_sharding.has_value()); + pgather->set_sharding(*pgather_sharding); + VLOG(5) << "[Gather partitioning]: Partitioned as index only"; + return pgather; + } + } + return nullptr; +} + +// Perform partitioning of Gather when the operand is split in a offset +// dimension that is passed through (slice size is the same size of the operand +// dimension). +StatusOr ParititonPassthroughOperand( + const HloGatherInstruction* gather, Shape output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + PartitionedHlo& operand, PartitionedHlo& indices, + SpmdPartitioningVisitor* visitor) { + SpmdBuilder* b = visitor->builder(); + GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + if (auto maybe_passthrough = + hlo_sharding_util::GatherOutputShardingFromDataOperand( + operand.sharding(), *gather, output_shape, + operand.base_shape())) { + indices = indices.Reshard(HloSharding::Replicate()); + auto pshape = MakePartitionedShape(output_shape, *maybe_passthrough); + std::vector pslice_sizes(gather->gather_slice_sizes().begin(), + gather->gather_slice_sizes().end()); + for (int64 i = 0; i < pslice_sizes.size(); ++i) { + if (operand.sharding().tile_assignment().dim(i) > 1) { + pslice_sizes[i] = operand.hlo()->shape().dimensions(i); + } + } + auto pgather = b->AddInstruction(HloInstruction::CreateGather( + pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes, + gather->indices_are_sorted())); + pgather->set_sharding(*maybe_passthrough); + VLOG(5) << "[Gather partitioning]: Partitioned as operand passthrough " + "offset_dim"; + return PartitionedHlo(pgather, output_shape, operand.state()) + .Reshard(output_sharding) + .hlo(); + } + return nullptr; +} + +// Partition a Gather when its sliced in a dimension in the operand that is +// trivially sliced (sliced with slice size of 1). +StatusOr ParititonTrivialIndexedOperandDimension( + const HloGatherInstruction* gather, Shape output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + PartitionedHlo& operand, PartitionedHlo& indices, + SpmdPartitioningVisitor* visitor) { + SpmdBuilder* b = visitor->builder(); + GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, start_index_map, gather->gather_slice_sizes()) && + ShapeSizeInBytes(output_shape) < ShapeSizeInBytes(operand.base_shape())) { + indices = indices.Reshard(HloSharding::Replicate()); + // Now the operand is partitioned in trivial slice dimensions, and the + // indices are replicated. We execute a gather on partitioned operand, + // with full number of indices, where out-of-bounds indices are clamped, + // and masked out with 0 in the result; then we use all-reduce to combine + // results. Although gather will not get faster, we avoided the need to + // replicate the operand. + HloInstruction* indices_min; + HloInstruction* indices_max; + std::tie(indices_min, indices_max) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, operand.state().partition_id, start_index_map, + dnums.index_vector_dim(), b); + // Clamp the indices. + auto adjusted_indices = b->AddInstruction( + HloInstruction::CreateTernary(indices.base_shape(), HloOpcode::kClamp, + indices_min, indices.hlo(), indices_max)); + // Adjust the indices by subtracting the offset. + adjusted_indices = b->AddInstruction( + HloInstruction::CreateBinary(indices.base_shape(), HloOpcode::kSubtract, + adjusted_indices, indices_min)); + // Gather on adjusted indices. + auto pgather = b->AddInstruction(HloInstruction::CreateGather( + output_shape, operand.hlo(), adjusted_indices, dnums, + gather->gather_slice_sizes(), gather->indices_are_sorted())); + // Mask out invalid results. + auto filter = b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), indices.hlo(), + indices_min, ComparisonDirection::kLt)); + filter = b->AddInstruction(HloInstruction::CreateBinary( + filter->shape(), HloOpcode::kOr, filter, + b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_max, ComparisonDirection::kGt)))); + if (dnums.index_vector_dim() < indices.base_shape().rank()) { + std::vector reduced_filter_dims; + for (int64 i = 0; i < filter->shape().rank(); ++i) { + if (i != dnums.index_vector_dim()) { + reduced_filter_dims.push_back(filter->shape().dimensions(i)); + } + } + filter = b->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter, + CreateR0WithType(PRED, false, b), {dnums.index_vector_dim()}, + MakeBinaryAdd(PRED, indices.state().module))); + } + std::vector batch_dims; + for (int64 i = 0; i < pgather->shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.offset_dims(), i)) { + batch_dims.push_back(i); + } + } + auto broadcast_filter = b->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter, + batch_dims)); + + auto filtered = b->AddInstruction(HloInstruction::CreateTernary( + pgather->shape(), HloOpcode::kSelect, broadcast_filter, + CreateZero(pgather->shape(), b), pgather)); + // Combine from different partitions. + absl::InlinedVector replicated_dim; + if (operand.sharding().ReplicateOnLastTileDim()) { + replicated_dim.push_back( + operand.sharding().tile_assignment().num_dimensions() - 1); + } + auto sharding_grouped = + GroupShardingOnDims(operand.sharding(), replicated_dim); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + operand.state(), sharding_grouped.device_groups, b); + auto collective_ops_creator = + per_group_partitioner_state.collective_ops_creator; + auto ar = collective_ops_creator.create_cross_partition_all_reduce( + b, filtered, + MakeBinaryAdd(filtered->shape().element_type(), + per_group_partitioner_state.module), + {}, visitor->NewChannel()); + VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand " + "batch_dim slice"; + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, output_shape, operand.state()) + .Reshard(output_sharding) + .hlo(); + } + return nullptr; +} + +// Partition a gather over a indices dimensions that are cosidered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +StatusOr PartitionIndexParallelDimensions( + const HloGatherInstruction* gather, Shape output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + PartitionedHlo& operand, PartitionedHlo& indices, + SpmdPartitioningVisitor* visitor) { + absl::InlinedVector, 2> + top_level_sharding_to_reset; + auto cleaner = MakeCleanup([&top_level_sharding_to_reset] { + for (auto& to_reset : top_level_sharding_to_reset) { + to_reset.first->set_sharding(to_reset.second); + } + }); + SpmdBuilder* b = visitor->builder(); + GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + // Handle the case where operand is tile maximal. In this case we check if + // the index is not TileMaximal and in this case we use the index sharding + // to drive the output sharding. + if (absl::optional parallel_dims = + hlo_sharding_util::GetGatherBatchParallelDims(*gather)) { + if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims( + *operand.hlo(), *indices.hlo(), *parallel_dims)) { + auto indices_parallel_dims = parallel_dims->indices_parallel_dims; + auto operand_parallel_dims = parallel_dims->operand_parallel_dims; + auto output_parallel_dims = + hlo_sharding_util::GatherParallelOutputDims(*gather, *parallel_dims); + HloSharding indices_sharding = gather_sharding->indices_sharding; + HloSharding operand_sharding = gather_sharding->operand_sharding; + GroupedSharding grouped_indices = + GroupShardingOnDims(indices_sharding, indices_parallel_dims); + GroupedSharding grouped_operand = + GroupShardingOnDims(operand_sharding, operand_parallel_dims); + int index_dim = dnums.index_vector_dim(); + // Construct the required sharding for the new gather we are gonna form. + absl::InlinedVector output_tiling( + output_shape.dimensions_size(), 1); + for (int i = 0, num_output_parallel_dims = output_parallel_dims.size(); + i < num_output_parallel_dims; ++i) { + int output_idx = output_parallel_dims[i]; + int indices_idx = indices_parallel_dims[i]; + output_tiling[output_idx] = + indices_sharding.tile_assignment().dim(indices_idx); + } + operand = operand.Reshard(operand_sharding); + indices = indices.Reshard(indices_sharding); + if (indices_sharding.ReplicateOnLastTileDim()) { + output_tiling.push_back( + indices_sharding.tile_assignment().dimensions().back()); + } + Array output_tile_assignment = indices_sharding.tile_assignment(); + output_tile_assignment.Reshape(output_tiling); + // New gather tiling. + HloSharding gather_output_sharding = + indices_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(output_tile_assignment) + : HloSharding::Tile(output_tile_assignment); + // Shape of the partitioned gather + Shape pshape = MakePartitionedShape(output_shape, gather_output_sharding); + // Construct the offsets for the operand sharding to be used to adjust + // the indices. Because we know the only dimensions partitioned are the + // parallel ones and because the partitioning is the same across indices + // and operands we can apply the offsets on the operands on the indices. + std::vector operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand_sharding, operand.state().partition_id, + b, operand_parallel_dims); + absl::InlinedVector index_offsets; + for (int start_idx = 0; start_idx < dnums.start_index_map_size(); + ++start_idx) { + HloInstruction* index_offset = + indices.base_shape().dimensions_size() > index_dim + ? b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1}), + operand_offsets[dnums.start_index_map(start_idx)])) + : operand_offsets[dnums.start_index_map(start_idx)]; + index_offsets.push_back(index_offset); + } + HloInstruction* adjusted_indices = nullptr; + if (indices.base_shape().dimensions_size() > index_dim) { + // Concatenate the offsets for the parallel dimensions to subtract. + adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, + {indices.base_shape().dimensions(index_dim)}), + index_offsets, 0)); + } else { + CHECK_EQ(index_offsets.size(), 1); + adjusted_indices = index_offsets[0]; + } + if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(adjusted_indices->shape(), + indices.hlo()->shape().element_type()), + adjusted_indices)); + } + if (adjusted_indices->shape().rank() == 0) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {})); + } else { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {index_dim})); + } + // Adjust indices by subtracting the offsets based on the partition id. + adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + adjusted_indices)); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + operand.state(), grouped_operand.device_groups, b); + top_level_sharding_to_reset.emplace_back(operand.hlo(), + operand.sharding()); + adjusted_indices->set_sharding(grouped_indices.sharding); + operand.hlo()->set_sharding(grouped_operand.sharding); + VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; + HloInstruction* pgather; + if (operand_sharding.NumTiles() == + operand_sharding.NumTiles(operand_parallel_dims) && + indices_sharding.NumTiles() == + indices_sharding.NumTiles(indices_parallel_dims)) { + pgather = b->AddInstruction(HloInstruction::CreateGather( + pshape, operand.hlo(), adjusted_indices, dnums, + gather->gather_slice_sizes(), gather->indices_are_sorted())); + } else { + PartitionedHlo per_group_operand( + operand.hlo(), + GetPerGroupBaseShape(grouped_operand, operand.base_shape()), + per_group_partitioner_state); + PartitionedHlo per_group_indices( + adjusted_indices, + GetPerGroupBaseShape(grouped_indices, indices.base_shape()), + per_group_partitioner_state); + GroupedSharding grouped_output = + GroupShardingOnDims(gather_output_sharding, output_parallel_dims); + TF_ASSIGN_OR_RETURN(pgather, PartitionGather(gather, per_group_operand, + per_group_indices, pshape, + grouped_output.sharding, + batch_dims, visitor)); + } + if (pgather) { + pgather->set_sharding(gather_output_sharding); + return PartitionedHlo(pgather, output_shape, operand.state()) + .Reshard(output_sharding) + .hlo(); + } + } + } + return nullptr; +} + +StatusOr PartitionGather(const HloGatherInstruction* gather, + PartitionedHlo& operand, + PartitionedHlo& indices, + const Shape& output_shape, + const HloSharding& output_sharding, + absl::Span batch_dims, + SpmdPartitioningVisitor* visitor) { + absl::InlinedVector, 2> + top_level_sharding_to_reset; + auto cleaner = MakeCleanup([&top_level_sharding_to_reset] { + for (auto& to_reset : top_level_sharding_to_reset) { + to_reset.first->set_sharding(to_reset.second); + } + }); + HloInstruction* partitioned_gather; + // Check if we identify some of the dimensions of the gather as parallel and + // if we have sharded the operand and indices across those dimensions. + // If that's the case then we can partition the gather across such dimensions + // by adjusting the offsets. + TF_ASSIGN_OR_RETURN( + partitioned_gather, + PartitionIndexParallelDimensions(gather, output_shape, output_sharding, + batch_dims, operand, indices, visitor)); + if (partitioned_gather) { + return partitioned_gather; + } + // Pefrorm passthrough and trivial slice partitioning of the Gather. + if (!operand.sharding().IsTileMaximal()) { + TF_ASSIGN_OR_RETURN( + partitioned_gather, + ParititonPassthroughOperand(gather, output_shape, output_sharding, + batch_dims, operand, indices, visitor)); + if (partitioned_gather) { + return partitioned_gather; + } + TF_ASSIGN_OR_RETURN(partitioned_gather, + ParititonTrivialIndexedOperandDimension( + gather, output_shape, output_sharding, batch_dims, + operand, indices, visitor)); + if (partitioned_gather) { + return partitioned_gather; + } + } + return nullptr; +} + } // namespace Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { @@ -288,266 +671,33 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { const auto& dnums = gather->gather_dimension_numbers(); auto operand = GetPartitionedHlo(gather->operand(0)); auto indices = GetPartitionedHlo(gather->operand(1)); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); std::vector batch_dims; for (int64 i = 0; i < gather->shape().rank(); ++i) { if (!absl::c_linear_search(dnums.offset_dims(), i)) { batch_dims.push_back(i); } } - // Check if we identify some of the dimensions of the gather as parallel and - // if we have sharded the operand and indices across those dimensions. - // If that's the case then we can partition the gather across such dimensions - // by adjusting the offsets. - if (absl::optional parallel_dims = - hlo_sharding_util::GetGatherBatchParallelDims(*hlo)) { - if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims( - *operand.hlo(), *indices.hlo(), *parallel_dims)) { - auto indices_parallel_dims = parallel_dims->indices_parallel_dims; - auto operand_parallel_dims = parallel_dims->operand_parallel_dims; - auto output_parallel_dims = - hlo_sharding_util::GatherParallelOutputDims(*hlo, *parallel_dims); - HloSharding indices_sharding = gather_sharding->indices_sharding; - HloSharding operand_sharding = gather_sharding->operand_sharding; - if (operand_sharding.NumTiles() == - operand_sharding.NumTiles(operand_parallel_dims) && - indices_sharding.NumTiles() == - indices_sharding.NumTiles(indices_parallel_dims)) { - int index_dim = dnums.index_vector_dim(); - // Construct the required sharding for the new gather we are gonna form. - absl::InlinedVector output_tiling( - hlo->shape().dimensions_size(), 1); - for (int i = 0, num_output_parallel_dims = output_parallel_dims.size(); - i < num_output_parallel_dims; ++i) { - int output_idx = output_parallel_dims[i]; - int indices_idx = indices_parallel_dims[i]; - output_tiling[output_idx] = - indices_sharding.tile_assignment().dim(indices_idx); - } - operand = operand.Reshard(operand_sharding); - indices = indices.Reshard(indices_sharding); - if (indices_sharding.ReplicateOnLastTileDim()) { - output_tiling.push_back( - indices_sharding.tile_assignment().dimensions().back()); - } - Array output_tile_assignment = - indices_sharding.tile_assignment(); - output_tile_assignment.Reshape(output_tiling); - // New gather tiling. - HloSharding output_sharding = - indices_sharding.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(output_tile_assignment) - : HloSharding::Tile(output_tile_assignment); - // Shape of the partitioned gather - Shape pshape = MakePartitionedShape(gather->shape(), output_sharding); - // Construct the offsets for the operand sharding to be used to adjust - // the indices. Because we know the only dimensions partitioned are the - // parallel ones and because the partitioning is the same across indices - // and operands we can apply the offsets on the operands on the indices. - std::vector operand_offsets = MakePartitionOffsets( - operand.base_shape(), operand_sharding, partition_id_, &b_); - absl::InlinedVector index_offsets; - for (int start_idx = 0; start_idx < dnums.start_index_map_size(); - ++start_idx) { - HloInstruction* index_offset = - indices.base_shape().dimensions_size() > index_dim - ? b_.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - operand_offsets[dnums.start_index_map(start_idx)])) - : operand_offsets[dnums.start_index_map(start_idx)]; - index_offsets.push_back(index_offset); - } - HloInstruction* adjusted_indices = nullptr; - if (indices.base_shape().dimensions_size() > index_dim) { - // Concatenate the offsets for the parallel dimensions to subtract. - adjusted_indices = - b_.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape( - S32, {indices.base_shape().dimensions(index_dim)}), - index_offsets, 0)); - } else { - CHECK_EQ(index_offsets.size(), 1); - adjusted_indices = index_offsets[0]; - } - if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { - adjusted_indices = b_.AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - adjusted_indices->shape(), - indices.hlo()->shape().element_type()), - adjusted_indices)); - } - if (adjusted_indices->shape().rank() == 0) { - adjusted_indices = b_.AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {})); - } else { - adjusted_indices = b_.AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {index_dim})); - } - // Adjust indices by subtracting the offsets based on the partition id. - adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( - indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), - adjusted_indices)); - HloInstruction* pgather = - b_.AddInstruction(HloInstruction::CreateGather( - pshape, operand.hlo(), adjusted_indices, dnums, - gather->gather_slice_sizes(), gather->indices_are_sorted())); - pgather->set_sharding(output_sharding); - SetPartitionedHlo(hlo, [&]() { - return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } - } + + HloInstruction* pgather; + TF_ASSIGN_OR_RETURN(pgather, + PartitionGather(gather, operand, indices, gather->shape(), + gather->sharding(), + absl::MakeConstSpan(batch_dims), this)); + if (pgather) { + SetPartitionedHlo(gather, [pgather] { return pgather; }); + return Status::OK(); } - if (operand.sharding().IsTileMaximal()) { - if (!indices.sharding().IsTileMaximal() && - (dnums.index_vector_dim() == indices.base_shape().rank() || - indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == - 1)) { - auto replicated_operand = operand.Replicate(); - TF_ASSIGN_OR_RETURN( - Shape partitioned_output_shape, - ShapeInference::InferGatherShape(replicated_operand.hlo()->shape(), - indices.hlo()->shape(), dnums, - gather->gather_slice_sizes())); - auto pgather = b_.AddInstruction(gather->CloneWithNewOperands( - partitioned_output_shape, {replicated_operand.hlo(), indices.hlo()})); - std::vector output_dim_to_index_dim(pgather->shape().rank(), -1); - std::vector index_dim_to_output_dim(indices.base_shape().rank(), - -1); - for (int64 i = 0; i < batch_dims.size(); ++i) { - int64 indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1; - output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim; - index_dim_to_output_dim[indices_batch_dim] = batch_dims[i]; - } - auto pgather_sharding = - hlo_sharding_util::TransposeShardingWithCollapsedDims( - indices.sharding(), index_dim_to_output_dim, - output_dim_to_index_dim); - CHECK(pgather_sharding.has_value()); - pgather->set_sharding(*pgather_sharding); - SetPartitionedHlo(hlo, [&]() { - return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } - } else { - auto maybe_passthrough = - hlo_sharding_util::GatherOutputShardingFromDataOperand( - operand.sharding(), *hlo); - if (maybe_passthrough.has_value()) { - indices = indices.Reshard(HloSharding::Replicate()); - auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough); - std::vector pslice_sizes(gather->gather_slice_sizes().begin(), - gather->gather_slice_sizes().end()); - for (int64 i = 0; i < pslice_sizes.size(); ++i) { - if (operand.sharding().tile_assignment().dim(i) > 1) { - pslice_sizes[i] = operand.hlo()->shape().dimensions(i); - } - } - auto pgather = b_.AddInstruction(HloInstruction::CreateGather( - pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes, - gather->indices_are_sorted())); - pgather->set_sharding(*maybe_passthrough); - SetPartitionedHlo(hlo, [&]() { - return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } - if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( - operand, start_index_map, gather->gather_slice_sizes()) && - ShapeSizeInBytes(gather->shape()) < - ShapeSizeInBytes(gather->operand(0)->shape())) { - indices = indices.Reshard(HloSharding::Replicate()); - // Now the operand is partitioned in trivial slice dimensions, and the - // indices are replicated. We execute a gather on partitioned operand, - // with full number of indices, where out-of-bounds indices are clamped, - // and masked out with 0 in the result; then we use all-reduce to combine - // results. Although gather will not get faster, we avoided the need to - // replicate the operand. - HloInstruction* indices_min; - HloInstruction* indices_max; - std::tie(indices_min, indices_max) = - IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( - operand, indices, partition_id_, start_index_map, - dnums.index_vector_dim(), &b_); - // Clamp the indices. - auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary( - indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(), - indices_max)); - // Adjust the indices by subtracting the offset. - adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( - indices.base_shape(), HloOpcode::kSubtract, adjusted_indices, - indices_min)); - // Gather on adjusted indices. - auto pgather = b_.AddInstruction(HloInstruction::CreateGather( - gather->shape(), operand.hlo(), adjusted_indices, dnums, - gather->gather_slice_sizes(), gather->indices_are_sorted())); - // Mask out invalid results. - auto filter = b_.AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::ChangeElementType(indices.base_shape(), PRED), - indices.hlo(), indices_min, ComparisonDirection::kLt)); - filter = b_.AddInstruction(HloInstruction::CreateBinary( - filter->shape(), HloOpcode::kOr, filter, - b_.AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::ChangeElementType(indices.base_shape(), PRED), - indices.hlo(), indices_max, ComparisonDirection::kGt)))); - if (dnums.index_vector_dim() < indices.base_shape().rank()) { - std::vector reduced_filter_dims; - for (int64 i = 0; i < filter->shape().rank(); ++i) { - if (i != dnums.index_vector_dim()) { - reduced_filter_dims.push_back(filter->shape().dimensions(i)); - } - } - filter = b_.AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter, - CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()}, - MakeBinaryAdd(PRED, module_))); - } - std::vector batch_dims; - for (int64 i = 0; i < pgather->shape().rank(); ++i) { - if (!absl::c_linear_search(dnums.offset_dims(), i)) { - batch_dims.push_back(i); - } - } - auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter, - batch_dims)); - auto filtered = b_.AddInstruction(HloInstruction::CreateTernary( - pgather->shape(), HloOpcode::kSelect, broadcast_filter, - CreateZero(pgather->shape(), &b_), pgather)); - // Combine from different partitions. - auto collective_ops_creator = collective_ops_creator_; - if (operand.sharding().ReplicateOnLastTileDim()) { - auto sharding_grouped = GroupShardingOnDims( - operand.sharding(), - {operand.sharding().tile_assignment().num_dimensions() - 1}); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - operand.state(), sharding_grouped.device_groups, &b_); - collective_ops_creator = - per_group_partitioner_state.collective_ops_creator; - } - auto ar = collective_ops_creator.create_cross_partition_all_reduce( - &b_, filtered, - MakeBinaryAdd(filtered->shape().element_type(), module_), {}, - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - SetPartitionedHlo(hlo, [&]() { - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } + // Handle the case where operand is tile maximal. In this case we check if + // the index is not TileMaximal and in this case we use the index sharding + // to drive the output sharding. + TF_ASSIGN_OR_RETURN(pgather, PartitionIndexOnlyPartition( + gather, absl::MakeConstSpan(batch_dims), + operand, indices, &b_)); + if (pgather) { + SetPartitionedHlo(gather, [pgather] { return pgather; }); + return Status::OK(); } - return DefaultAction(hlo); + return DefaultAction(gather); } } // namespace spmd diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 849da445faf..eba37ae142c 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -6913,6 +6913,69 @@ ENTRY %module { op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _))); } +TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %module { + %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), + sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7} + %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2, + sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1, + sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota, + s32[1,8,4]{2,1,0} %iota2), dimensions={0}, + sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather( + s32[8,4,2,2]{3,2,1,0} %parameter.0, + s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3}, + collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0, + slice_sizes={1,1,2,2}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + auto root = module->entry_computation()->root_instruction(); + auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::DynamicSlice()); + auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract()); + auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices)); + EXPECT_THAT( + root, op::AllReduce(op::DynamicUpdateSlice( + _, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)), + _, _, _, _))); +} + +TEST_F(SpmdPartitioningTest, GatherMergedParallelIndexTrivialSlice) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY %module { + %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} + %parameter.1 = s32[1,8,1]{2,1,0} parameter(1), + sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %iota = s32[1,8,1]{2,1,0} iota(), iota_dimension=1, + sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %concatenate.19 = s32[2,8,1]{2,1,0} concatenate( + s32[1,8,1]{2,1,0} %parameter.1, s32[1,8,1]{2,1,0} %iota), dimensions={0}, + sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %gather.20 = s32[8,1,2,2]{3,2,1,0} gather( + s32[8,4,2,2]{3,2,1,0} %parameter.0, + s32[2,8,1]{2,1,0} %concatenate.19), offset_dims={2,3}, + collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0, + slice_sizes={1,1,2,2}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + auto root = module->entry_computation()->root_instruction(); + auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Parameter()); + auto indices = AllOf(op::Shape("s32[2,2,1]"), op::Subtract()); + auto gather = AllOf(op::Shape("s32[2,1,2,2]"), op::Gather(operand, indices)); + EXPECT_THAT(root, + op::AllReduce(op::DynamicUpdateSlice( + _, op::AllReduce(op::Select(_, _, gather)), _, _, _, _))); +} + } // namespace } // namespace spmd } // namespace xla