[XLA] Add sharding propagation support for Gather sharding on parallel dimensions.
Should enable sharding of tf.reverse_sequence() and gradient. PiperOrigin-RevId: 352154964 Change-Id: Ifa953ae5068ebda61ef0d1d1b0438362afcb7e67
This commit is contained in:
parent
1d0e1df48a
commit
85bf96f508
@ -489,6 +489,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,11 +15,13 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/xla/array.h"
|
#include "tensorflow/compiler/xla/array.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
@ -700,6 +702,54 @@ absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
|
|||||||
: HloSharding::Tile(tile_assignment);
|
: HloSharding::Tile(tile_assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Collect data operand sharding for a gather with parallel dimensions from
|
||||||
|
// the output.
|
||||||
|
absl::optional<HloSharding> GatherParallelDataOperandSharding(
|
||||||
|
const HloSharding& output_sharding, const HloInstruction& gather) {
|
||||||
|
if (output_sharding.IsTileMaximal()) {
|
||||||
|
return output_sharding;
|
||||||
|
}
|
||||||
|
auto parallel_dims = GetGatherBatchParallelDims(gather);
|
||||||
|
if (!parallel_dims) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
auto output_parallel_dims = GatherParallelOutputDims(gather, *parallel_dims);
|
||||||
|
auto output_aligned_operand_parallel_dims =
|
||||||
|
GatherOutputAlignedOperandParallelDims(gather, *parallel_dims);
|
||||||
|
const Shape gather_shape = gather.shape();
|
||||||
|
CHECK_EQ(output_parallel_dims.size(),
|
||||||
|
output_aligned_operand_parallel_dims.size());
|
||||||
|
std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
|
||||||
|
1);
|
||||||
|
for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
|
||||||
|
if (parallel_idx >= output_parallel_dims.size() ||
|
||||||
|
output_parallel_dims[parallel_idx] != i) {
|
||||||
|
// Support only the case where the output dimensions are sharded only
|
||||||
|
// across parallel dimensions.
|
||||||
|
if (output_sharding.tile_assignment().dim(i) != 1) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int64 operand_dim =
|
||||||
|
output_aligned_operand_parallel_dims[parallel_idx++];
|
||||||
|
operand_tile_assignment[operand_dim] =
|
||||||
|
output_sharding.tile_assignment().dim(i);
|
||||||
|
}
|
||||||
|
if (output_sharding.ReplicateOnLastTileDim()) {
|
||||||
|
operand_tile_assignment.push_back(
|
||||||
|
output_sharding.tile_assignment().dimensions().back());
|
||||||
|
}
|
||||||
|
Array<int64> tile_assignment = output_sharding.tile_assignment();
|
||||||
|
if (tile_assignment.num_elements() != Product(operand_tile_assignment)) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
tile_assignment.Reshape(operand_tile_assignment);
|
||||||
|
return output_sharding.ReplicateOnLastTileDim()
|
||||||
|
? HloSharding::PartialTile(tile_assignment)
|
||||||
|
: HloSharding::Tile(tile_assignment);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
|
absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
|
||||||
@ -726,6 +776,12 @@ absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
|
|||||||
dnums.start_index_map().end());
|
dnums.start_index_map().end());
|
||||||
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
|
std::vector<int64> offset_dims(dnums.offset_dims().begin(),
|
||||||
dnums.offset_dims().end());
|
dnums.offset_dims().end());
|
||||||
|
// Prioritize parallel sharding first as this is how it is in
|
||||||
|
// spmd_partitioner.
|
||||||
|
if (auto parallel_sharding =
|
||||||
|
GatherParallelDataOperandSharding(hlo.sharding(), hlo)) {
|
||||||
|
return parallel_sharding;
|
||||||
|
}
|
||||||
return PassthroughGatherOutputOrScatterUpdateToOperand(
|
return PassthroughGatherOutputOrScatterUpdateToOperand(
|
||||||
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
|
hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
|
||||||
start_index_map, offset_dims, hlo.gather_slice_sizes());
|
start_index_map, offset_dims, hlo.gather_slice_sizes());
|
||||||
@ -981,7 +1037,6 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
|
|||||||
}
|
}
|
||||||
|
|
||||||
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
||||||
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
|
|
||||||
const HloInstruction& hlo) {
|
const HloInstruction& hlo) {
|
||||||
const auto& dnums = hlo.gather_dimension_numbers();
|
const auto& dnums = hlo.gather_dimension_numbers();
|
||||||
int64 index_dim = dnums.index_vector_dim();
|
int64 index_dim = dnums.index_vector_dim();
|
||||||
@ -997,20 +1052,37 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
|||||||
// is common for tf.reverse_sequence and would match this case.
|
// is common for tf.reverse_sequence and would match this case.
|
||||||
absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
|
absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
|
||||||
const HloInstruction* indices = hlo.operand(1);
|
const HloInstruction* indices = hlo.operand(1);
|
||||||
|
const int num_indices = dnums.start_index_map_size();
|
||||||
|
std::vector<int64> index_parallel_in_dim(num_indices, -1);
|
||||||
// Handle cases where we concatenate pieces of the indices one at a time.
|
// Handle cases where we concatenate pieces of the indices one at a time.
|
||||||
if (indices->opcode() == HloOpcode::kConcatenate &&
|
if (indices->opcode() == HloOpcode::kConcatenate &&
|
||||||
indices->concatenate_dimension() == index_dim) {
|
indices->concatenate_dimension() == index_dim) {
|
||||||
for (auto* op : indices->operands()) {
|
int concatenated_dims = 0;
|
||||||
|
for (int i = 0; i < indices->operand_count(); ++i) {
|
||||||
|
const HloInstruction* op = indices->operand(i);
|
||||||
|
const int64 num_indices_from_element =
|
||||||
|
op->shape().dimensions_size() > index_dim
|
||||||
|
? op->shape().dimensions(index_dim)
|
||||||
|
: 1;
|
||||||
if (auto* iota = DynCast<HloIotaInstruction>(op)) {
|
if (auto* iota = DynCast<HloIotaInstruction>(op)) {
|
||||||
if (iota->iota_dimension() != index_dim) {
|
if (iota->iota_dimension() != index_dim) {
|
||||||
iotas.push_back(iota);
|
for (int j = 0; j < num_indices_from_element; ++j) {
|
||||||
|
index_parallel_in_dim[concatenated_dims + j] =
|
||||||
|
iota->iota_dimension();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
concatenated_dims += num_indices_from_element;
|
||||||
|
}
|
||||||
} else if (auto* iota = DynCast<HloIotaInstruction>(indices)) {
|
} else if (auto* iota = DynCast<HloIotaInstruction>(indices)) {
|
||||||
if (iota->iota_dimension() != index_dim) {
|
if (iota->iota_dimension() != index_dim) {
|
||||||
// This is a case of a single iota with index_dim being out of bounds.
|
// This is a case of a single iota with index_dim being out of bounds.
|
||||||
iotas.push_back(iota);
|
const int64 num_indices_from_element =
|
||||||
|
iota->shape().dimensions_size() > index_dim
|
||||||
|
? iota->shape().dimensions(index_dim)
|
||||||
|
: 1;
|
||||||
|
index_parallel_in_dim.assign(num_indices_from_element,
|
||||||
|
iota->iota_dimension());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
absl::InlinedVector<int64, 1> indices_parallel_dims;
|
absl::InlinedVector<int64, 1> indices_parallel_dims;
|
||||||
@ -1020,29 +1092,75 @@ absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
|||||||
// operands and index they could have different spots in the shape because the
|
// operands and index they could have different spots in the shape because the
|
||||||
// position of the index dimension in the operand is determined by
|
// position of the index dimension in the operand is determined by
|
||||||
// start_index_map.
|
// start_index_map.
|
||||||
int index_num = 0;
|
for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
|
||||||
for (auto* iota : iotas) {
|
int index_parallel_dim = index_parallel_in_dim[i];
|
||||||
int64 num_indices_from_iota = iota->shape().dimensions_size() > index_dim
|
if (index_parallel_dim == -1) {
|
||||||
? iota->shape().dimensions(index_dim)
|
continue;
|
||||||
: 1;
|
}
|
||||||
for (int i = 0; i < num_indices_from_iota; ++i) {
|
if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
|
||||||
int64 index_dim = iota->iota_dimension();
|
|
||||||
int64 operand_dim = dnums.start_index_map(index_num + i);
|
|
||||||
// Uniquify multiple iotas concatenated on index_dim with the same
|
|
||||||
// iota_dimension
|
|
||||||
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
|
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
indices_parallel_dims.push_back(index_dim);
|
indices_parallel_dims.push_back(index_parallel_dim);
|
||||||
operand_parallel_dims.push_back(operand_dim);
|
operand_parallel_dims.push_back(dnums.start_index_map(i));
|
||||||
}
|
|
||||||
index_num += num_indices_from_iota;
|
|
||||||
}
|
}
|
||||||
|
absl::c_sort(indices_parallel_dims);
|
||||||
if (!indices_parallel_dims.empty()) {
|
if (!indices_parallel_dims.empty()) {
|
||||||
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims};
|
return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
|
||||||
|
index_parallel_in_dim};
|
||||||
}
|
}
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::InlinedVector<int64, 1> GatherParallelOutputDims(
|
||||||
|
const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
|
||||||
|
absl::InlinedVector<int64, 1> output_parallel_dims;
|
||||||
|
auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
|
||||||
|
const Shape gather_shape = gather.shape();
|
||||||
|
auto dnums = gather.gather_dimension_numbers();
|
||||||
|
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
|
||||||
|
if (absl::c_linear_search(dnums.offset_dims(), i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int index_dim =
|
||||||
|
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
|
||||||
|
if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
|
||||||
|
output_parallel_dims.push_back(i);
|
||||||
|
}
|
||||||
|
++idx_dim;
|
||||||
|
}
|
||||||
|
return output_parallel_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
|
||||||
|
const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
|
||||||
|
absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
|
||||||
|
parallel_dims.operand_parallel_dims.size(), -1);
|
||||||
|
auto dnums = gather.gather_dimension_numbers();
|
||||||
|
CHECK_LE(parallel_dims.indices_parallel_dims.size(),
|
||||||
|
parallel_dims.operand_parallel_dims.size());
|
||||||
|
for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
|
||||||
|
// This is the equivalent batch dimension of the indices that corresponds
|
||||||
|
// to this index dimension.
|
||||||
|
const int64 index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
|
||||||
|
// If it's not an index that is parallel skip.
|
||||||
|
if (index_parallel_dim == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// This is small so just look linearly. Populate the operand parallel
|
||||||
|
// dimensions based on the order of the index batch dims (which is the same
|
||||||
|
// order as the output).
|
||||||
|
for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
|
||||||
|
if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
|
||||||
|
const int64 operand_parallel_dim = dnums.start_index_map(i);
|
||||||
|
if (operand_parallel_dim_to_output[j] == -1) {
|
||||||
|
operand_parallel_dim_to_output[j] = operand_parallel_dim;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return operand_parallel_dim_to_output;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace hlo_sharding_util
|
} // namespace hlo_sharding_util
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -32,6 +33,7 @@ namespace hlo_sharding_util {
|
|||||||
struct GatherParallelDims {
|
struct GatherParallelDims {
|
||||||
absl::InlinedVector<int64, 1> indices_parallel_dims;
|
absl::InlinedVector<int64, 1> indices_parallel_dims;
|
||||||
absl::InlinedVector<int64, 1> operand_parallel_dims;
|
absl::InlinedVector<int64, 1> operand_parallel_dims;
|
||||||
|
std::vector<int64> index_parallel_in_dim;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Given a map<device, occurrence_count>, selects the device with higher
|
// Given a map<device, occurrence_count>, selects the device with higher
|
||||||
@ -188,9 +190,19 @@ absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
|
|||||||
|
|
||||||
// Returns identified parallel dimensions for Gather.
|
// Returns identified parallel dimensions for Gather.
|
||||||
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
|
||||||
const HloSharding& operand_sharding, const HloSharding& indices_sharding,
|
|
||||||
const HloInstruction& hlo);
|
const HloInstruction& hlo);
|
||||||
|
|
||||||
|
// Returns the parallel dimensions of the output of a gather based on the
|
||||||
|
// parallel dimensions of the input.
|
||||||
|
absl::InlinedVector<int64, 1> GatherParallelOutputDims(
|
||||||
|
const HloInstruction& gather, const GatherParallelDims& parallel_dim);
|
||||||
|
|
||||||
|
// Returns the parallel dimensions of the data operand of a gather with the
|
||||||
|
// order of the parallel dimensions matching that of the parallel dimensions
|
||||||
|
// of the output.
|
||||||
|
absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
|
||||||
|
const HloInstruction& gather, const GatherParallelDims& parallel_dims);
|
||||||
|
|
||||||
} // namespace hlo_sharding_util
|
} // namespace hlo_sharding_util
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
|||||||
@ -555,6 +555,75 @@ bool InferDotShardingFromOperands(
|
|||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool InferGatherParallelShardingFromOperands(
|
||||||
|
HloInstruction* instruction,
|
||||||
|
const hlo_sharding_util::GatherParallelDims& parallel_dims,
|
||||||
|
bool may_combine_partial_sharding) {
|
||||||
|
auto from_operand = [instruction](
|
||||||
|
int64 operand_index,
|
||||||
|
absl::Span<const int64> output_aligned_parallel_dims,
|
||||||
|
absl::Span<const int64> output_parallel_dims) {
|
||||||
|
const HloInstruction* operand = instruction->operand(operand_index);
|
||||||
|
const HloSharding& operand_sharding = operand->sharding();
|
||||||
|
if (operand_sharding.IsTileMaximal()) {
|
||||||
|
return operand_sharding;
|
||||||
|
}
|
||||||
|
auto dnums = instruction->gather_dimension_numbers();
|
||||||
|
std::vector<int64> output_tile_dims(instruction->shape().rank(), 1);
|
||||||
|
std::vector<int64> index_non_parallel_dims;
|
||||||
|
index_non_parallel_dims.reserve(operand->shape().rank());
|
||||||
|
// Detect non parallel dimensions in the index.
|
||||||
|
for (int i = 0; i < operand->shape().rank(); ++i) {
|
||||||
|
if (!absl::c_linear_search(output_aligned_parallel_dims, i)) {
|
||||||
|
index_non_parallel_dims.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Collect tile dimensions in the operand. The order of the parallel
|
||||||
|
// dimensions in output_aligned_parallel_dims is the same as that of the
|
||||||
|
// output
|
||||||
|
for (int i = 0; i < output_aligned_parallel_dims.size(); ++i) {
|
||||||
|
const int64 indices_idx = output_aligned_parallel_dims[i];
|
||||||
|
const int64 output_idx = output_parallel_dims[i];
|
||||||
|
output_tile_dims[output_idx] =
|
||||||
|
operand_sharding.tile_assignment().dim(indices_idx);
|
||||||
|
}
|
||||||
|
HloSharding replicate_non_parallel_dims =
|
||||||
|
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
|
||||||
|
operand_sharding, index_non_parallel_dims);
|
||||||
|
if (replicate_non_parallel_dims.ReplicateOnLastTileDim()) {
|
||||||
|
output_tile_dims.push_back(
|
||||||
|
replicate_non_parallel_dims.tile_assignment().dimensions().back());
|
||||||
|
}
|
||||||
|
auto output_tile_assignment = replicate_non_parallel_dims.tile_assignment();
|
||||||
|
output_tile_assignment.Reshape(output_tile_dims);
|
||||||
|
return replicate_non_parallel_dims.ReplicateOnLastTileDim()
|
||||||
|
? HloSharding::PartialTile(output_tile_assignment)
|
||||||
|
: HloSharding::Tile(output_tile_assignment);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool changed = false;
|
||||||
|
auto output_parallel_dims =
|
||||||
|
hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims);
|
||||||
|
if (IsSpatiallyPartitioned(instruction->operand(0))) {
|
||||||
|
changed |= MaybeImproveInstructionSharding(
|
||||||
|
from_operand(
|
||||||
|
0,
|
||||||
|
absl::MakeConstSpan(
|
||||||
|
hlo_sharding_util::GatherOutputAlignedOperandParallelDims(
|
||||||
|
*instruction, parallel_dims)),
|
||||||
|
absl::MakeConstSpan(output_parallel_dims)),
|
||||||
|
instruction, may_combine_partial_sharding);
|
||||||
|
}
|
||||||
|
if (IsSpatiallyPartitioned(instruction->operand(1))) {
|
||||||
|
changed |= MaybeImproveInstructionSharding(
|
||||||
|
from_operand(1,
|
||||||
|
absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
|
||||||
|
absl::MakeConstSpan(output_parallel_dims)),
|
||||||
|
instruction, may_combine_partial_sharding);
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
// Convolution handling for InferShardingFromOperands().
|
// Convolution handling for InferShardingFromOperands().
|
||||||
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
|
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
|
||||||
int64 aggressiveness,
|
int64 aggressiveness,
|
||||||
@ -1030,6 +1099,14 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
|||||||
}
|
}
|
||||||
case HloOpcode::kGather: {
|
case HloOpcode::kGather: {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
if (is_spmd) {
|
||||||
|
auto gather_parallel_dims =
|
||||||
|
hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
|
||||||
|
if (gather_parallel_dims) {
|
||||||
|
changed |= InferGatherParallelShardingFromOperands(
|
||||||
|
instruction, *gather_parallel_dims, may_combine_partial_sharding);
|
||||||
|
}
|
||||||
|
}
|
||||||
if (IsSpatiallyPartitioned(instruction->operand(1))) {
|
if (IsSpatiallyPartitioned(instruction->operand(1))) {
|
||||||
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
|
||||||
instruction->operand(1)->sharding(), instruction);
|
instruction->operand(1)->sharding(), instruction);
|
||||||
|
|||||||
@ -2579,5 +2579,243 @@ ENTRY %transpose {
|
|||||||
"{devices=[2,1,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
"{devices=[2,1,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, ParallelGatherFromOperandForwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
|
||||||
|
sharding={devices=[8,1,1,1]0,1,4,5,2,3,6,7}
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "gather"),
|
||||||
|
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, ParallelGatherFromIndexForwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
|
||||||
|
sharding={devices=[1,8,1]0,1,4,5,2,3,6,7}
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "gather"),
|
||||||
|
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, ParallelGatherBackwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||||
|
%copy.p = s32[8,4,2,2]{3,2,1,0} copy(%parameter.0)
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[8,4,2,2]{3,2,1,0} %copy.p,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2}, sharding={devices=[8,1,1,1]0,1,4,5,2,3,6,7}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "concatenate"),
|
||||||
|
op::Sharding("{devices=[1,8,1]0,1,4,5,2,3,6,7}"));
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "copy.p"),
|
||||||
|
op::Sharding("{devices=[8,1,1,1]0,1,4,5,2,3,6,7}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, ParallelGatherBackwardPass2) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
|
||||||
|
%copy.p = s32[4,8,2,2]{3,2,1,0} copy(%parameter.0)
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[4,8,2,2]{3,2,1,0} %copy.p,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2}, sharding={devices=[1,4,1,1]0,1,4,5}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "concatenate"),
|
||||||
|
op::Sharding("{devices=[1,1,4]0,1,4,5}"));
|
||||||
|
EXPECT_THAT(FindInstruction(module.get(), "copy.p"),
|
||||||
|
op::Sharding("{devices=[4,1,1,1]0,1,4,5}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest,
|
||||||
|
PartialShardingParallelGatherFromOperandForwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
|
||||||
|
sharding={devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(
|
||||||
|
FindInstruction(module.get(), "gather"),
|
||||||
|
op::Sharding(
|
||||||
|
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest,
|
||||||
|
PartialShardingParallelGatherFromIndexForwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
|
||||||
|
sharding={devices=[1,4,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[8,4,2,2]{3,2,1,0} %parameter.0,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(
|
||||||
|
FindInstruction(module.get(), "gather"),
|
||||||
|
op::Sharding(
|
||||||
|
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, PartialShardingParallelGatherBackwardPass) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0)
|
||||||
|
%copy.p = s32[8,4,2,2]{3,2,1,0} copy(%parameter.0)
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[8,4,2,2]{3,2,1,0} %copy.p,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2},
|
||||||
|
sharding={devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(
|
||||||
|
FindInstruction(module.get(), "concatenate"),
|
||||||
|
op::Sharding(
|
||||||
|
"{devices=[1,4,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
FindInstruction(module.get(), "copy.p"),
|
||||||
|
op::Sharding(
|
||||||
|
"{devices=[4,1,1,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ShardingPropagationTest, PartialShardingParallelGatherBackwardPass2) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY %module {
|
||||||
|
%parameter.0 = s32[4,8,2,2]{3,2,1,0} parameter(0)
|
||||||
|
%copy.p = s32[4,8,2,2]{3,2,1,0} copy(%parameter.0)
|
||||||
|
%iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1
|
||||||
|
%iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2
|
||||||
|
%concatenate = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
|
||||||
|
s32[1,8,4]{2,1,0} %iota2), dimensions={0}
|
||||||
|
%gather = s32[8,4,2,2]{3,2,1,0} gather(
|
||||||
|
s32[4,8,2,2]{3,2,1,0} %copy.p,
|
||||||
|
s32[2,8,4]{2,1,0} %concatenate), offset_dims={2,3},
|
||||||
|
collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
|
||||||
|
slice_sizes={1,1,2,2},
|
||||||
|
sharding={devices=[1,2,1,1,2]0,1,4,5 last_tile_dim_replicate}
|
||||||
|
ROOT %copy = s32[8,4,2,2]{3,2,1,0} copy(%gather)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get()));
|
||||||
|
EXPECT_TRUE(changed);
|
||||||
|
EXPECT_THAT(
|
||||||
|
FindInstruction(module.get(), "concatenate"),
|
||||||
|
op::Sharding("{devices=[1,1,2,2]0,1,4,5 last_tile_dim_replicate}"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
FindInstruction(module.get(), "copy.p"),
|
||||||
|
op::Sharding("{devices=[2,1,1,1,2]0,1,4,5 last_tile_dim_replicate}"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -2753,28 +2753,13 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
|||||||
// If that's the case then we can partition the gather across such dimensions
|
// If that's the case then we can partition the gather across such dimensions
|
||||||
// by adjusting the offsets.
|
// by adjusting the offsets.
|
||||||
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
|
if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims =
|
||||||
hlo_sharding_util::GetGatherBatchParallelDims(
|
hlo_sharding_util::GetGatherBatchParallelDims(*hlo)) {
|
||||||
operand.sharding(), indices.sharding(), *hlo)) {
|
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
|
||||||
|
*operand.hlo(), *indices.hlo(), *parallel_dims)) {
|
||||||
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
|
auto indices_parallel_dims = parallel_dims->indices_parallel_dims;
|
||||||
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
|
auto operand_parallel_dims = parallel_dims->operand_parallel_dims;
|
||||||
if (auto gather_sharding = GatherOperandsShardedAcrossParallelDims(
|
auto output_parallel_dims =
|
||||||
*operand.hlo(), *indices.hlo(),
|
hlo_sharding_util::GatherParallelOutputDims(*hlo, *parallel_dims);
|
||||||
absl::MakeConstSpan(indices_parallel_dims),
|
|
||||||
absl::MakeConstSpan(operand_parallel_dims))) {
|
|
||||||
absl::InlinedVector<int64, 4> output_parallel_dims;
|
|
||||||
Shape gather_shape = gather->shape();
|
|
||||||
absl::c_sort(indices_parallel_dims);
|
|
||||||
for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
|
|
||||||
if (absl::c_linear_search(dnums.offset_dims(), i)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int index_dim =
|
|
||||||
idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
|
|
||||||
if (absl::c_linear_search(indices_parallel_dims, index_dim)) {
|
|
||||||
output_parallel_dims.push_back(i);
|
|
||||||
}
|
|
||||||
++idx_dim;
|
|
||||||
}
|
|
||||||
HloSharding indices_sharding = gather_sharding->indices_sharding;
|
HloSharding indices_sharding = gather_sharding->indices_sharding;
|
||||||
HloSharding operand_sharding = gather_sharding->operand_sharding;
|
HloSharding operand_sharding = gather_sharding->operand_sharding;
|
||||||
if (operand_sharding.NumTiles() ==
|
if (operand_sharding.NumTiles() ==
|
||||||
|
|||||||
@ -1830,8 +1830,9 @@ HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
|
|||||||
absl::optional<GatherParallelDimSharding>
|
absl::optional<GatherParallelDimSharding>
|
||||||
GatherOperandsShardedAcrossParallelDims(
|
GatherOperandsShardedAcrossParallelDims(
|
||||||
const HloInstruction& operand, const HloInstruction& indices,
|
const HloInstruction& operand, const HloInstruction& indices,
|
||||||
absl::Span<const int64> indices_parallel_dims,
|
const hlo_sharding_util::GatherParallelDims& parallel_dims) {
|
||||||
absl::Span<const int64> operand_parallel_dims) {
|
auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
|
||||||
|
auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
|
||||||
if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
|
if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
@ -1842,25 +1843,32 @@ GatherOperandsShardedAcrossParallelDims(
|
|||||||
if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
|
if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
|
||||||
|
for (int idx : parallel_dims.index_parallel_in_dim) {
|
||||||
|
if (idx != -1) {
|
||||||
|
indices_parallel_dims_ordered_as_operand.push_back(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
if (new_index_shard.IsReplicated()) {
|
if (new_index_shard.IsReplicated()) {
|
||||||
return GatherParallelDimSharding{
|
return GatherParallelDimSharding{
|
||||||
CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
|
CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
|
||||||
indices_parallel_dims,
|
indices_parallel_dims_ordered_as_operand,
|
||||||
operand_parallel_dims),
|
operand_parallel_dims),
|
||||||
new_operand_shard};
|
new_operand_shard};
|
||||||
}
|
}
|
||||||
if (new_operand_shard.IsReplicated()) {
|
if (new_operand_shard.IsReplicated()) {
|
||||||
return GatherParallelDimSharding{
|
return GatherParallelDimSharding{
|
||||||
new_index_shard, CreateMatchingShardingOnDims(
|
new_index_shard,
|
||||||
operand.shape(), new_index_shard,
|
CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
|
||||||
operand_parallel_dims, indices_parallel_dims)};
|
operand_parallel_dims,
|
||||||
|
indices_parallel_dims_ordered_as_operand)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel dimension distribution needs to be the same, so try to steal
|
// Parallel dimension distribution needs to be the same, so try to steal
|
||||||
// sharding from partial replication to compensate.
|
// sharding from partial replication to compensate.
|
||||||
if (idx_parallel_tiles_num != op_parallel_tiles_num) {
|
if (idx_parallel_tiles_num != op_parallel_tiles_num) {
|
||||||
auto to_adjust_dims = operand_parallel_dims;
|
auto to_adjust_dims = operand_parallel_dims;
|
||||||
auto target_dims = indices_parallel_dims;
|
auto target_dims = indices_parallel_dims_ordered_as_operand;
|
||||||
HloSharding* target = &new_index_shard;
|
HloSharding* target = &new_index_shard;
|
||||||
HloSharding* to_adjust = &new_operand_shard;
|
HloSharding* to_adjust = &new_operand_shard;
|
||||||
if (idx_parallel_tiles_num < op_parallel_tiles_num) {
|
if (idx_parallel_tiles_num < op_parallel_tiles_num) {
|
||||||
@ -1908,17 +1916,19 @@ GatherOperandsShardedAcrossParallelDims(
|
|||||||
// Make sure that the parallel dimensions are aligned.
|
// Make sure that the parallel dimensions are aligned.
|
||||||
auto operand_shard_tile_dims =
|
auto operand_shard_tile_dims =
|
||||||
new_operand_shard.tile_assignment().dimensions();
|
new_operand_shard.tile_assignment().dimensions();
|
||||||
for (int i = 0; i < indices_parallel_dims.size(); ++i) {
|
for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
|
||||||
operand_shard_tile_dims[operand_parallel_dims[i]] =
|
operand_shard_tile_dims[operand_parallel_dims[i]] =
|
||||||
new_index_shard.tile_assignment().dim(indices_parallel_dims[i]);
|
new_index_shard.tile_assignment().dim(
|
||||||
|
indices_parallel_dims_ordered_as_operand[i]);
|
||||||
}
|
}
|
||||||
auto operand_shard_tiles = new_operand_shard.tile_assignment();
|
auto operand_shard_tiles = new_operand_shard.tile_assignment();
|
||||||
operand_shard_tiles.Reshape(operand_shard_tile_dims);
|
operand_shard_tiles.Reshape(operand_shard_tile_dims);
|
||||||
new_operand_shard = AlignShardingOnDims(
|
new_operand_shard =
|
||||||
new_operand_shard.ReplicateOnLastTileDim()
|
AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
|
||||||
? HloSharding::PartialTile(operand_shard_tiles)
|
? HloSharding::PartialTile(operand_shard_tiles)
|
||||||
: HloSharding::Tile(operand_shard_tiles),
|
: HloSharding::Tile(operand_shard_tiles),
|
||||||
operand_parallel_dims, new_index_shard, indices_parallel_dims);
|
operand_parallel_dims, new_index_shard,
|
||||||
|
indices_parallel_dims_ordered_as_operand);
|
||||||
return GatherParallelDimSharding{new_index_shard, new_operand_shard};
|
return GatherParallelDimSharding{new_index_shard, new_operand_shard};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -414,8 +415,7 @@ HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
|
|||||||
absl::optional<GatherParallelDimSharding>
|
absl::optional<GatherParallelDimSharding>
|
||||||
GatherOperandsShardedAcrossParallelDims(
|
GatherOperandsShardedAcrossParallelDims(
|
||||||
const HloInstruction& operand, const HloInstruction& indices,
|
const HloInstruction& operand, const HloInstruction& indices,
|
||||||
absl::Span<const int64> indices_parallel_dims,
|
const hlo_sharding_util::GatherParallelDims& parallel_dims);
|
||||||
absl::Span<const int64> operand_parallel_dims);
|
|
||||||
|
|
||||||
} // namespace spmd
|
} // namespace spmd
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user