[XLA] Add sharding propagation support for Gather sharding on parallel dimensions.

Should enable sharding of tf.reverse_sequence() and gradient.

PiperOrigin-RevId: 352154964
Change-Id: Ifa953ae5068ebda61ef0d1d1b0438362afcb7e67
This commit is contained in:
Marcello Maggioni 2021-01-16 01:50:09 -08:00 committed by TensorFlower Gardener
parent 1d0e1df48a
commit 85bf96f508
8 changed files with 500 additions and 59 deletions

View File

@ -489,6 +489,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -15,11 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include <algorithm>
#include <map>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@ -700,6 +702,54 @@ absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
: HloSharding::Tile(tile_assignment);
}
// Collect data operand sharding for a gather with parallel dimensions from
// the output.
absl::optional<HloSharding> GatherParallelDataOperandSharding(
const HloSharding& output_sharding, const HloInstruction& gather) {
if (output_sharding.IsTileMaximal()) {
return output_sharding;
}
auto parallel_dims = GetGatherBatchParallelDims(gather);
if (!parallel_dims) {
return absl::nullopt;
}
auto output_parallel_dims = GatherParallelOutputDims(gather, *parallel_dims);
auto output_aligned_operand_parallel_dims =
GatherOutputAlignedOperandParallelDims(gather, *parallel_dims);
const Shape gather_shape = gather.shape();
CHECK_EQ(output_parallel_dims.size(),
output_aligned_operand_parallel_dims.size());
std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
1);
for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
if (parallel_idx >= output_parallel_dims.size() ||
output_parallel_dims[parallel_idx] != i) {
// Support only the case where the output dimensions are sharded only
// across parallel dimensions.
if (output_sharding.tile_assignment().dim(i) != 1) {
return absl::nullopt;
}
continue;
}
const int64 operand_dim =
output_aligned_operand_parallel_dims[parallel_idx++];
operand_tile_assignment[operand_dim] =
output_sharding.tile_assignment().dim(i);
}
if (output_sharding.ReplicateOnLastTileDim()) {
operand_tile_assignment.push_back(
output_sharding.tile_assignment().dimensions().back());
}
Array<int64> tile_assignment = output_sharding.tile_assignment();
if (tile_assignment.num_elements() != Product(operand_tile_assignment)) {
return absl::nullopt;
}
tile_assignment.Reshape(operand_tile_assignment);
return output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment)
: HloSharding::Tile(tile_assignment);
}
} // namespace
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
@ -726,6 +776,12 @@ absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
dnums.start_index_map().end());
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
// Prioritize parallel sharding first as this is how it is in
// spmd_partitioner.
if (auto parallel_sharding =
GatherParallelDataOperandSharding(hlo.sharding(), hlo)) {
return parallel_sharding;
}
return PassthroughGatherOutputOrScatterUpdateToOperand(
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
start_index_map, offset_dims, hlo.gather_slice_sizes());
@ -981,7 +1037,6 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
}
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
const HloInstruction& hlo) {
const auto& dnums = hlo.gather_dimension_numbers();
int64 index_dim = dnums.index_vector_dim();
@ -997,20 +1052,37 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
// is common for tf.reverse_sequence and would match this case.
absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
const HloInstruction* indices = hlo.operand(1);
const int num_indices = dnums.start_index_map_size();
std::vector<int64> index_parallel_in_dim(num_indices, -1);
// Handle cases where we concatenate pieces of the indices one at a time.
if (indices->opcode() == HloOpcode::kConcatenate &&
indices->concatenate_dimension() == index_dim) {
for (auto* op : indices->operands()) {
int concatenated_dims = 0;
for (int i = 0; i < indices->operand_count(); ++i) {
const HloInstruction* op = indices->operand(i);
const int64 num_indices_from_element =
op->shape().dimensions_size() > index_dim
? op->shape().dimensions(index_dim)
: 1;
if (auto* iota = DynCast<HloIotaInstruction>(op)) {
if (iota->iota_dimension() != index_dim) {
iotas.push_back(iota);
for (int j = 0; j < num_indices_from_element; ++j) {
index_parallel_in_dim[concatenated_dims + j] =
iota->iota_dimension();
}
}
}
concatenated_dims += num_indices_from_element;
}
} else if (auto* iota = DynCast<HloIotaInstruction>(indices)) {
if (iota->iota_dimension() != index_dim) {
// This is a case of a single iota with index_dim being out of bounds.
iotas.push_back(iota);
const int64 num_indices_from_element =
iota->shape().dimensions_size() > index_dim
? iota->shape().dimensions(index_dim)
: 1;
index_parallel_in_dim.assign(num_indices_from_element,
iota->iota_dimension());
}
}
absl::InlinedVector<int64, 1> indices_parallel_dims;
@ -1020,29 +1092,75 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
// operands and index they could have different spots in the shape because the
// position of the index dimension in the operand is determined by
// start_index_map.
int index_num = 0;
for (auto* iota : iotas) {
int64 num_indices_from_iota = iota->shape().dimensions_size() > index_dim
? iota->shape().dimensions(index_dim)
: 1;
for (int i = 0; i < num_indices_from_iota; ++i) {
int64 index_dim = iota->iota_dimension();
int64 operand_dim = dnums.start_index_map(index_num + i);
// Uniquify multiple iotas concatenated on index_dim with the same
// iota_dimension
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
return absl::nullopt;
}
indices_parallel_dims.push_back(index_dim);
operand_parallel_dims.push_back(operand_dim);
for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
int index_parallel_dim = index_parallel_in_dim[i];
if (index_parallel_dim == -1) {
continue;
}
index_num += num_indices_from_iota;
if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
return absl::nullopt;
}
indices_parallel_dims.push_back(index_parallel_dim);
operand_parallel_dims.push_back(dnums.start_index_map(i));
}
absl::c_sort(indices_parallel_dims);
if (!indices_parallel_dims.empty()) {
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims};
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
index_parallel_in_dim};
}
return absl::nullopt;
}
absl::InlinedVector<int64, 1> GatherParallelOutputDims(
const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
absl::InlinedVector<int64, 1> output_parallel_dims;
auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
const Shape gather_shape = gather.shape();
auto dnums = gather.gather_dimension_numbers();
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
if (absl::c_linear_search(dnums.offset_dims(), i)) {
continue;
}
const int index_dim =
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
output_parallel_dims.push_back(i);
}
++idx_dim;
}
return output_parallel_dims;
}
absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
parallel_dims.operand_parallel_dims.size(), -1);
auto dnums = gather.gather_dimension_numbers();
CHECK_LE(parallel_dims.indices_parallel_dims.size(),
parallel_dims.operand_parallel_dims.size());
for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
// This is the equivalent batch dimension of the indices that corresponds
// to this index dimension.
const int64 index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
// If it's not an index that is parallel skip.
if (index_parallel_dim == -1) {
continue;
}
// This is small so just look linearly. Populate the operand parallel
// dimensions based on the order of the index batch dims (which is the same
// order as the output).
for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
const int64 operand_parallel_dim = dnums.start_index_map(i);
if (operand_parallel_dim_to_output[j] == -1) {
operand_parallel_dim_to_output[j] = operand_parallel_dim;
}
break;
}
}
}
return operand_parallel_dim_to_output;
}
} // namespace hlo_sharding_util
} // namespace xla

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <map>
#include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -32,6 +33,7 @@ namespace hlo_sharding_util {
struct GatherParallelDims {
absl::InlinedVector<int64, 1> indices_parallel_dims;
absl::InlinedVector<int64, 1> operand_parallel_dims;
std::vector<int64> index_parallel_in_dim;
};
// Given a map<device, occurrence_count>, selects the device with higher
@ -188,9 +190,19 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
// Returns identified parallel dimensions for Gather.
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
const HloInstruction& hlo);
// Returns the parallel dimensions of the output of a gather based on the
// parallel dimensions of the input.
absl::InlinedVector<int64, 1> GatherParallelOutputDims(
const HloInstruction& gather, const GatherParallelDims& parallel_dim);
// Returns the parallel dimensions of the data operand of a gather with the
// order of the parallel dimensions matching that of the parallel dimensions
// of the output.
absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
const HloInstruction& gather, const GatherParallelDims& parallel_dims);
} // namespace hlo_sharding_util
} // namespace xla

View File

@ -555,6 +555,75 @@ bool InferDotShardingFromOperands(
return changed;
}
bool InferGatherParallelShardingFromOperands(
HloInstruction* instruction,
const hlo_sharding_util::GatherParallelDims& parallel_dims,
bool may_combine_partial_sharding) {
auto from_operand = [instruction](
int64 operand_index,
absl::Span<const int64> output_aligned_parallel_dims,
absl::Span<const int64> output_parallel_dims) {
const HloInstruction* operand = instruction->operand(operand_index);
const HloSharding& operand_sharding = operand->sharding();
if (operand_sharding.IsTileMaximal()) {
return operand_sharding;
}
auto dnums = instruction->gather_dimension_numbers();
std::vector<int64> output_tile_dims(instruction->shape().rank(), 1);
std::vector<int64> index_non_parallel_dims;
index_non_parallel_dims.reserve(operand->shape().rank());
// Detect non parallel dimensions in the index.
for (int i = 0; i < operand->shape().rank(); ++i) {
if (!absl::c_linear_search(output_aligned_parallel_dims, i)) {
index_non_parallel_dims.push_back(i);
}
}
// Collect tile dimensions in the operand. The order of the parallel
// dimensions in output_aligned_parallel_dims is the same as that of the
// output
for (int i = 0; i < output_aligned_parallel_dims.size(); ++i) {
const int64 indices_idx = output_aligned_parallel_dims[i];
const int64 output_idx = output_parallel_dims[i];
output_tile_dims[output_idx] =
operand_sharding.tile_assignment().dim(indices_idx);
}
HloSharding replicate_non_parallel_dims =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
operand_sharding, index_non_parallel_dims);
if (replicate_non_parallel_dims.ReplicateOnLastTileDim()) {
output_tile_dims.push_back(
replicate_non_parallel_dims.tile_assignment().dimensions().back());
}
auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment();
output_tile_assignment.Reshape(output_tile_dims);
return replicate_non_parallel_dims.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment)
: HloSharding::Tile(output_tile_assignment);
};
bool changed = false;
auto output_parallel_dims =
hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims);
if (IsSpatiallyPartitioned(instruction->operand(0))) {
changed |= MaybeImproveInstructionSharding(
from_operand(
0,
absl::MakeConstSpan(
hlo_sharding_util::GatherOutputAlignedOperandParallelDims(
*instruction, parallel_dims)),
absl::MakeConstSpan(output_parallel_dims)),
instruction, may_combine_partial_sharding);
}
if (IsSpatiallyPartitioned(instruction->operand(1))) {
changed |= MaybeImproveInstructionSharding(
from_operand(1,
absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)),
instruction, may_combine_partial_sharding);
}
return changed;
}
// Convolution handling for InferShardingFromOperands().
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
int64 aggressiveness,
@ -1030,6 +1099,14 @@ bool InferShardingFromOperands(HloInstruction* instruction,
}
case HloOpcode::kGather: {
bool changed = false;
if (is_spmd) {
auto gather_parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
if (gather_parallel_dims) {
changed |= InferGatherParallelShardingFromOperands(
instruction, *gather_parallel_dims, may_combine_partial_sharding);
}
}
if (IsSpatiallyPartitioned(instruction->operand(1))) {
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
instruction->operand(1)->sharding(), instruction);

View File

@ -2579,5 +2579,243 @@ ENTRY %transpose {
"{devices=[2,1,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, ParallelGatherFromOperandForwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
sharding={devices=[8,1,1,1]0,1,4,5,2,3,6,7}
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %parameter.0,
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "gather"),
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
}
TEST_F(ShardingPropagationTest, ParallelGatherFromIndexForwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
sharding={devices=[1,8,1]0,1,4,5,2,3,6,7}
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %parameter.0,
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "gather"),
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
}
TEST_F(ShardingPropagationTest, ParallelGatherBackwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
%copy.p = s32[8,4,2,2]{3,2,1,0} copy(%parameter.0)
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %copy.p,
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}, sharding={devices=[8,1,1,1]0,1,4,5,2,3,6,7}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "concatenate"),
op::Sharding("{devices=[1,8,1]0,1,4,5,2,3,6,7}"));
EXPECT_THAT(FindInstruction(module.get(), "copy.p"),
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
}
TEST_F(ShardingPropagationTest, ParallelGatherBackwardPass2) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
%copy.p = s32[4,8,2,2]{3,2,1,0} copy(%parameter.0)
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[4,8,2,2]{3,2,1,0} %copy.p,
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
slice_sizes={1,1,2,2}, sharding={devices=[1,4,1,1]0,1,4,5}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(FindInstruction(module.get(), "concatenate"),
op::Sharding("{devices=[1,1,4]0,1,4,5}"));
EXPECT_THAT(FindInstruction(module.get(), "copy.p"),
op::Sharding("{devices=[4,1,1,1]0,1,4,5}"));
}
TEST_F(ShardingPropagationTest,
PartialShardingParallelGatherFromOperandForwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
sharding={devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %parameter.0,
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
FindInstruction(module.get(), "gather"),
op::Sharding(
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest,
PartialShardingParallelGatherFromIndexForwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
sharding={devices=[1,4,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %parameter.0,
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
FindInstruction(module.get(), "gather"),
op::Sharding(
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, PartialShardingParallelGatherBackwardPass) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
%copy.p = s32[8,4,2,2]{3,2,1,0} copy(%parameter.0)
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[8,4,2,2]{3,2,1,0} %copy.p,
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
slice_sizes={1,1,2,2},
sharding={devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
FindInstruction(module.get(), "concatenate"),
op::Sharding(
"{devices=[1,4,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
EXPECT_THAT(
FindInstruction(module.get(), "copy.p"),
op::Sharding(
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
}
TEST_F(ShardingPropagationTest, PartialShardingParallelGatherBackwardPass2) {
const char* const hlo_string = R"(
HloModule module
ENTRY %module {
%parameter.0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
%copy.p = s32[4,8,2,2]{3,2,1,0} copy(%parameter.0)
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
%gather = s32[8,4,2,2]{3,2,1,0} gather(
s32[4,8,2,2]{3,2,1,0} %copy.p,
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
slice_sizes={1,1,2,2},
sharding={devices=[1,2,1,1,2]0,1,4,5 last_tile_dim_replicate}
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(
FindInstruction(module.get(), "concatenate"),
op::Sharding("{devices=[1,1,2,2]0,1,4,5 last_tile_dim_replicate}"));
EXPECT_THAT(
FindInstruction(module.get(), "copy.p"),
op::Sharding("{devices=[2,1,1,1,2]0,1,4,5 last_tile_dim_replicate}"));
}
} // namespace
} // namespace xla

View File

@ -2753,28 +2753,13 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
// If that's the case then we can partition the gather across such dimensions
// by adjusting the offsets.
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
hlo_sharding_util::GetGatherBatchParallelDims(
operand.sharding(), indices.sharding(), *hlo)) {
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
hlo_sharding_util::GetGatherBatchParallelDims(*hlo)) {
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
*operand.hlo(), *indices.hlo(),
absl::MakeConstSpan(indices_parallel_dims),
absl::MakeConstSpan(operand_parallel_dims))) {
absl::InlinedVector<int64, 4> output_parallel_dims;
Shape gather_shape = gather->shape();
absl::c_sort(indices_parallel_dims);
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
if (absl::c_linear_search(dnums.offset_dims(), i)) {
continue;
}
int index_dim =
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
output_parallel_dims.push_back(i);
}
++idx_dim;
}
*operand.hlo(), *indices.hlo(), *parallel_dims)) {
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
auto output_parallel_dims =
hlo_sharding_util::GatherParallelOutputDims(*hlo, *parallel_dims);
HloSharding indices_sharding = gather_sharding->indices_sharding;
HloSharding operand_sharding = gather_sharding->operand_sharding;
if (operand_sharding.NumTiles() ==

View File

@ -1830,8 +1830,9 @@ HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(
const HloInstruction& operand, const HloInstruction& indices,
absl::Span<const int64> indices_parallel_dims,
absl::Span<const int64> operand_parallel_dims) {
const hlo_sharding_util::GatherParallelDims& parallel_dims) {
auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
return absl::nullopt;
}
@ -1842,25 +1843,32 @@ GatherOperandsShardedAcrossParallelDims(
if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
return absl::nullopt;
}
absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
for (int idx : parallel_dims.index_parallel_in_dim) {
if (idx != -1) {
indices_parallel_dims_ordered_as_operand.push_back(idx);
}
}
if (new_index_shard.IsReplicated()) {
return GatherParallelDimSharding{
CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
indices_parallel_dims,
indices_parallel_dims_ordered_as_operand,
operand_parallel_dims),
new_operand_shard};
}
if (new_operand_shard.IsReplicated()) {
return GatherParallelDimSharding{
new_index_shard, CreateMatchingShardingOnDims(
operand.shape(), new_index_shard,
operand_parallel_dims, indices_parallel_dims)};
new_index_shard,
CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
operand_parallel_dims,
indices_parallel_dims_ordered_as_operand)};
}
// Parallel dimension distribution needs to be the same, so try to steal
// sharding from partial replication to compensate.
if (idx_parallel_tiles_num != op_parallel_tiles_num) {
auto to_adjust_dims = operand_parallel_dims;
auto target_dims = indices_parallel_dims;
auto target_dims = indices_parallel_dims_ordered_as_operand;
HloSharding* target = &new_index_shard;
HloSharding* to_adjust = &new_operand_shard;
if (idx_parallel_tiles_num < op_parallel_tiles_num) {
@ -1908,17 +1916,19 @@ GatherOperandsShardedAcrossParallelDims(
// Make sure that the parallel dimensions are aligned.
auto operand_shard_tile_dims =
new_operand_shard.tile_assignment().dimensions();
for (int i = 0; i < indices_parallel_dims.size(); ++i) {
for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
operand_shard_tile_dims[operand_parallel_dims[i]] =
new_index_shard.tile_assignment().dim(indices_parallel_dims[i]);
new_index_shard.tile_assignment().dim(
indices_parallel_dims_ordered_as_operand[i]);
}
auto operand_shard_tiles = new_operand_shard.tile_assignment();
operand_shard_tiles.Reshape(operand_shard_tile_dims);
new_operand_shard = AlignShardingOnDims(
new_operand_shard.ReplicateOnLastTileDim()
? HloSharding::PartialTile(operand_shard_tiles)
: HloSharding::Tile(operand_shard_tiles),
operand_parallel_dims, new_index_shard, indices_parallel_dims);
new_operand_shard =
AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
? HloSharding::PartialTile(operand_shard_tiles)
: HloSharding::Tile(operand_shard_tiles),
operand_parallel_dims, new_index_shard,
indices_parallel_dims_ordered_as_operand);
return GatherParallelDimSharding{new_index_shard, new_operand_shard};
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
namespace xla {
@ -414,8 +415,7 @@ HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(
const HloInstruction& operand, const HloInstruction& indices,
absl::Span<const int64> indices_parallel_dims,
absl::Span<const int64> operand_parallel_dims);
const hlo_sharding_util::GatherParallelDims& parallel_dims);
} // namespace spmd
} // namespace xla