[XLA] Add support for concurrent partitioning of gather with multiple strategies.

Recursively repartition gather when parallel dimensions are split to take advantage
of additional partitioning support (like passthrough operand or trivial slice).

PiperOrigin-RevId: 354157012
Change-Id: I3815a4f0e56e8373cd213b28320a9479adb26c67
This commit is contained in:
Marcello Maggioni 2021-01-27 13:15:13 -08:00 committed by TensorFlower Gardener
parent 0d53639121
commit 56ecc5e516
8 changed files with 919 additions and 485 deletions

View File

@ -533,6 +533,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -34,6 +34,166 @@ limitations under the License.
namespace xla {
namespace hlo_sharding_util {
bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
CHECK_EQ(lhs.IsTuple(), rhs.IsTuple());
if (lhs.IsTuple()) {
// For tuples we consider lhs to have a better sharding if none of the
// elements are worse and at least one element is better then in rhs
// sharding.
const auto& lhs_shardings = lhs.tuple_elements();
const auto& rhs_shardings = rhs.tuple_elements();
CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
bool is_better = false;
for (int64 i = 0; i < lhs_shardings.size(); ++i) {
if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
return false;
}
if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
is_better = true;
}
}
return is_better;
}
if (!rhs.IsTileMaximal()) {
return lhs.NumTiles() > rhs.NumTiles();
} else if (!rhs.IsReplicated()) {
// If we are not replicated then only tiled (not tile maximal) shardings
// can improve us.
return !lhs.IsTileMaximal();
} else {
// If we are replicated then any non-replicated sharding can improve us.
return !lhs.IsReplicated();
}
}
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
bool may_combine_partial_sharding) {
if (old.IsTuple()) {
CHECK(to_merge->IsTuple());
bool changed = false;
for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
changed |=
MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
may_combine_partial_sharding);
}
return changed;
}
if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
!to_merge->ReplicateOnLastTileDim() ||
old.tile_assignment().num_elements() !=
to_merge->tile_assignment().num_elements()) {
return IsShardingMoreSpecific(*to_merge, old);
}
// Combine the tile dimension sizes from new and old.
int64 num_devices = old.tile_assignment().num_elements();
std::vector<int64> new_tile_dims;
bool compatible = true;
new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
int64 new_dim = to_merge->tile_assignment().dim(i);
int64 old_dim = old.tile_assignment().dim(i);
if (new_dim == 1) {
new_tile_dims.push_back(old_dim);
} else if (old_dim == 1) {
new_tile_dims.push_back(new_dim);
} else if (new_dim == old_dim) {
new_tile_dims.push_back(new_dim);
} else {
compatible = false;
break;
}
}
int64 replication = num_devices / Product(new_tile_dims);
if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
replication >= old.tile_assignment().dimensions().back()) {
return IsShardingMoreSpecific(*to_merge, old);
}
new_tile_dims.push_back(replication);
Array<int64> new_tile(new_tile_dims);
// Maps from replication group ID to sorted members.
absl::flat_hash_map<int64, std::set<int64>> old_group_members;
absl::flat_hash_map<int64, std::set<int64>> new_group_members;
auto get_group_index = [&](absl::Span<const int64> tile_indices,
const HloSharding& sharding) {
int64 group_id = 0;
for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
group_id *= to_merge->tile_assignment().dim(i);
group_id += tile_indices[i];
}
return group_id;
};
old.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
old_group_members[get_group_index(indices, old)].insert(device);
});
to_merge->tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
new_group_members[get_group_index(indices, *to_merge)].insert(device);
});
// Try to find the intersection of old and new replication groups, in
// order to determine the merged tile assignment.
new_tile.Each([&](absl::Span<const int64> indices, int64* device) {
if (!compatible) {
return;
}
std::vector<int64> old_index(indices.begin(), indices.end());
std::vector<int64> new_index = old_index;
for (int64 i = 0; i < indices.size() - 1; ++i) {
if (old.tile_assignment().dim(i) == 1) {
old_index[i] = 0;
}
if (to_merge->tile_assignment().dim(i) == 1) {
new_index[i] = 0;
}
}
int64 old_group_id = get_group_index(old_index, old);
int64 new_group_id = get_group_index(new_index, *to_merge);
if (old_group_members[old_group_id].empty() ||
new_group_members[new_group_id].empty()) {
compatible = false;
return;
}
int64 smallest_old = *old_group_members[old_group_id].begin();
int64 smallest_new = *new_group_members[new_group_id].begin();
if (smallest_old < smallest_new) {
if (old_group_members[old_group_id].count(smallest_new) == 0) {
compatible = false;
return;
}
*device = smallest_new;
} else {
if (new_group_members[new_group_id].count(smallest_old) == 0) {
compatible = false;
return;
}
*device = smallest_old;
}
old_group_members[old_group_id].erase(*device);
new_group_members[new_group_id].erase(*device);
});
if (compatible) {
std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
protobuf_util::ProtobufEqualsWrapper>
metadata_set(merged_metadata.begin(), merged_metadata.end());
absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
[&metadata_set](const OpMetadata& data) {
return !ContainsKey(metadata_set, data);
});
if (replication == 1) {
new_tile_dims.pop_back();
new_tile.Reshape(new_tile_dims);
*to_merge = HloSharding::Tile(new_tile, merged_metadata);
} else {
*to_merge = HloSharding::PartialTile(new_tile, merged_metadata);
}
return true;
}
return IsShardingMoreSpecific(*to_merge, old);
}
absl::optional<int64> SelectDominantDevice(
const std::map<int64, int64>& device_map, int64* top_count) {
int64 device = 0;
@ -397,18 +557,28 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding,
index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
}
int64 partial_replication_size = 1;
if (output_sharding.ReplicateOnLastTileDim()) {
index_tile_assignment_dims.push_back(
output_sharding.tile_assignment().dimensions().back());
partial_replication_size *=
output_sharding.tile_assignment().dimensions().back();
}
Array<int64> new_tile_assignment = output_sharding.tile_assignment();
if (new_tile_assignment.num_elements() !=
Product(index_tile_assignment_dims)) {
const int64 index_tile_elements =
Product(index_tile_assignment_dims) * partial_replication_size;
if (new_tile_assignment.num_elements() != index_tile_elements) {
if (new_tile_assignment.num_elements() % index_tile_elements == 0) {
partial_replication_size *=
(new_tile_assignment.num_elements() / index_tile_elements);
} else {
return HloSharding::Replicate(output_sharding.metadata());
}
}
if (partial_replication_size > 1) {
index_tile_assignment_dims.push_back(partial_replication_size);
}
new_tile_assignment.Reshape(index_tile_assignment_dims);
return output_sharding.ReplicateOnLastTileDim()
return partial_replication_size > 1
? HloSharding::PartialTile(new_tile_assignment,
output_sharding.metadata())
: HloSharding::Tile(new_tile_assignment,
@ -722,17 +892,14 @@ absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
// Collect data operand sharding for a gather with parallel dimensions from
// the output.
absl::optional<HloSharding> GatherParallelDataOperandSharding(
const HloSharding& output_sharding, const HloInstruction& gather) {
const HloSharding& output_sharding, const HloInstruction& gather,
const GatherParallelDims& parallel_dims) {
if (output_sharding.IsTileMaximal()) {
return output_sharding;
}
auto parallel_dims = GetGatherBatchParallelDims(gather);
if (!parallel_dims) {
return absl::nullopt;
}
auto output_parallel_dims = GatherParallelOutputDims(gather, *parallel_dims);
auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
auto output_aligned_operand_parallel_dims =
GatherOutputAlignedOperandParallelDims(gather, *parallel_dims);
GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
const Shape gather_shape = gather.shape();
CHECK_EQ(output_parallel_dims.size(),
output_aligned_operand_parallel_dims.size());
@ -741,11 +908,6 @@ absl::optional<HloSharding> GatherParallelDataOperandSharding(
for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
if (parallel_idx >= output_parallel_dims.size() ||
output_parallel_dims[parallel_idx] != i) {
// Support only the case where the output dimensions are sharded only
// across parallel dimensions.
if (output_sharding.tile_assignment().dim(i) != 1) {
return absl::nullopt;
}
continue;
}
const int64 operand_dim =
@ -753,16 +915,27 @@ absl::optional<HloSharding> GatherParallelDataOperandSharding(
operand_tile_assignment[operand_dim] =
output_sharding.tile_assignment().dim(i);
}
int64 partially_replicated_size = 1;
if (output_sharding.ReplicateOnLastTileDim()) {
operand_tile_assignment.push_back(
output_sharding.tile_assignment().dimensions().back());
partially_replicated_size *=
output_sharding.tile_assignment().dimensions().back();
}
Array<int64> tile_assignment = output_sharding.tile_assignment();
if (tile_assignment.num_elements() != Product(operand_tile_assignment)) {
const int64 operand_tile_elements =
Product(operand_tile_assignment) * partially_replicated_size;
if (tile_assignment.num_elements() != operand_tile_elements) {
if (tile_assignment.num_elements() % operand_tile_elements == 0) {
partially_replicated_size *=
(tile_assignment.num_elements() / operand_tile_elements);
} else {
return absl::nullopt;
}
}
if (partially_replicated_size > 1) {
operand_tile_assignment.push_back(partially_replicated_size);
}
tile_assignment.Reshape(operand_tile_assignment);
return output_sharding.ReplicateOnLastTileDim()
return partially_replicated_size > 1
? HloSharding::PartialTile(tile_assignment,
output_sharding.metadata())
: HloSharding::Tile(tile_assignment, output_sharding.metadata());
@ -771,7 +944,8 @@ absl::optional<HloSharding> GatherParallelDataOperandSharding(
} // namespace
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
const HloSharding& data_operand_sharding, const HloInstruction& hlo) {
const HloSharding& data_operand_sharding, const HloInstruction& hlo,
const Shape& output_shape, const Shape& operand_shape) {
const auto& dnums = hlo.gather_dimension_numbers();
std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
dnums.collapsed_slice_dims().end());
@ -780,9 +954,8 @@ absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
return PassthroughOperandToGatherOutputOrScatterUpdate(
hlo.operand(0)->shape(), data_operand_sharding, hlo.shape(),
collapsed_slice_dims, start_index_map, offset_dims,
hlo.gather_slice_sizes());
operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
start_index_map, offset_dims, hlo.gather_slice_sizes());
}
absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
@ -794,15 +967,41 @@ absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
absl::optional<HloSharding> parallel_sharding;
auto parallel_dims = GetGatherBatchParallelDims(hlo);
absl::Span<const int64> operand_parallel_dims;
if (parallel_dims) {
// Prioritize parallel sharding first as this is how it is in
// spmd_partitioner.
if (auto parallel_sharding =
GatherParallelDataOperandSharding(hlo.sharding(), hlo)) {
parallel_sharding =
GatherParallelDataOperandSharding(hlo.sharding(), hlo, *parallel_dims);
operand_parallel_dims = parallel_dims->operand_parallel_dims;
}
HloSharding filtered_output_sharding = PartiallyReplicateTiledShardingOnDims(
output_sharding, operand_parallel_dims);
absl::optional<HloSharding> passthrough_sharding =
PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.operand(0)->shape(), filtered_output_sharding,
collapsed_slice_dims, start_index_map, offset_dims,
hlo.gather_slice_sizes());
// Try to merge the two shardings or return the one that is present if only
// one of the two is.
if (!passthrough_sharding) {
return parallel_sharding;
}
return PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
start_index_map, offset_dims, hlo.gather_slice_sizes());
if (!parallel_sharding) {
return passthrough_sharding;
}
if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
/*may_combine_partial_sharding=*/true)) {
return passthrough_sharding;
}
if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
/*may_combine_partial_sharding=*/true)) {
return parallel_sharding;
}
return absl::nullopt;
}
absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
@ -941,7 +1140,7 @@ std::vector<int64> DevicesForSharding(
}
HloSharding PartiallyReplicateTiledShardingOnDims(
const HloSharding& sharding, const std::vector<int64>& dims_to_replicate) {
const HloSharding& sharding, absl::Span<const int64> dims_to_replicate) {
if (sharding.IsTileMaximal()) {
return sharding;
}
@ -1119,8 +1318,13 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
return absl::nullopt;
}
// Considered parallel only if the slice is of size 1 over the operand.
if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
indices_parallel_dims.push_back(index_parallel_dim);
operand_parallel_dims.push_back(dnums.start_index_map(i));
} else {
index_parallel_in_dim[i] = -1;
}
}
absl::c_sort(indices_parallel_dims);
if (!indices_parallel_dims.empty()) {

View File

@ -36,6 +36,19 @@ struct GatherParallelDims {
std::vector<int64> index_parallel_in_dim;
};
// Returns true if the lhs sharding is preferable over the rhs sharding.
// The most specific sharding is tile maximal followed by single device tile
// maximal and finally replicated. This order aims to primarily reduce memory
// usage and secondly reduce total compute.
// Note: This does NOT provide a total ordering as we can have 2 different
// sharding with same preference level.
bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs);
// Tries to refine `to_merge` by combining with `old`. Returns if the final
// `to_merge` is more specific than `old`.
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
bool may_combine_partial_sharding);
// Given a map<device, occurrence_count>, selects the device with higher
// occurrence count (if any). If top_count in not nullptr, it will receive the
// count of the dominant device returned.
@ -137,7 +150,8 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
// Returns an output sharding of gather by passing through the data operand's
// sharding.
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
const HloSharding& data_operand_sharding, const HloInstruction& hlo);
const HloSharding& data_operand_sharding, const HloInstruction& hlo,
const Shape& output_shape, const Shape& operand_shape);
// Returns a data operand sharding of gather by passing through the output's
// sharding.
@ -173,7 +187,7 @@ std::vector<int64> DevicesForSharding(
// Returns a sharding that replicates data across devices along the given
// dimensions in the original sharding.
HloSharding PartiallyReplicateTiledShardingOnDims(
const HloSharding& sharding, const std::vector<int64>& dims_to_replicate);
const HloSharding& sharding, absl::Span<const int64> dims_to_replicate);
// Returns a sharding the removes given tile dimensions.
//

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_split.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -65,174 +66,6 @@ bool IsSpatiallyPartitioned(const HloInstruction* hlo) {
return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding());
}
// Returns true if the lhs sharding is preferable over the rhs sharding.
// The most specific sharding is tile maximal followed by single device tile
// maximal and finally replicated. This order aims to primarily reduce memory
// usage and secondly reduce total compute.
// Note: This does NOT provide a total ordering as we can have 2 different
// sharding with same preference level.
bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
CHECK_EQ(lhs.IsTuple(), rhs.IsTuple());
if (lhs.IsTuple()) {
// For tuples we consider lhs to have a better sharding if none of the
// elements are worse and at least one element is better then in rhs
// sharding.
const auto& lhs_shardings = lhs.tuple_elements();
const auto& rhs_shardings = rhs.tuple_elements();
CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
bool is_better = false;
for (int64 i = 0; i < lhs_shardings.size(); ++i) {
if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
return false;
}
if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
is_better = true;
}
}
return is_better;
}
if (!rhs.IsTileMaximal()) {
return lhs.NumTiles() > rhs.NumTiles();
} else if (!rhs.IsReplicated()) {
// If we are not replicated then only tiled (not tile maximal) shardings
// can improve us.
return !lhs.IsTileMaximal();
} else {
// If we are replicated then any non-replicated sharding can improve us.
return !lhs.IsReplicated();
}
}
// Tries to refine `to_merge` by combining with `old`. Returns if the final
// `to_merge` is more specific than `old`.
bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
bool may_combine_partial_sharding) {
if (old.IsTuple()) {
CHECK(to_merge->IsTuple());
bool changed = false;
for (int64 i = 0; i < old.tuple_elements().size(); ++i) {
changed |=
MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
may_combine_partial_sharding);
}
return changed;
}
if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
!to_merge->ReplicateOnLastTileDim() ||
old.tile_assignment().num_elements() !=
to_merge->tile_assignment().num_elements()) {
return IsShardingMoreSpecific(*to_merge, old);
}
// Combine the tile dimension sizes from new and old.
int64 num_devices = old.tile_assignment().num_elements();
std::vector<int64> new_tile_dims;
bool compatible = true;
new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions());
for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) {
int64 new_dim = to_merge->tile_assignment().dim(i);
int64 old_dim = old.tile_assignment().dim(i);
if (new_dim == 1) {
new_tile_dims.push_back(old_dim);
} else if (old_dim == 1) {
new_tile_dims.push_back(new_dim);
} else if (new_dim == old_dim) {
new_tile_dims.push_back(new_dim);
} else {
compatible = false;
break;
}
}
int64 replication = num_devices / Product(new_tile_dims);
if (!compatible || num_devices % Product(new_tile_dims) != 0 ||
replication >= old.tile_assignment().dimensions().back()) {
return IsShardingMoreSpecific(*to_merge, old);
}
new_tile_dims.push_back(replication);
Array<int64> new_tile(new_tile_dims);
// Maps from replication group ID to sorted members.
absl::flat_hash_map<int64, std::set<int64>> old_group_members;
absl::flat_hash_map<int64, std::set<int64>> new_group_members;
auto get_group_index = [&](absl::Span<const int64> tile_indices,
const HloSharding& sharding) {
int64 group_id = 0;
for (int64 i = 0; i < tile_indices.size() - 1; ++i) {
group_id *= to_merge->tile_assignment().dim(i);
group_id += tile_indices[i];
}
return group_id;
};
old.tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
old_group_members[get_group_index(indices, old)].insert(device);
});
to_merge->tile_assignment().Each(
[&](absl::Span<const int64> indices, int64 device) {
new_group_members[get_group_index(indices, *to_merge)].insert(device);
});
// Try to find the intersection of old and new replication groups, in
// order to determine the merged tile assignment.
new_tile.Each([&](absl::Span<const int64> indices, int64* device) {
if (!compatible) {
return;
}
std::vector<int64> old_index(indices.begin(), indices.end());
std::vector<int64> new_index = old_index;
for (int64 i = 0; i < indices.size() - 1; ++i) {
if (old.tile_assignment().dim(i) == 1) {
old_index[i] = 0;
}
if (to_merge->tile_assignment().dim(i) == 1) {
new_index[i] = 0;
}
}
int64 old_group_id = get_group_index(old_index, old);
int64 new_group_id = get_group_index(new_index, *to_merge);
if (old_group_members[old_group_id].empty() ||
new_group_members[new_group_id].empty()) {
compatible = false;
return;
}
int64 smallest_old = *old_group_members[old_group_id].begin();
int64 smallest_new = *new_group_members[new_group_id].begin();
if (smallest_old < smallest_new) {
if (old_group_members[old_group_id].count(smallest_new) == 0) {
compatible = false;
return;
}
*device = smallest_new;
} else {
if (new_group_members[new_group_id].count(smallest_old) == 0) {
compatible = false;
return;
}
*device = smallest_old;
}
old_group_members[old_group_id].erase(*device);
new_group_members[new_group_id].erase(*device);
});
if (compatible) {
std::vector<OpMetadata> merged_metadata(std::move(to_merge->metadata()));
merged_metadata.reserve(merged_metadata.size() + old.metadata().size());
const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
protobuf_util::ProtobufEqualsWrapper>
metadata_set(merged_metadata.begin(), merged_metadata.end());
absl::c_copy_if(old.metadata(), std::back_inserter(merged_metadata),
[&metadata_set](const OpMetadata& data) {
return !ContainsKey(metadata_set, data);
});
if (replication == 1) {
new_tile_dims.pop_back();
new_tile.Reshape(new_tile_dims);
*to_merge = HloSharding::Tile(new_tile, merged_metadata);
} else {
*to_merge = HloSharding::PartialTile(new_tile, merged_metadata);
}
return true;
}
return IsShardingMoreSpecific(*to_merge, old);
}
// Updates the sharding of the specified instruction with the specified sharding
// if it is better than the current one and returns true if a new sharding have
// been applied. If may_combine_partial_sharding is true, this may combine the
@ -251,7 +84,7 @@ bool MaybeImproveInstructionSharding(HloSharding sharding,
return true;
}
int64 sharding_tiles = sharding.NumTiles();
if (MergeSharding(instruction->sharding(), &sharding,
if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
may_combine_partial_sharding)) {
// Override existing tiled sharding only when the new sharding is compatible
// with the existing one. This avoids unexpected resharding when `sharding`
@ -385,8 +218,8 @@ const HloInstruction* PickRepresentativeOperand(
for (const HloInstruction* operand : instruction->operands()) {
if (operand->has_sharding() &&
(best_operand == nullptr ||
IsShardingMoreSpecific(operand->sharding(),
best_operand->sharding()))) {
hlo_sharding_util::IsShardingMoreSpecific(
operand->sharding(), best_operand->sharding()))) {
best_operand = operand;
}
}
@ -823,7 +656,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
if (operand->shape().IsTuple()) {
for (int64 i = 0, e = ShapeUtil::GetLeafCount(operand->shape());
i < e; ++i) {
if (IsShardingMoreSpecific(
if (hlo_sharding_util::IsShardingMoreSpecific(
operand->sharding().tuple_elements()[i],
sub_shardings[sub_sharding_index + i])) {
sub_shardings[sub_sharding_index + i] =
@ -831,8 +664,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
}
}
} else {
if (IsShardingMoreSpecific(operand->sharding(),
sub_shardings[sub_sharding_index])) {
if (hlo_sharding_util::IsShardingMoreSpecific(
operand->sharding(), sub_shardings[sub_sharding_index])) {
sub_shardings[sub_sharding_index] = operand->sharding();
}
}
@ -1120,6 +953,12 @@ bool InferShardingFromOperands(HloInstruction* instruction,
}
case HloOpcode::kGather: {
bool changed = false;
if (IsSpatiallyPartitioned(instruction->operand(1))) {
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);
changed |= MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
if (is_spmd) {
auto gather_parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
@ -1127,23 +966,26 @@ bool InferShardingFromOperands(HloInstruction* instruction,
changed |= InferGatherParallelShardingFromOperands(
instruction, *gather_parallel_dims, may_combine_partial_sharding);
}
if (IsSpatiallyPartitioned(instruction->operand(0))) {
absl::Span<const int64> operand_parallel_dims;
if (gather_parallel_dims) {
operand_parallel_dims = absl::MakeConstSpan(
gather_parallel_dims->operand_parallel_dims);
}
if (IsSpatiallyPartitioned(instruction->operand(1))) {
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);
changed |= MaybeImproveInstructionSharding(
std::move(new_sharding), instruction, may_combine_partial_sharding);
}
if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) {
HloSharding filtered_operand_sharding =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
instruction->operand(0)->sharding(), operand_parallel_dims);
auto maybe_from_data =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
instruction->operand(0)->sharding(), *instruction);
filtered_operand_sharding, *instruction, instruction->shape(),
instruction->operand(0)->shape());
if (maybe_from_data) {
changed |= MaybeImproveInstructionSharding(
std::move(*maybe_from_data), instruction,
may_combine_partial_sharding);
}
}
}
return changed;
}
case HloOpcode::kScatter: {
@ -1177,7 +1019,7 @@ bool InferShardingFromOperands(HloInstruction* instruction,
}
auto sharding = instruction->operand(0)->sharding();
if (instruction->has_sharding()) {
MergeSharding(instruction->sharding(), &sharding,
hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
may_combine_partial_sharding);
}
return MaybeImproveInstructionSharding(std::move(sharding), instruction,
@ -1274,7 +1116,7 @@ HloSharding InferDotOperandSharding(
*hlo_sharding_util::TransposeShardingWithCollapsedDims(
other_operand_dims_replicated, other_to_operand_dims,
operand_to_other_dims);
if (MergeSharding(sharding, &sharding_from_other,
if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other,
may_combine_partial_sharding)) {
sharding = std::move(sharding_from_other);
}

View File

@ -4573,5 +4573,164 @@ ENTRY %module {
}
}
TEST_P(ParameterizedMetadataTest, GatherParallelAndPassthroughMerged) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY %module {
%arg0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
%arg1 = s32[4]{0} parameter(1)
%input = s32[4,8,2,2]{3,2,1,0} copy(%arg0),
sharding={devices=[2,1,2,1]0,1,4,5 metadata={op_name="a"}}
%seq_size = s32[4]{0} copy(s32[4]{0} %arg1)
%seq_b = s32[1,4,8]{2,1,0} broadcast(s32[4]{0} %seq_size
), dimensions={1}
%iota.11 = s32[1,4,8]{2,1,0} iota(), iota_dimension=1
%concatenate = s32[2,4,8]{2,1,0} concatenate(s32[1,4,8]{2,1,0} %iota.11,
s32[1,4,8]{2,1,0} %seq_b), dimensions={0}
%gather = s32[4,8,2,2]{3,2,1,0} gather(s32[4,8,2,2]{3,2,1,0} %input,
s32[2,4,8]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}
ROOT %copy = s32[4,8,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
.Run(module.get()));
EXPECT_TRUE(changed);
const HloInstruction* input = FindInstruction(module.get(), "input");
ASSERT_NE(input, nullptr);
EXPECT_THAT(input, op::Sharding("{devices=[2,1,2,1]0,1,4,5 }"));
const HloInstruction* concatenate =
FindInstruction(module.get(), "concatenate");
ASSERT_NE(concatenate, nullptr);
EXPECT_THAT(
concatenate,
op::Sharding("{devices=[1,2,1,2]0,1,4,5 last_tile_dim_replicate}"));
const HloInstruction* gather = FindInstruction(module.get(), "gather");
ASSERT_NE(gather, nullptr);
EXPECT_THAT(gather, op::Sharding("{devices=[2,1,2,1]0,1,4,5}"));
for (const HloInstruction* instruction : {input, gather}) {
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(instruction->sharding(),
ShardingMetadata({CreateMetadata("a")}));
} else {
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
}
}
}
TEST_P(ParameterizedMetadataTest, GatherParallelAndTrivialMerged) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY %module {
%arg0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
%arg1 = s32[4]{0} parameter(1)
%input = s32[4,8,2,2]{3,2,1,0} copy(%arg0),
sharding={devices=[2,2,1,1]0,1,4,5 metadata={op_name="a"}}
%seq_size = s32[4]{0} copy(s32[4]{0} %arg1)
%seq_b = s32[1,4,1]{2,1,0} broadcast(s32[4]{0} %seq_size), dimensions={1}
%iota.11 = s32[1,4,1]{2,1,0} iota(), iota_dimension=1
%concatenate = s32[2,4,1]{2,1,0} concatenate(s32[1,4,1]{2,1,0} %iota.11,
s32[1,4,1]{2,1,0} %seq_b), dimensions={0}
%gather = s32[4,1,2,2]{3,2,1,0} gather(s32[4,8,2,2]{3,2,1,0} %input,
s32[2,4,1]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}
ROOT %copy = s32[4,1,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
.Run(module.get()));
EXPECT_TRUE(changed);
const HloInstruction* input = FindInstruction(module.get(), "input");
ASSERT_NE(input, nullptr);
EXPECT_THAT(input, op::Sharding("{devices=[2,2,1,1]0,1,4,5}"));
const HloInstruction* concatenate =
FindInstruction(module.get(), "concatenate");
ASSERT_NE(concatenate, nullptr);
EXPECT_THAT(
concatenate,
op::Sharding("{devices=[1,2,1,2]0,1,4,5 last_tile_dim_replicate}"));
const HloInstruction* gather = FindInstruction(module.get(), "gather");
ASSERT_NE(gather, nullptr);
EXPECT_THAT(
gather,
op::Sharding("{devices=[2,1,1,1,2]0,1,4,5 last_tile_dim_replicate}"));
for (const HloInstruction* instruction : {input, gather}) {
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(instruction->sharding(),
ShardingMetadata({CreateMetadata("a")}));
} else {
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
}
}
}
TEST_P(ParameterizedMetadataTest,
GatherParallelAndPassthroughMergedBackwardPass) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY %module {
%arg0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
%arg1 = s32[4]{0} parameter(1)
%input = s32[4,8,2,2]{3,2,1,0} copy(%arg0)
%seq_size = s32[4]{0} copy(s32[4]{0} %arg1)
%seq_b = s32[1,4,8]{2,1,0} broadcast(s32[4]{0} %seq_size
), dimensions={1}
%iota.11 = s32[1,4,8]{2,1,0} iota(), iota_dimension=1
%concatenate = s32[2,4,8]{2,1,0} concatenate(s32[1,4,8]{2,1,0} %iota.11,
s32[1,4,8]{2,1,0} %seq_b), dimensions={0}
%gather = s32[4,8,2,2]{3,2,1,0} gather(s32[4,8,2,2]{3,2,1,0} %input,
s32[2,4,8]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2},
sharding={devices=[2,1,2,1]0,1,4,5 metadata={op_name="a"}}
ROOT %copy = s32[4,8,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/true, GetParam().propagate_metadata)
.Run(module.get()));
EXPECT_TRUE(changed);
const HloInstruction* input = FindInstruction(module.get(), "input");
ASSERT_NE(input, nullptr);
EXPECT_THAT(input, op::Sharding("{devices=[2,1,2,1]0,1,4,5 }"));
const HloInstruction* concatenate =
FindInstruction(module.get(), "concatenate");
ASSERT_NE(concatenate, nullptr);
EXPECT_THAT(
concatenate,
op::Sharding("{devices=[1,2,1,2]0,1,4,5 last_tile_dim_replicate}"));
const HloInstruction* gather = FindInstruction(module.get(), "gather");
ASSERT_NE(gather, nullptr);
EXPECT_THAT(gather, op::Sharding("{devices=[2,1,2,1]0,1,4,5}"));
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(gather->sharding(), ShardingMetadata({CreateMetadata("a")}));
} else {
EXPECT_THAT(gather->sharding(), ShardingMetadata({}));
}
}
} // namespace
} // namespace xla

View File

@ -55,6 +55,7 @@ cc_library(
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/core:lib",
"//tensorflow/core/platform:numbers",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace spmd {
@ -118,6 +121,386 @@ IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
return {broadcast_min, broadcast_max};
}
// Function that tries to perform recursive partitioning of Gather.
StatusOr<HloInstruction*> PartitionGather(const HloGatherInstruction* gather,
PartitionedHlo& operand,
PartitionedHlo& indices,
const Shape& output_shape,
const HloSharding& output_sharding,
absl::Span<const int64> batch_dims,
SpmdPartitioningVisitor* visitor);
// Perform partitioning of Gather when the indices are partitioned and
// the operand is replicated.
StatusOr<HloInstruction*> PartitionIndexOnlyPartition(
const HloGatherInstruction* gather, absl::Span<const int64> batch_dims,
PartitionedHlo& operand, PartitionedHlo& indices, SpmdBuilder* b) {
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
if (operand.sharding().IsTileMaximal()) {
if (!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||
indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) ==
1)) {
auto replicated_operand = operand.Replicate();
TF_ASSIGN_OR_RETURN(
Shape partitioned_output_shape,
ShapeInference::InferGatherShape(replicated_operand.hlo()->shape(),
indices.hlo()->shape(), dnums,
gather->gather_slice_sizes()));
auto pgather = b->AddInstruction(gather->CloneWithNewOperands(
partitioned_output_shape, {replicated_operand.hlo(), indices.hlo()}));
std::vector<int64> output_dim_to_index_dim(pgather->shape().rank(), -1);
std::vector<int64> index_dim_to_output_dim(indices.base_shape().rank(),
-1);
for (int64 i = 0; i < batch_dims.size(); ++i) {
int64 indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1;
output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim;
index_dim_to_output_dim[indices_batch_dim] = batch_dims[i];
}
auto pgather_sharding =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
indices.sharding(), index_dim_to_output_dim,
output_dim_to_index_dim);
CHECK(pgather_sharding.has_value());
pgather->set_sharding(*pgather_sharding);
VLOG(5) << "[Gather partitioning]: Partitioned as index only";
return pgather;
}
}
return nullptr;
}
// Perform partitioning of Gather when the operand is split in a offset
// dimension that is passed through (slice size is the same size of the operand
// dimension).
StatusOr<HloInstruction*> ParititonPassthroughOperand(
const HloGatherInstruction* gather, Shape output_shape,
const HloSharding& output_sharding, absl::Span<const int64> batch_dims,
PartitionedHlo& operand, PartitionedHlo& indices,
SpmdPartitioningVisitor* visitor) {
SpmdBuilder* b = visitor->builder();
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
if (auto maybe_passthrough =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
operand.sharding(), *gather, output_shape,
operand.base_shape())) {
indices = indices.Reshard(HloSharding::Replicate());
auto pshape = MakePartitionedShape(output_shape, *maybe_passthrough);
std::vector<int64> pslice_sizes(gather->gather_slice_sizes().begin(),
gather->gather_slice_sizes().end());
for (int64 i = 0; i < pslice_sizes.size(); ++i) {
if (operand.sharding().tile_assignment().dim(i) > 1) {
pslice_sizes[i] = operand.hlo()->shape().dimensions(i);
}
}
auto pgather = b->AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes,
gather->indices_are_sorted()));
pgather->set_sharding(*maybe_passthrough);
VLOG(5) << "[Gather partitioning]: Partitioned as operand passthrough "
"offset_dim";
return PartitionedHlo(pgather, output_shape, operand.state())
.Reshard(output_sharding)
.hlo();
}
return nullptr;
}
// Partition a Gather when its sliced in a dimension in the operand that is
// trivially sliced (sliced with slice size of 1).
StatusOr<HloInstruction*> ParititonTrivialIndexedOperandDimension(
const HloGatherInstruction* gather, Shape output_shape,
const HloSharding& output_sharding, absl::Span<const int64> batch_dims,
PartitionedHlo& operand, PartitionedHlo& indices,
SpmdPartitioningVisitor* visitor) {
SpmdBuilder* b = visitor->builder();
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, start_index_map, gather->gather_slice_sizes()) &&
ShapeSizeInBytes(output_shape) < ShapeSizeInBytes(operand.base_shape())) {
indices = indices.Reshard(HloSharding::Replicate());
// Now the operand is partitioned in trivial slice dimensions, and the
// indices are replicated. We execute a gather on partitioned operand,
// with full number of indices, where out-of-bounds indices are clamped,
// and masked out with 0 in the result; then we use all-reduce to combine
// results. Although gather will not get faster, we avoided the need to
// replicate the operand.
HloInstruction* indices_min;
HloInstruction* indices_max;
std::tie(indices_min, indices_max) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, operand.state().partition_id, start_index_map,
dnums.index_vector_dim(), b);
// Clamp the indices.
auto adjusted_indices = b->AddInstruction(
HloInstruction::CreateTernary(indices.base_shape(), HloOpcode::kClamp,
indices_min, indices.hlo(), indices_max));
// Adjust the indices by subtracting the offset.
adjusted_indices = b->AddInstruction(
HloInstruction::CreateBinary(indices.base_shape(), HloOpcode::kSubtract,
adjusted_indices, indices_min));
// Gather on adjusted indices.
auto pgather = b->AddInstruction(HloInstruction::CreateGather(
output_shape, operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
// Mask out invalid results.
auto filter = b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED), indices.hlo(),
indices_min, ComparisonDirection::kLt));
filter = b->AddInstruction(HloInstruction::CreateBinary(
filter->shape(), HloOpcode::kOr, filter,
b->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_max, ComparisonDirection::kGt))));
if (dnums.index_vector_dim() < indices.base_shape().rank()) {
std::vector<int64> reduced_filter_dims;
for (int64 i = 0; i < filter->shape().rank(); ++i) {
if (i != dnums.index_vector_dim()) {
reduced_filter_dims.push_back(filter->shape().dimensions(i));
}
}
filter = b->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter,
CreateR0WithType(PRED, false, b), {dnums.index_vector_dim()},
MakeBinaryAdd(PRED, indices.state().module)));
}
std::vector<int64> batch_dims;
for (int64 i = 0; i < pgather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
auto broadcast_filter = b->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter,
batch_dims));
auto filtered = b->AddInstruction(HloInstruction::CreateTernary(
pgather->shape(), HloOpcode::kSelect, broadcast_filter,
CreateZero(pgather->shape(), b), pgather));
// Combine from different partitions.
absl::InlinedVector<int64, 1> replicated_dim;
if (operand.sharding().ReplicateOnLastTileDim()) {
replicated_dim.push_back(
operand.sharding().tile_assignment().num_dimensions() - 1);
}
auto sharding_grouped =
GroupShardingOnDims(operand.sharding(), replicated_dim);
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), sharding_grouped.device_groups, b);
auto collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
b, filtered,
MakeBinaryAdd(filtered->shape().element_type(),
per_group_partitioner_state.module),
{}, visitor->NewChannel());
VLOG(5) << "[Gather partitioning]: Partitioned as trivial operand "
"batch_dim slice";
ar->set_sharding(HloSharding::Replicate());
return PartitionedHlo(ar, output_shape, operand.state())
.Reshard(output_sharding)
.hlo();
}
return nullptr;
}
// Partition a gather over a indices dimensions that are cosidered parallel
// (which means that the indices access the operand in a monotonically
// increasing way across the respective operand dimension referenced by the
// index).
StatusOr<HloInstruction*> PartitionIndexParallelDimensions(
const HloGatherInstruction* gather, Shape output_shape,
const HloSharding& output_sharding, absl::Span<const int64> batch_dims,
PartitionedHlo& operand, PartitionedHlo& indices,
SpmdPartitioningVisitor* visitor) {
absl::InlinedVector<std::pair<HloInstruction*, HloSharding>, 2>
top_level_sharding_to_reset;
auto cleaner = MakeCleanup([&top_level_sharding_to_reset] {
for (auto& to_reset : top_level_sharding_to_reset) {
to_reset.first->set_sharding(to_reset.second);
}
});
SpmdBuilder* b = visitor->builder();
GatherDimensionNumbers dnums = gather->gather_dimension_numbers();
// Handle the case where operand is tile maximal. In this case we check if
// the index is not TileMaximal and in this case we use the index sharding
// to drive the output sharding.
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(*gather)) {
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
*operand.hlo(), *indices.hlo(), *parallel_dims)) {
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
auto output_parallel_dims =
hlo_sharding_util::GatherParallelOutputDims(*gather, *parallel_dims);
HloSharding indices_sharding = gather_sharding->indices_sharding;
HloSharding operand_sharding = gather_sharding->operand_sharding;
GroupedSharding grouped_indices =
GroupShardingOnDims(indices_sharding, indices_parallel_dims);
GroupedSharding grouped_operand =
GroupShardingOnDims(operand_sharding, operand_parallel_dims);
int index_dim = dnums.index_vector_dim();
// Construct the required sharding for the new gather we are gonna form.
absl::InlinedVector<int64, 4> output_tiling(
output_shape.dimensions_size(), 1);
for (int i = 0, num_output_parallel_dims = output_parallel_dims.size();
i < num_output_parallel_dims; ++i) {
int output_idx = output_parallel_dims[i];
int indices_idx = indices_parallel_dims[i];
output_tiling[output_idx] =
indices_sharding.tile_assignment().dim(indices_idx);
}
operand = operand.Reshard(operand_sharding);
indices = indices.Reshard(indices_sharding);
if (indices_sharding.ReplicateOnLastTileDim()) {
output_tiling.push_back(
indices_sharding.tile_assignment().dimensions().back());
}
Array<int64> output_tile_assignment = indices_sharding.tile_assignment();
output_tile_assignment.Reshape(output_tiling);
// New gather tiling.
HloSharding gather_output_sharding =
indices_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment)
: HloSharding::Tile(output_tile_assignment);
// Shape of the partitioned gather
Shape pshape = MakePartitionedShape(output_shape, gather_output_sharding);
// Construct the offsets for the operand sharding to be used to adjust
// the indices. Because we know the only dimensions partitioned are the
// parallel ones and because the partitioning is the same across indices
// and operands we can apply the offsets on the operands on the indices.
std::vector<HloInstruction*> operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand_sharding, operand.state().partition_id,
b, operand_parallel_dims);
absl::InlinedVector<HloInstruction*, 4> index_offsets;
for (int start_idx = 0; start_idx < dnums.start_index_map_size();
++start_idx) {
HloInstruction* index_offset =
indices.base_shape().dimensions_size() > index_dim
? b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1}),
operand_offsets[dnums.start_index_map(start_idx)]))
: operand_offsets[dnums.start_index_map(start_idx)];
index_offsets.push_back(index_offset);
}
HloInstruction* adjusted_indices = nullptr;
if (indices.base_shape().dimensions_size() > index_dim) {
// Concatenate the offsets for the parallel dimensions to subtract.
adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(S32,
{indices.base_shape().dimensions(index_dim)}),
index_offsets, 0));
} else {
CHECK_EQ(index_offsets.size(), 1);
adjusted_indices = index_offsets[0];
}
if (indices.hlo()->shape().element_type() != PrimitiveType::S32) {
adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(adjusted_indices->shape(),
indices.hlo()->shape().element_type()),
adjusted_indices));
}
if (adjusted_indices->shape().rank() == 0) {
adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {}));
} else {
adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {index_dim}));
}
// Adjust indices by subtracting the offsets based on the partition id.
adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
adjusted_indices));
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), grouped_operand.device_groups, b);
top_level_sharding_to_reset.emplace_back(operand.hlo(),
operand.sharding());
adjusted_indices->set_sharding(grouped_indices.sharding);
operand.hlo()->set_sharding(grouped_operand.sharding);
VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim";
HloInstruction* pgather;
if (operand_sharding.NumTiles() ==
operand_sharding.NumTiles(operand_parallel_dims) &&
indices_sharding.NumTiles() ==
indices_sharding.NumTiles(indices_parallel_dims)) {
pgather = b->AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
} else {
PartitionedHlo per_group_operand(
operand.hlo(),
GetPerGroupBaseShape(grouped_operand, operand.base_shape()),
per_group_partitioner_state);
PartitionedHlo per_group_indices(
adjusted_indices,
GetPerGroupBaseShape(grouped_indices, indices.base_shape()),
per_group_partitioner_state);
GroupedSharding grouped_output =
GroupShardingOnDims(gather_output_sharding, output_parallel_dims);
TF_ASSIGN_OR_RETURN(pgather, PartitionGather(gather, per_group_operand,
per_group_indices, pshape,
grouped_output.sharding,
batch_dims, visitor));
}
if (pgather) {
pgather->set_sharding(gather_output_sharding);
return PartitionedHlo(pgather, output_shape, operand.state())
.Reshard(output_sharding)
.hlo();
}
}
}
return nullptr;
}
StatusOr<HloInstruction*> PartitionGather(const HloGatherInstruction* gather,
PartitionedHlo& operand,
PartitionedHlo& indices,
const Shape& output_shape,
const HloSharding& output_sharding,
absl::Span<const int64> batch_dims,
SpmdPartitioningVisitor* visitor) {
absl::InlinedVector<std::pair<HloInstruction*, HloSharding>, 2>
top_level_sharding_to_reset;
auto cleaner = MakeCleanup([&top_level_sharding_to_reset] {
for (auto& to_reset : top_level_sharding_to_reset) {
to_reset.first->set_sharding(to_reset.second);
}
});
HloInstruction* partitioned_gather;
// Check if we identify some of the dimensions of the gather as parallel and
// if we have sharded the operand and indices across those dimensions.
// If that's the case then we can partition the gather across such dimensions
// by adjusting the offsets.
TF_ASSIGN_OR_RETURN(
partitioned_gather,
PartitionIndexParallelDimensions(gather, output_shape, output_sharding,
batch_dims, operand, indices, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
// Pefrorm passthrough and trivial slice partitioning of the Gather.
if (!operand.sharding().IsTileMaximal()) {
TF_ASSIGN_OR_RETURN(
partitioned_gather,
ParititonPassthroughOperand(gather, output_shape, output_sharding,
batch_dims, operand, indices, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
TF_ASSIGN_OR_RETURN(partitioned_gather,
ParititonTrivialIndexedOperandDimension(
gather, output_shape, output_sharding, batch_dims,
operand, indices, visitor));
if (partitioned_gather) {
return partitioned_gather;
}
}
return nullptr;
}
} // namespace
Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
@ -288,266 +671,33 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
const auto& dnums = gather->gather_dimension_numbers();
auto operand = GetPartitionedHlo(gather->operand(0));
auto indices = GetPartitionedHlo(gather->operand(1));
std::vector<int64> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
std::vector<int64> batch_dims;
for (int64 i = 0; i < gather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
// Check if we identify some of the dimensions of the gather as parallel and
// if we have sharded the operand and indices across those dimensions.
// If that's the case then we can partition the gather across such dimensions
// by adjusting the offsets.
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(*hlo)) {
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
*operand.hlo(), *indices.hlo(), *parallel_dims)) {
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
auto output_parallel_dims =
hlo_sharding_util::GatherParallelOutputDims(*hlo, *parallel_dims);
HloSharding indices_sharding = gather_sharding->indices_sharding;
HloSharding operand_sharding = gather_sharding->operand_sharding;
if (operand_sharding.NumTiles() ==
operand_sharding.NumTiles(operand_parallel_dims) &&
indices_sharding.NumTiles() ==
indices_sharding.NumTiles(indices_parallel_dims)) {
int index_dim = dnums.index_vector_dim();
// Construct the required sharding for the new gather we are gonna form.
absl::InlinedVector<int64, 4> output_tiling(
hlo->shape().dimensions_size(), 1);
for (int i = 0, num_output_parallel_dims = output_parallel_dims.size();
i < num_output_parallel_dims; ++i) {
int output_idx = output_parallel_dims[i];
int indices_idx = indices_parallel_dims[i];
output_tiling[output_idx] =
indices_sharding.tile_assignment().dim(indices_idx);
}
operand = operand.Reshard(operand_sharding);
indices = indices.Reshard(indices_sharding);
if (indices_sharding.ReplicateOnLastTileDim()) {
output_tiling.push_back(
indices_sharding.tile_assignment().dimensions().back());
}
Array<int64> output_tile_assignment =
indices_sharding.tile_assignment();
output_tile_assignment.Reshape(output_tiling);
// New gather tiling.
HloSharding output_sharding =
indices_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment)
: HloSharding::Tile(output_tile_assignment);
// Shape of the partitioned gather
Shape pshape = MakePartitionedShape(gather->shape(), output_sharding);
// Construct the offsets for the operand sharding to be used to adjust
// the indices. Because we know the only dimensions partitioned are the
// parallel ones and because the partitioning is the same across indices
// and operands we can apply the offsets on the operands on the indices.
std::vector<HloInstruction*> operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand_sharding, partition_id_, &b_);
absl::InlinedVector<HloInstruction*, 4> index_offsets;
for (int start_idx = 0; start_idx < dnums.start_index_map_size();
++start_idx) {
HloInstruction* index_offset =
indices.base_shape().dimensions_size() > index_dim
? b_.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1}),
operand_offsets[dnums.start_index_map(start_idx)]))
: operand_offsets[dnums.start_index_map(start_idx)];
index_offsets.push_back(index_offset);
}
HloInstruction* adjusted_indices = nullptr;
if (indices.base_shape().dimensions_size() > index_dim) {
// Concatenate the offsets for the parallel dimensions to subtract.
adjusted_indices =
b_.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(
S32, {indices.base_shape().dimensions(index_dim)}),
index_offsets, 0));
} else {
CHECK_EQ(index_offsets.size(), 1);
adjusted_indices = index_offsets[0];
}
if (indices.hlo()->shape().element_type() != PrimitiveType::S32) {
adjusted_indices = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(
adjusted_indices->shape(),
indices.hlo()->shape().element_type()),
adjusted_indices));
}
if (adjusted_indices->shape().rank() == 0) {
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {}));
} else {
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {index_dim}));
}
// Adjust indices by subtracting the offsets based on the partition id.
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
adjusted_indices));
HloInstruction* pgather =
b_.AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
pgather->set_sharding(output_sharding);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
HloInstruction* pgather;
TF_ASSIGN_OR_RETURN(pgather,
PartitionGather(gather, operand, indices, gather->shape(),
gather->sharding(),
absl::MakeConstSpan(batch_dims), this));
if (pgather) {
SetPartitionedHlo(gather, [pgather] { return pgather; });
return Status::OK();
}
}
}
if (operand.sharding().IsTileMaximal()) {
if (!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||
indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) ==
1)) {
auto replicated_operand = operand.Replicate();
TF_ASSIGN_OR_RETURN(
Shape partitioned_output_shape,
ShapeInference::InferGatherShape(replicated_operand.hlo()->shape(),
indices.hlo()->shape(), dnums,
gather->gather_slice_sizes()));
auto pgather = b_.AddInstruction(gather->CloneWithNewOperands(
partitioned_output_shape, {replicated_operand.hlo(), indices.hlo()}));
std::vector<int64> output_dim_to_index_dim(pgather->shape().rank(), -1);
std::vector<int64> index_dim_to_output_dim(indices.base_shape().rank(),
-1);
for (int64 i = 0; i < batch_dims.size(); ++i) {
int64 indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1;
output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim;
index_dim_to_output_dim[indices_batch_dim] = batch_dims[i];
}
auto pgather_sharding =
hlo_sharding_util::TransposeShardingWithCollapsedDims(
indices.sharding(), index_dim_to_output_dim,
output_dim_to_index_dim);
CHECK(pgather_sharding.has_value());
pgather->set_sharding(*pgather_sharding);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
// Handle the case where operand is tile maximal. In this case we check if
// the index is not TileMaximal and in this case we use the index sharding
// to drive the output sharding.
TF_ASSIGN_OR_RETURN(pgather, PartitionIndexOnlyPartition(
gather, absl::MakeConstSpan(batch_dims),
operand, indices, &b_));
if (pgather) {
SetPartitionedHlo(gather, [pgather] { return pgather; });
return Status::OK();
}
} else {
auto maybe_passthrough =
hlo_sharding_util::GatherOutputShardingFromDataOperand(
operand.sharding(), *hlo);
if (maybe_passthrough.has_value()) {
indices = indices.Reshard(HloSharding::Replicate());
auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough);
std::vector<int64> pslice_sizes(gather->gather_slice_sizes().begin(),
gather->gather_slice_sizes().end());
for (int64 i = 0; i < pslice_sizes.size(); ++i) {
if (operand.sharding().tile_assignment().dim(i) > 1) {
pslice_sizes[i] = operand.hlo()->shape().dimensions(i);
}
}
auto pgather = b_.AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes,
gather->indices_are_sorted()));
pgather->set_sharding(*maybe_passthrough);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims(
operand, start_index_map, gather->gather_slice_sizes()) &&
ShapeSizeInBytes(gather->shape()) <
ShapeSizeInBytes(gather->operand(0)->shape())) {
indices = indices.Reshard(HloSharding::Replicate());
// Now the operand is partitioned in trivial slice dimensions, and the
// indices are replicated. We execute a gather on partitioned operand,
// with full number of indices, where out-of-bounds indices are clamped,
// and masked out with 0 in the result; then we use all-reduce to combine
// results. Although gather will not get faster, we avoided the need to
// replicate the operand.
HloInstruction* indices_min;
HloInstruction* indices_max;
std::tie(indices_min, indices_max) =
IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims(
operand, indices, partition_id_, start_index_map,
dnums.index_vector_dim(), &b_);
// Clamp the indices.
auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary(
indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(),
indices_max));
// Adjust the indices by subtracting the offset.
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.base_shape(), HloOpcode::kSubtract, adjusted_indices,
indices_min));
// Gather on adjusted indices.
auto pgather = b_.AddInstruction(HloInstruction::CreateGather(
gather->shape(), operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
// Mask out invalid results.
auto filter = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_min, ComparisonDirection::kLt));
filter = b_.AddInstruction(HloInstruction::CreateBinary(
filter->shape(), HloOpcode::kOr, filter,
b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(indices.base_shape(), PRED),
indices.hlo(), indices_max, ComparisonDirection::kGt))));
if (dnums.index_vector_dim() < indices.base_shape().rank()) {
std::vector<int64> reduced_filter_dims;
for (int64 i = 0; i < filter->shape().rank(); ++i) {
if (i != dnums.index_vector_dim()) {
reduced_filter_dims.push_back(filter->shape().dimensions(i));
}
}
filter = b_.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter,
CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()},
MakeBinaryAdd(PRED, module_)));
}
std::vector<int64> batch_dims;
for (int64 i = 0; i < pgather->shape().rank(); ++i) {
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
batch_dims.push_back(i);
}
}
auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter,
batch_dims));
auto filtered = b_.AddInstruction(HloInstruction::CreateTernary(
pgather->shape(), HloOpcode::kSelect, broadcast_filter,
CreateZero(pgather->shape(), &b_), pgather));
// Combine from different partitions.
auto collective_ops_creator = collective_ops_creator_;
if (operand.sharding().ReplicateOnLastTileDim()) {
auto sharding_grouped = GroupShardingOnDims(
operand.sharding(),
{operand.sharding().tile_assignment().num_dimensions() - 1});
auto per_group_partitioner_state = CreatePerGroupPartitioningState(
operand.state(), sharding_grouped.device_groups, &b_);
collective_ops_creator =
per_group_partitioner_state.collective_ops_creator;
}
auto ar = collective_ops_creator.create_cross_partition_all_reduce(
&b_, filtered,
MakeBinaryAdd(filtered->shape().element_type(), module_), {},
NewChannel());
ar->set_sharding(HloSharding::Replicate());
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(ar, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
return DefaultAction(hlo);
return DefaultAction(gather);
}
} // namespace spmd

View File

@ -6913,6 +6913,69 @@ ENTRY %module {
op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
}
TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0},
sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %parameter.0,
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
slice_sizes={1,1,2,2}, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
auto root = module->entry_computation()->root_instruction();
auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::DynamicSlice());
auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
EXPECT_THAT(
root, op::AllReduce(op::DynamicUpdateSlice(
_, op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)),
_, _, _, _)));
}
TEST_F(SpmdPartitioningTest, GatherMergedParallelIndexTrivialSlice) {
absl::string_view hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
%parameter.1 = s32[1,8,1]{2,1,0} parameter(1),
sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
%iota = s32[1,8,1]{2,1,0} iota(), iota_dimension=1,
sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
%concatenate.19 = s32[2,8,1]{2,1,0} concatenate(
s32[1,8,1]{2,1,0} %parameter.1, s32[1,8,1]{2,1,0} %iota), dimensions={0},
sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
ROOT %gather.20 = s32[8,1,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %parameter.0,
s32[2,8,1]{2,1,0} %concatenate.19), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
slice_sizes={1,1,2,2}, sharding={replicated}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/8));
auto root = module->entry_computation()->root_instruction();
auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Parameter());
auto indices = AllOf(op::Shape("s32[2,2,1]"), op::Subtract());
auto gather = AllOf(op::Shape("s32[2,1,2,2]"), op::Gather(operand, indices));
EXPECT_THAT(root,
op::AllReduce(op::DynamicUpdateSlice(
_, op::AllReduce(op::Select(_, _, gather)), _, _, _, _)));
}
} // namespace
} // namespace spmd
} // namespace xla