[XLA] Add sharding propagation support for Gather sharding on parallel dimensions.
Should enable sharding of tf.reverse_sequence() and gradient. PiperOrigin-RevId: 352154964 Change-Id: Ifa953ae5068ebda61ef0d1d1b0438362afcb7e67
This commit is contained in:
parent
1d0e1df48a
commit
85bf96f508
@ -489,6 +489,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -15,11 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/array.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
@ -700,6 +702,54 @@ absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
|
||||
: HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
// Collect data operand sharding for a gather with parallel dimensions from
|
||||
// the output.
|
||||
absl::optional<HloSharding> GatherParallelDataOperandSharding(
|
||||
const HloSharding& output_sharding, const HloInstruction& gather) {
|
||||
if (output_sharding.IsTileMaximal()) {
|
||||
return output_sharding;
|
||||
}
|
||||
auto parallel_dims = GetGatherBatchParallelDims(gather);
|
||||
if (!parallel_dims) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
auto output_parallel_dims = GatherParallelOutputDims(gather, *parallel_dims);
|
||||
auto output_aligned_operand_parallel_dims =
|
||||
GatherOutputAlignedOperandParallelDims(gather, *parallel_dims);
|
||||
const Shape gather_shape = gather.shape();
|
||||
CHECK_EQ(output_parallel_dims.size(),
|
||||
output_aligned_operand_parallel_dims.size());
|
||||
std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
|
||||
1);
|
||||
for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
|
||||
if (parallel_idx >= output_parallel_dims.size() ||
|
||||
output_parallel_dims[parallel_idx] != i) {
|
||||
// Support only the case where the output dimensions are sharded only
|
||||
// across parallel dimensions.
|
||||
if (output_sharding.tile_assignment().dim(i) != 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const int64 operand_dim =
|
||||
output_aligned_operand_parallel_dims[parallel_idx++];
|
||||
operand_tile_assignment[operand_dim] =
|
||||
output_sharding.tile_assignment().dim(i);
|
||||
}
|
||||
if (output_sharding.ReplicateOnLastTileDim()) {
|
||||
operand_tile_assignment.push_back(
|
||||
output_sharding.tile_assignment().dimensions().back());
|
||||
}
|
||||
Array<int64> tile_assignment = output_sharding.tile_assignment();
|
||||
if (tile_assignment.num_elements() != Product(operand_tile_assignment)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
tile_assignment.Reshape(operand_tile_assignment);
|
||||
return output_sharding.ReplicateOnLastTileDim()
|
||||
? HloSharding::PartialTile(tile_assignment)
|
||||
: HloSharding::Tile(tile_assignment);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
|
||||
@ -726,6 +776,12 @@ absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
|
||||
dnums.start_index_map().end());
|
||||
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
// Prioritize parallel sharding first as this is how it is in
|
||||
// spmd_partitioner.
|
||||
if (auto parallel_sharding =
|
||||
GatherParallelDataOperandSharding(hlo.sharding(), hlo)) {
|
||||
return parallel_sharding;
|
||||
}
|
||||
return PassthroughGatherOutputOrScatterUpdateToOperand(
|
||||
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
|
||||
start_index_map, offset_dims, hlo.gather_slice_sizes());
|
||||
@ -981,7 +1037,6 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
|
||||
}
|
||||
|
||||
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
||||
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
|
||||
const HloInstruction& hlo) {
|
||||
const auto& dnums = hlo.gather_dimension_numbers();
|
||||
int64 index_dim = dnums.index_vector_dim();
|
||||
@ -997,20 +1052,37 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
||||
// is common for tf.reverse_sequence and would match this case.
|
||||
absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
|
||||
const HloInstruction* indices = hlo.operand(1);
|
||||
const int num_indices = dnums.start_index_map_size();
|
||||
std::vector<int64> index_parallel_in_dim(num_indices, -1);
|
||||
// Handle cases where we concatenate pieces of the indices one at a time.
|
||||
if (indices->opcode() == HloOpcode::kConcatenate &&
|
||||
indices->concatenate_dimension() == index_dim) {
|
||||
for (auto* op : indices->operands()) {
|
||||
int concatenated_dims = 0;
|
||||
for (int i = 0; i < indices->operand_count(); ++i) {
|
||||
const HloInstruction* op = indices->operand(i);
|
||||
const int64 num_indices_from_element =
|
||||
op->shape().dimensions_size() > index_dim
|
||||
? op->shape().dimensions(index_dim)
|
||||
: 1;
|
||||
if (auto* iota = DynCast<HloIotaInstruction>(op)) {
|
||||
if (iota->iota_dimension() != index_dim) {
|
||||
iotas.push_back(iota);
|
||||
for (int j = 0; j < num_indices_from_element; ++j) {
|
||||
index_parallel_in_dim[concatenated_dims + j] =
|
||||
iota->iota_dimension();
|
||||
}
|
||||
}
|
||||
}
|
||||
concatenated_dims += num_indices_from_element;
|
||||
}
|
||||
} else if (auto* iota = DynCast<HloIotaInstruction>(indices)) {
|
||||
if (iota->iota_dimension() != index_dim) {
|
||||
// This is a case of a single iota with index_dim being out of bounds.
|
||||
iotas.push_back(iota);
|
||||
const int64 num_indices_from_element =
|
||||
iota->shape().dimensions_size() > index_dim
|
||||
? iota->shape().dimensions(index_dim)
|
||||
: 1;
|
||||
index_parallel_in_dim.assign(num_indices_from_element,
|
||||
iota->iota_dimension());
|
||||
}
|
||||
}
|
||||
absl::InlinedVector<int64, 1> indices_parallel_dims;
|
||||
@ -1020,29 +1092,75 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
||||
// operands and index they could have different spots in the shape because the
|
||||
// position of the index dimension in the operand is determined by
|
||||
// start_index_map.
|
||||
int index_num = 0;
|
||||
for (auto* iota : iotas) {
|
||||
int64 num_indices_from_iota = iota->shape().dimensions_size() > index_dim
|
||||
? iota->shape().dimensions(index_dim)
|
||||
: 1;
|
||||
for (int i = 0; i < num_indices_from_iota; ++i) {
|
||||
int64 index_dim = iota->iota_dimension();
|
||||
int64 operand_dim = dnums.start_index_map(index_num + i);
|
||||
// Uniquify multiple iotas concatenated on index_dim with the same
|
||||
// iota_dimension
|
||||
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
indices_parallel_dims.push_back(index_dim);
|
||||
operand_parallel_dims.push_back(operand_dim);
|
||||
for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
|
||||
int index_parallel_dim = index_parallel_in_dim[i];
|
||||
if (index_parallel_dim == -1) {
|
||||
continue;
|
||||
}
|
||||
index_num += num_indices_from_iota;
|
||||
if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
indices_parallel_dims.push_back(index_parallel_dim);
|
||||
operand_parallel_dims.push_back(dnums.start_index_map(i));
|
||||
}
|
||||
absl::c_sort(indices_parallel_dims);
|
||||
if (!indices_parallel_dims.empty()) {
|
||||
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims};
|
||||
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
|
||||
index_parallel_in_dim};
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
absl::InlinedVector<int64, 1> GatherParallelOutputDims(
|
||||
const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
|
||||
absl::InlinedVector<int64, 1> output_parallel_dims;
|
||||
auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
|
||||
const Shape gather_shape = gather.shape();
|
||||
auto dnums = gather.gather_dimension_numbers();
|
||||
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
|
||||
if (absl::c_linear_search(dnums.offset_dims(), i)) {
|
||||
continue;
|
||||
}
|
||||
const int index_dim =
|
||||
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
|
||||
if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
|
||||
output_parallel_dims.push_back(i);
|
||||
}
|
||||
++idx_dim;
|
||||
}
|
||||
return output_parallel_dims;
|
||||
}
|
||||
|
||||
absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
|
||||
const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
|
||||
absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
|
||||
parallel_dims.operand_parallel_dims.size(), -1);
|
||||
auto dnums = gather.gather_dimension_numbers();
|
||||
CHECK_LE(parallel_dims.indices_parallel_dims.size(),
|
||||
parallel_dims.operand_parallel_dims.size());
|
||||
for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
|
||||
// This is the equivalent batch dimension of the indices that corresponds
|
||||
// to this index dimension.
|
||||
const int64 index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
|
||||
// If it's not an index that is parallel skip.
|
||||
if (index_parallel_dim == -1) {
|
||||
continue;
|
||||
}
|
||||
// This is small so just look linearly. Populate the operand parallel
|
||||
// dimensions based on the order of the index batch dims (which is the same
|
||||
// order as the output).
|
||||
for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
|
||||
if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
|
||||
const int64 operand_parallel_dim = dnums.start_index_map(i);
|
||||
if (operand_parallel_dim_to_output[j] == -1) {
|
||||
operand_parallel_dim_to_output[j] = operand_parallel_dim;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return operand_parallel_dim_to_output;
|
||||
}
|
||||
|
||||
} // namespace hlo_sharding_util
|
||||
} // namespace xla
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -32,6 +33,7 @@ namespace hlo_sharding_util {
|
||||
struct GatherParallelDims {
|
||||
absl::InlinedVector<int64, 1> indices_parallel_dims;
|
||||
absl::InlinedVector<int64, 1> operand_parallel_dims;
|
||||
std::vector<int64> index_parallel_in_dim;
|
||||
};
|
||||
|
||||
// Given a map<device, occurrence_count>, selects the device with higher
|
||||
@ -188,9 +190,19 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
|
||||
|
||||
// Returns identified parallel dimensions for Gather.
|
||||
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
||||
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
|
||||
const HloInstruction& hlo);
|
||||
|
||||
// Returns the parallel dimensions of the output of a gather based on the
|
||||
// parallel dimensions of the input.
|
||||
absl::InlinedVector<int64, 1> GatherParallelOutputDims(
|
||||
const HloInstruction& gather, const GatherParallelDims& parallel_dim);
|
||||
|
||||
// Returns the parallel dimensions of the data operand of a gather with the
|
||||
// order of the parallel dimensions matching that of the parallel dimensions
|
||||
// of the output.
|
||||
absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
|
||||
const HloInstruction& gather, const GatherParallelDims& parallel_dims);
|
||||
|
||||
} // namespace hlo_sharding_util
|
||||
} // namespace xla
|
||||
|
||||
|
@ -555,6 +555,75 @@ bool InferDotShardingFromOperands(
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool InferGatherParallelShardingFromOperands(
|
||||
HloInstruction* instruction,
|
||||
const hlo_sharding_util::GatherParallelDims& parallel_dims,
|
||||
bool may_combine_partial_sharding) {
|
||||
auto from_operand = [instruction](
|
||||
int64 operand_index,
|
||||
absl::Span<const int64> output_aligned_parallel_dims,
|
||||
absl::Span<const int64> output_parallel_dims) {
|
||||
const HloInstruction* operand = instruction->operand(operand_index);
|
||||
const HloSharding& operand_sharding = operand->sharding();
|
||||
if (operand_sharding.IsTileMaximal()) {
|
||||
return operand_sharding;
|
||||
}
|
||||
auto dnums = instruction->gather_dimension_numbers();
|
||||
std::vector<int64> output_tile_dims(instruction->shape().rank(), 1);
|
||||
std::vector<int64> index_non_parallel_dims;
|
||||
index_non_parallel_dims.reserve(operand->shape().rank());
|
||||
// Detect non parallel dimensions in the index.
|
||||
for (int i = 0; i < operand->shape().rank(); ++i) {
|
||||
if (!absl::c_linear_search(output_aligned_parallel_dims, i)) {
|
||||
index_non_parallel_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
// Collect tile dimensions in the operand. The order of the parallel
|
||||
// dimensions in output_aligned_parallel_dims is the same as that of the
|
||||
// output
|
||||
for (int i = 0; i < output_aligned_parallel_dims.size(); ++i) {
|
||||
const int64 indices_idx = output_aligned_parallel_dims[i];
|
||||
const int64 output_idx = output_parallel_dims[i];
|
||||
output_tile_dims[output_idx] =
|
||||
operand_sharding.tile_assignment().dim(indices_idx);
|
||||
}
|
||||
HloSharding replicate_non_parallel_dims =
|
||||
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
|
||||
operand_sharding, index_non_parallel_dims);
|
||||
if (replicate_non_parallel_dims.ReplicateOnLastTileDim()) {
|
||||
output_tile_dims.push_back(
|
||||
replicate_non_parallel_dims.tile_assignment().dimensions().back());
|
||||
}
|
||||
auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment();
|
||||
output_tile_assignment.Reshape(output_tile_dims);
|
||||
return replicate_non_parallel_dims.ReplicateOnLastTileDim()
|
||||
? HloSharding::PartialTile(output_tile_assignment)
|
||||
: HloSharding::Tile(output_tile_assignment);
|
||||
};
|
||||
|
||||
bool changed = false;
|
||||
auto output_parallel_dims =
|
||||
hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims);
|
||||
if (IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
from_operand(
|
||||
0,
|
||||
absl::MakeConstSpan(
|
||||
hlo_sharding_util::GatherOutputAlignedOperandParallelDims(
|
||||
*instruction, parallel_dims)),
|
||||
absl::MakeConstSpan(output_parallel_dims)),
|
||||
instruction, may_combine_partial_sharding);
|
||||
}
|
||||
if (IsSpatiallyPartitioned(instruction->operand(1))) {
|
||||
changed |= MaybeImproveInstructionSharding(
|
||||
from_operand(1,
|
||||
absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
|
||||
absl::MakeConstSpan(output_parallel_dims)),
|
||||
instruction, may_combine_partial_sharding);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Convolution handling for InferShardingFromOperands().
|
||||
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
|
||||
int64 aggressiveness,
|
||||
@ -1030,6 +1099,14 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
bool changed = false;
|
||||
if (is_spmd) {
|
||||
auto gather_parallel_dims =
|
||||
hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
|
||||
if (gather_parallel_dims) {
|
||||
changed |= InferGatherParallelShardingFromOperands(
|
||||
instruction, *gather_parallel_dims, may_combine_partial_sharding);
|
||||
}
|
||||
}
|
||||
if (IsSpatiallyPartitioned(instruction->operand(1))) {
|
||||
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
||||
instruction->operand(1)->sharding(), instruction);
|
||||
|
@ -2579,5 +2579,243 @@ ENTRY %transpose {
|
||||
"{devices=[2,1,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ParallelGatherFromOperandForwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
|
||||
sharding={devices=[8,1,1,1]0,1,4,5,2,3,6,7}
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "gather"),
|
||||
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ParallelGatherFromIndexForwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
|
||||
sharding={devices=[1,8,1]0,1,4,5,2,3,6,7}
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "gather"),
|
||||
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ParallelGatherBackwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||
%copy.p = s32[8,4,2,2]{3,2,1,0} copy(%parameter.0)
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[8,4,2,2]{3,2,1,0} %copy.p,
|
||||
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2}, sharding={devices=[8,1,1,1]0,1,4,5,2,3,6,7}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "concatenate"),
|
||||
op::Sharding("{devices=[1,8,1]0,1,4,5,2,3,6,7}"));
|
||||
EXPECT_THAT(FindInstruction(module.get(), "copy.p"),
|
||||
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, ParallelGatherBackwardPass2) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
|
||||
%copy.p = s32[4,8,2,2]{3,2,1,0} copy(%parameter.0)
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[4,8,2,2]{3,2,1,0} %copy.p,
|
||||
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2}, sharding={devices=[1,4,1,1]0,1,4,5}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(FindInstruction(module.get(), "concatenate"),
|
||||
op::Sharding("{devices=[1,1,4]0,1,4,5}"));
|
||||
EXPECT_THAT(FindInstruction(module.get(), "copy.p"),
|
||||
op::Sharding("{devices=[4,1,1,1]0,1,4,5}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest,
|
||||
PartialShardingParallelGatherFromOperandForwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
|
||||
sharding={devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
FindInstruction(module.get(), "gather"),
|
||||
op::Sharding(
|
||||
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest,
|
||||
PartialShardingParallelGatherFromIndexForwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
|
||||
sharding={devices=[1,4,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
FindInstruction(module.get(), "gather"),
|
||||
op::Sharding(
|
||||
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, PartialShardingParallelGatherBackwardPass) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||
%copy.p = s32[8,4,2,2]{3,2,1,0} copy(%parameter.0)
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[8,4,2,2]{3,2,1,0} %copy.p,
|
||||
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2},
|
||||
sharding={devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
FindInstruction(module.get(), "concatenate"),
|
||||
op::Sharding(
|
||||
"{devices=[1,4,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||
EXPECT_THAT(
|
||||
FindInstruction(module.get(), "copy.p"),
|
||||
op::Sharding(
|
||||
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
TEST_F(ShardingPropagationTest, PartialShardingParallelGatherBackwardPass2) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY %module {
|
||||
%parameter.0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
|
||||
%copy.p = s32[4,8,2,2]{3,2,1,0} copy(%parameter.0)
|
||||
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||
s32[4,8,2,2]{3,2,1,0} %copy.p,
|
||||
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
|
||||
slice_sizes={1,1,2,2},
|
||||
sharding={devices=[1,2,1,1,2]0,1,4,5 last_tile_dim_replicate}
|
||||
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
EXPECT_THAT(
|
||||
FindInstruction(module.get(), "concatenate"),
|
||||
op::Sharding("{devices=[1,1,2,2]0,1,4,5 last_tile_dim_replicate}"));
|
||||
EXPECT_THAT(
|
||||
FindInstruction(module.get(), "copy.p"),
|
||||
op::Sharding("{devices=[2,1,1,1,2]0,1,4,5 last_tile_dim_replicate}"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -2753,28 +2753,13 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
||||
// If that's the case then we can partition the gather across such dimensions
|
||||
// by adjusting the offsets.
|
||||
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
|
||||
hlo_sharding_util::GetGatherBatchParallelDims(
|
||||
operand.sharding(), indices.sharding(), *hlo)) {
|
||||
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
|
||||
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
|
||||
hlo_sharding_util::GetGatherBatchParallelDims(*hlo)) {
|
||||
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
|
||||
*operand.hlo(), *indices.hlo(),
|
||||
absl::MakeConstSpan(indices_parallel_dims),
|
||||
absl::MakeConstSpan(operand_parallel_dims))) {
|
||||
absl::InlinedVector<int64, 4> output_parallel_dims;
|
||||
Shape gather_shape = gather->shape();
|
||||
absl::c_sort(indices_parallel_dims);
|
||||
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
|
||||
if (absl::c_linear_search(dnums.offset_dims(), i)) {
|
||||
continue;
|
||||
}
|
||||
int index_dim =
|
||||
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
|
||||
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
|
||||
output_parallel_dims.push_back(i);
|
||||
}
|
||||
++idx_dim;
|
||||
}
|
||||
*operand.hlo(), *indices.hlo(), *parallel_dims)) {
|
||||
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
|
||||
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
|
||||
auto output_parallel_dims =
|
||||
hlo_sharding_util::GatherParallelOutputDims(*hlo, *parallel_dims);
|
||||
HloSharding indices_sharding = gather_sharding->indices_sharding;
|
||||
HloSharding operand_sharding = gather_sharding->operand_sharding;
|
||||
if (operand_sharding.NumTiles() ==
|
||||
|
@ -1830,8 +1830,9 @@ HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
|
||||
absl::optional<GatherParallelDimSharding>
|
||||
GatherOperandsShardedAcrossParallelDims(
|
||||
const HloInstruction& operand, const HloInstruction& indices,
|
||||
absl::Span<const int64> indices_parallel_dims,
|
||||
absl::Span<const int64> operand_parallel_dims) {
|
||||
const hlo_sharding_util::GatherParallelDims& parallel_dims) {
|
||||
auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
|
||||
auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
|
||||
if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
@ -1842,25 +1843,32 @@ GatherOperandsShardedAcrossParallelDims(
|
||||
if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
|
||||
for (int idx : parallel_dims.index_parallel_in_dim) {
|
||||
if (idx != -1) {
|
||||
indices_parallel_dims_ordered_as_operand.push_back(idx);
|
||||
}
|
||||
}
|
||||
if (new_index_shard.IsReplicated()) {
|
||||
return GatherParallelDimSharding{
|
||||
CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
|
||||
indices_parallel_dims,
|
||||
indices_parallel_dims_ordered_as_operand,
|
||||
operand_parallel_dims),
|
||||
new_operand_shard};
|
||||
}
|
||||
if (new_operand_shard.IsReplicated()) {
|
||||
return GatherParallelDimSharding{
|
||||
new_index_shard, CreateMatchingShardingOnDims(
|
||||
operand.shape(), new_index_shard,
|
||||
operand_parallel_dims, indices_parallel_dims)};
|
||||
new_index_shard,
|
||||
CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
|
||||
operand_parallel_dims,
|
||||
indices_parallel_dims_ordered_as_operand)};
|
||||
}
|
||||
|
||||
// Parallel dimension distribution needs to be the same, so try to steal
|
||||
// sharding from partial replication to compensate.
|
||||
if (idx_parallel_tiles_num != op_parallel_tiles_num) {
|
||||
auto to_adjust_dims = operand_parallel_dims;
|
||||
auto target_dims = indices_parallel_dims;
|
||||
auto target_dims = indices_parallel_dims_ordered_as_operand;
|
||||
HloSharding* target = &new_index_shard;
|
||||
HloSharding* to_adjust = &new_operand_shard;
|
||||
if (idx_parallel_tiles_num < op_parallel_tiles_num) {
|
||||
@ -1908,17 +1916,19 @@ GatherOperandsShardedAcrossParallelDims(
|
||||
// Make sure that the parallel dimensions are aligned.
|
||||
auto operand_shard_tile_dims =
|
||||
new_operand_shard.tile_assignment().dimensions();
|
||||
for (int i = 0; i < indices_parallel_dims.size(); ++i) {
|
||||
for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
|
||||
operand_shard_tile_dims[operand_parallel_dims[i]] =
|
||||
new_index_shard.tile_assignment().dim(indices_parallel_dims[i]);
|
||||
new_index_shard.tile_assignment().dim(
|
||||
indices_parallel_dims_ordered_as_operand[i]);
|
||||
}
|
||||
auto operand_shard_tiles = new_operand_shard.tile_assignment();
|
||||
operand_shard_tiles.Reshape(operand_shard_tile_dims);
|
||||
new_operand_shard = AlignShardingOnDims(
|
||||
new_operand_shard.ReplicateOnLastTileDim()
|
||||
? HloSharding::PartialTile(operand_shard_tiles)
|
||||
: HloSharding::Tile(operand_shard_tiles),
|
||||
operand_parallel_dims, new_index_shard, indices_parallel_dims);
|
||||
new_operand_shard =
|
||||
AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
|
||||
? HloSharding::PartialTile(operand_shard_tiles)
|
||||
: HloSharding::Tile(operand_shard_tiles),
|
||||
operand_parallel_dims, new_index_shard,
|
||||
indices_parallel_dims_ordered_as_operand);
|
||||
return GatherParallelDimSharding{new_index_shard, new_operand_shard};
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||
|
||||
namespace xla {
|
||||
@ -414,8 +415,7 @@ HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
|
||||
absl::optional<GatherParallelDimSharding>
|
||||
GatherOperandsShardedAcrossParallelDims(
|
||||
const HloInstruction& operand, const HloInstruction& indices,
|
||||
absl::Span<const int64> indices_parallel_dims,
|
||||
absl::Span<const int64> operand_parallel_dims);
|
||||
const hlo_sharding_util::GatherParallelDims& parallel_dims);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user