[XLA:SPMD] Adding support for detecting potential parallel dimensions in gather.

Some gathers have indices driven by IOTAs that cause the access to be
partitionable in the dimension indexed by the dimension with of the IOTA.
An example is:

  %iota.1 = iota()
  %indices = concatenate(..., %iota.1, ...)
  ... = gather(..., %indices)

which is a pattern followed by tf.reverse_sequence for example.

This patch adds support to detecting such situation and allowing sharding
of the gather across devices.

PiperOrigin-RevId: 351205227
Change-Id: I79b7b6bf9392a81df0f5810c41d4b72c6f2678a6
This commit is contained in:
Marcello Maggioni 2021-01-11 12:00:58 -08:00 committed by TensorFlower Gardener
parent 96fdbb1a8d
commit 7618f357a8
10 changed files with 806 additions and 198 deletions

View File

@ -479,6 +479,7 @@ cc_library(
],
deps = [
":hlo",
":hlo_casting_utils",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@ -567,6 +568,21 @@ int64 HloSharding::NumTiles() const {
return tile_assignment().num_elements();
}
int64 HloSharding::NumTiles(absl::Span<const int64> dims) const {
if (IsTileMaximal()) {
return 1;
}
CHECK(!IsManual());
CHECK(!ReplicateOnLastTileDim() ||
!absl::c_linear_search(dims, tile_assignment().num_dimensions() - 1));
int64 num_tiles = 1;
for (auto d : dims) {
CHECK(d < tile_assignment().num_dimensions());
num_tiles *= tile_assignment().dim(d);
}
return num_tiles;
}
HloSharding HloSharding::GetSubSharding(const Shape& shape,
const ShapeIndex& index) const {
CHECK(IsTuple());

View File

@ -263,6 +263,9 @@ class HloSharding {
// Gets the number of tiles. If it has partial replication, this will not
// equal the device count.
int64 NumTiles() const;
// Like NumTiles() but considers only some specific dimensions passed as
// argument
int64 NumTiles(absl::Span<const int64> dims) const;
private:
explicit HloSharding(bool manual, bool replicated)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -979,5 +980,68 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
: HloSharding::Tile(reshape_tiles);
}
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
const HloInstruction& hlo) {
const auto& dnums = hlo.gather_dimension_numbers();
int64 index_dim = dnums.index_vector_dim();
// Try to identify if there's a dimension in the indices that is monotonically
// increasing with a Iota across a certain dimension. This would mean that the
// access in the relative dimension indexed by this index in the operand is
// parallelizable and that we can shard the operand (and the index/output)
// across such dimension.
// For example the pattern:
// %iota.1 = iota()
// %indices = concatenate(..., %iota.1, ...)
// ... = gather(..., %indices)
// is common for tf.reverse_sequence and would match this case.
absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
const HloInstruction* indices = hlo.operand(1);
// Handle cases where we concatenate pieces of the indices one at a time.
if (indices->opcode() == HloOpcode::kConcatenate &&
indices->concatenate_dimension() == index_dim) {
for (auto* op : indices->operands()) {
if (auto* iota = DynCast<HloIotaInstruction>(op)) {
if (iota->iota_dimension() != index_dim) {
iotas.push_back(iota);
}
}
}
} else if (auto* iota = DynCast<HloIotaInstruction>(indices);
iota != nullptr && iota->iota_dimension() != index_dim) {
// This is a case of a single iota with index_dim being out of bounds.
iotas.push_back(iota);
}
absl::InlinedVector<int64, 1> indices_parallel_dims;
absl::InlinedVector<int64, 1> operand_parallel_dims;
// Map the parallelizable dimension from the iota to the dimensions of the
// output and the operand. These dimensions are interconnected, but between
// operands and index they could have different spots in the shape because the
// position of the index dimension in the operand is determined by
// start_index_map.
int index_num = 0;
for (auto* iota : iotas) {
int64 num_indices_from_iota = iota->shape().dimensions_size() > index_dim
? iota->shape().dimensions(index_dim)
: 1;
for (int i = 0; i < num_indices_from_iota; ++i) {
int64 index_dim = iota->iota_dimension();
int64 operand_dim = dnums.start_index_map(index_num + i);
// Uniquify multiple iotas concatenated on index_dim with the same
// iota_dimension
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
return absl::nullopt;
}
indices_parallel_dims.push_back(index_dim);
operand_parallel_dims.push_back(operand_dim);
}
index_num += num_indices_from_iota;
}
if (!indices_parallel_dims.empty()) {
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims};
}
return absl::nullopt;
}
} // namespace hlo_sharding_util
} // namespace xla

View File

@ -29,6 +29,11 @@ limitations under the License.
namespace xla {
namespace hlo_sharding_util {
struct GatherParallelDims {
absl::InlinedVector<int64, 1> indices_parallel_dims;
absl::InlinedVector<int64, 1> operand_parallel_dims;
};
// Given a map<device, occurrence_count>, selects the device with higher
// occurrence count (if any). If top_count in not nullptr, it will receive the
// count of the dominant device returned.
@ -181,6 +186,11 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
const HloSharding& source, absl::Span<int64 const> src_to_tgt,
absl::Span<int64 const> tgt_to_src);
// Returns identified parallel dimensions for Gather.
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
const HloInstruction& hlo);
} // namespace hlo_sharding_util
} // namespace xla

View File

@ -56,9 +56,11 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -24,9 +24,11 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -44,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
@ -2745,6 +2748,127 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
batch_dims.push_back(i);
}
}
// Check if we identify some of the dimensions of the gather as parallel and
// if we have sharded the operand and indices across those dimensions.
// If that's the case then we can partition the gather across such dimensions
// by adjusting the offsets.
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(
operand.sharding(), indices.sharding(), *hlo);
parallel_dims.has_value()) {
auto& [indices_parallel_dims, operand_parallel_dims] = *parallel_dims;
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
*operand.hlo(), *indices.hlo(),
absl::MakeConstSpan(indices_parallel_dims),
absl::MakeConstSpan(operand_parallel_dims));
gather_sharding.has_value()) {
absl::InlinedVector<int64, 4> output_parallel_dims;
Shape gather_shape = gather->shape();
absl::c_sort(indices_parallel_dims);
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
if (absl::c_linear_search(dnums.offset_dims(), i)) {
continue;
}
int index_dim =
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
output_parallel_dims.push_back(i);
}
++idx_dim;
}
auto [indices_sharding, operand_sharding] = *gather_sharding;
if (operand_sharding.NumTiles() ==
operand_sharding.NumTiles(operand_parallel_dims) &&
indices_sharding.NumTiles() ==
indices_sharding.NumTiles(indices_parallel_dims)) {
int index_dim = dnums.index_vector_dim();
// Construct the required sharding for the new gather we are gonna form.
absl::InlinedVector<int64, 4> output_tiling(
hlo->shape().dimensions_size(), 1);
for (int i = 0, num_output_parallel_dims = output_parallel_dims.size();
i < num_output_parallel_dims; ++i) {
int output_idx = output_parallel_dims[i];
int indices_idx = indices_parallel_dims[i];
output_tiling[output_idx] =
indices_sharding.tile_assignment().dim(indices_idx);
}
operand = operand.Reshard(operand_sharding);
indices = indices.Reshard(indices_sharding);
if (indices_sharding.ReplicateOnLastTileDim()) {
output_tiling.push_back(
indices_sharding.tile_assignment().dimensions().back());
}
Array<int64> output_tile_assignment =
indices_sharding.tile_assignment();
output_tile_assignment.Reshape(output_tiling);
// New gather tiling.
HloSharding output_sharding =
indices_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment)
: HloSharding::Tile(output_tile_assignment);
// Shape of the partitioned gather
Shape pshape = MakePartitionedShape(gather->shape(), output_sharding);
// Construct the offsets for the operand sharding to be used to adjust
// the indices. Because we know the only dimensions partitioned are the
// parallel ones and because the partitioning is the same across indices
// and operands we can apply the offsets on the operands on the indices.
std::vector<HloInstruction*> operand_offsets = MakePartitionOffsets(
operand.base_shape(), operand_sharding, partition_id_, &b_);
absl::InlinedVector<HloInstruction*, 4> index_offsets;
for (int start_idx = 0; start_idx < dnums.start_index_map_size();
++start_idx) {
HloInstruction* index_offset =
indices.base_shape().dimensions_size() > index_dim
? b_.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1}),
operand_offsets[dnums.start_index_map(start_idx)]))
: operand_offsets[dnums.start_index_map(start_idx)];
index_offsets.push_back(index_offset);
}
HloInstruction* adjusted_indices = nullptr;
if (indices.base_shape().dimensions_size() > index_dim) {
// Concatenate the offsets for the parallel dimensions to subtract.
adjusted_indices =
b_.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(
S32, {indices.base_shape().dimensions(index_dim)}),
index_offsets, 0));
} else {
CHECK_EQ(index_offsets.size(), 1);
adjusted_indices = index_offsets[0];
}
if (indices.hlo()->shape().element_type() != PrimitiveType::S32) {
adjusted_indices = b_.AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(
adjusted_indices->shape(),
indices.hlo()->shape().element_type()),
adjusted_indices));
}
if (adjusted_indices->shape().rank() == 0) {
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {}));
} else {
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBroadcast(
indices.hlo()->shape(), adjusted_indices, {index_dim}));
}
// Adjust indices by subtracting the offsets based on the partition id.
adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary(
indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(),
adjusted_indices));
HloInstruction* pgather =
b_.AddInstruction(HloInstruction::CreateGather(
pshape, operand.hlo(), adjusted_indices, dnums,
gather->gather_slice_sizes(), gather->indices_are_sorted()));
pgather->set_sharding(output_sharding);
SetPartitionedHlo(hlo, [&]() {
return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
}
}
if (operand.sharding().IsTileMaximal()) {
if (!indices.sharding().IsTileMaximal() &&
(dnums.index_vector_dim() == indices.base_shape().rank() ||

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
@ -1542,6 +1543,15 @@ GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
return grouped_sharding;
}
HloSharding AlignShardingOnDims(const HloSharding& sharding,
absl::Span<const int64> sharding_dims,
const HloSharding& reference,
absl::Span<const int64> reference_dims) {
auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims);
auto reference_grouped = GroupShardingOnDims(reference, reference_dims);
return UngroupSharding(AlignGroupsWith(sharding_grouped, reference_grouped));
}
Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
const Shape& original_base_shape) {
auto result = original_base_shape;
@ -1781,5 +1791,136 @@ absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
return dims;
}
HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
const HloSharding& source_sharding,
absl::Span<const int64> target_dims,
absl::Span<const int64> source_dims) {
CHECK(target_dims.size() == source_dims.size())
<< "Expected 1:1 match between parallel dimensions";
if (source_sharding.IsReplicated()) {
return HloSharding::Replicate();
}
absl::InlinedVector<int64, 4> tile_dims(target_shape.dimensions_size(), 1);
int num_tiles = 1;
for (int i = 0, end = target_dims.size(); i < end; ++i) {
num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
tile_dims[target_dims[i]] =
source_sharding.tile_assignment().dim(source_dims[i]);
}
// If there is some partition across non-parallel dimensions in the
// other operand then partially replicate for the new
bool to_be_partially_replicated = false;
if (num_tiles != source_sharding.tile_assignment().num_elements()) {
CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
to_be_partially_replicated = true;
tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
num_tiles);
}
auto tgt_tile_assignment = source_sharding.tile_assignment();
tgt_tile_assignment.Reshape(tile_dims);
if (to_be_partially_replicated) {
return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
target_dims, source_sharding, source_dims);
} else {
return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
target_dims, source_sharding, source_dims);
}
}
absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(
const HloInstruction& operand, const HloInstruction& indices,
absl::Span<const int64> indices_parallel_dims,
absl::Span<const int64> operand_parallel_dims) {
if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
return absl::nullopt;
}
auto new_index_shard = indices.sharding();
auto new_operand_shard = operand.sharding();
int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
return absl::nullopt;
}
if (new_index_shard.IsReplicated()) {
return GatherParallelDimSharding{
CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
indices_parallel_dims,
operand_parallel_dims),
new_operand_shard};
}
if (new_operand_shard.IsReplicated()) {
return GatherParallelDimSharding{
new_index_shard, CreateMatchingShardingOnDims(
operand.shape(), new_index_shard,
operand_parallel_dims, indices_parallel_dims)};
}
// Parallel dimension distribution needs to be the same, so try to steal
// sharding from partial replication to compensate.
if (idx_parallel_tiles_num != op_parallel_tiles_num) {
auto to_adjust_dims = operand_parallel_dims;
auto target_dims = indices_parallel_dims;
HloSharding* target = &new_index_shard;
HloSharding* to_adjust = &new_operand_shard;
if (idx_parallel_tiles_num < op_parallel_tiles_num) {
std::swap(to_adjust_dims, target_dims);
std::swap(to_adjust, target);
}
if (!to_adjust->ReplicateOnLastTileDim()) {
return absl::nullopt;
}
auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
for (int i = 0; i < to_adjust_dims.size(); ++i) {
int64 target_dim = target->tile_assignment().dim(target_dims[i]);
int64 to_adjust_dim = to_adjust->tile_assignment().dim(to_adjust_dims[i]);
if (target_dim < to_adjust_dim) {
return absl::nullopt;
}
if (target_dim == to_adjust_dim) {
continue;
}
int64 ratio = target_dim / to_adjust_dim;
if (target_dim % to_adjust_dim != 0 ||
new_tile_assignment_dims.back() % ratio != 0) {
return absl::nullopt;
}
new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
new_tile_assignment_dims.back() /= ratio;
}
CHECK_GE(new_tile_assignment_dims.back(), 1);
bool to_partially_replicate = true;
if (new_tile_assignment_dims.back() == 1) {
new_tile_assignment_dims.pop_back();
to_partially_replicate = false;
}
auto new_tile_assignment = to_adjust->tile_assignment();
new_tile_assignment.Reshape(new_tile_assignment_dims);
if (to_partially_replicate) {
*to_adjust =
AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
to_adjust_dims, *target, target_dims);
} else {
*to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
to_adjust_dims, *target, target_dims);
}
}
// Make sure that the parallel dimensions are aligned.
auto operand_shard_tile_dims =
new_operand_shard.tile_assignment().dimensions();
for (int i = 0; i < indices_parallel_dims.size(); ++i) {
operand_shard_tile_dims[operand_parallel_dims[i]] =
new_index_shard.tile_assignment().dim(indices_parallel_dims[i]);
}
auto operand_shard_tiles = new_operand_shard.tile_assignment();
operand_shard_tiles.Reshape(operand_shard_tile_dims);
new_operand_shard = AlignShardingOnDims(
new_operand_shard.ReplicateOnLastTileDim()
? HloSharding::PartialTile(operand_shard_tiles)
: HloSharding::Tile(operand_shard_tiles),
operand_parallel_dims, new_index_shard, indices_parallel_dims);
return GatherParallelDimSharding{new_index_shard, new_operand_shard};
}
} // namespace spmd
} // namespace xla

View File

@ -30,6 +30,11 @@ limitations under the License.
namespace xla {
namespace spmd {
struct GatherParallelDimSharding {
HloSharding indices_sharding;
HloSharding operand_sharding;
};
// Returns true if the given sharding contains any replicated sharding.
bool HasReplicatedSharding(const HloSharding& sharding);
@ -323,6 +328,14 @@ GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
const GroupedSharding& reference,
bool ignore_group_order = false);
// Align device groups between the two ahrdings. Equivalent in calling
// GroupShardingOnDims on the two sharding AlignGroupsWith and then
// UngroupSharding
HloSharding AlignShardingOnDims(const HloSharding& sharding,
absl::Span<const int64> sharding_dims,
const HloSharding& reference,
absl::Span<const int64> reference_dims);
// Returns the per-group base shape, i.e., before applying the in-group
// sharding.
Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
@ -385,6 +398,25 @@ absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
const HloSharding& sharding,
const std::vector<std::vector<int64>>& device_groups);
// Create a sharding that matches the provided source sharding on the
// specified dimensions. 'target_dims' and 'source_dims' represent the
// dimensions for which the sharding should match in their respective shape.
// If some devices from the source sharding are left over (because not all the
// devices are allocated to 'source_dims' dimensions) then partial replication
// is employed to make sure the number of devices for the two sharding match.
HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
const HloSharding& source_sharding,
absl::Span<const int64> target_dims,
absl::Span<const int64> source_dims);
// Returns if the sharding across operand and indices of a gather is across
// parallel dimensions and matches what SPMD partitioner supports.
absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(
const HloInstruction& operand, const HloInstruction& indices,
absl::Span<const int64> indices_parallel_dims,
absl::Span<const int64> operand_parallel_dims);
} // namespace spmd
} // namespace xla