[XLA] Avoid sharding propagation on trivial dims for dot-general convolution
A dimension with size 1 can be treated as a batch dimension, but it doesn't make much sense to shard it. PiperOrigin-RevId: 315840599 Change-Id: I34946a94f6a921468a1a3c1f2fbb4fcbcc506c97
This commit is contained in:
parent
262ac8ec97
commit
1c064ab760
@ -359,6 +359,140 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction,
|
||||
}
|
||||
}
|
||||
|
||||
// Convolution handling for InferShardingFromOperands().
|
||||
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
|
||||
bool aggressive_prop) {
|
||||
const auto& dnums = instruction->convolution_dimension_numbers();
|
||||
const HloInstruction* lhs = instruction->operand(0);
|
||||
const HloInstruction* rhs = instruction->operand(1);
|
||||
auto get_tiled_sharding_based_on_lhs = [&] {
|
||||
CHECK(!lhs->sharding().IsTileMaximal());
|
||||
std::vector<int64> output_to_lhs_indices(instruction->shape().rank());
|
||||
output_to_lhs_indices[dnums.output_batch_dimension()] =
|
||||
dnums.input_batch_dimension();
|
||||
output_to_lhs_indices[dnums.output_feature_dimension()] =
|
||||
dnums.input_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
|
||||
dnums.input_spatial_dimensions(i);
|
||||
}
|
||||
return hlo_sharding_util::TransposeSharding(lhs->sharding(),
|
||||
output_to_lhs_indices);
|
||||
};
|
||||
auto get_tiled_sharding_based_on_rhs = [&] {
|
||||
CHECK(!rhs->sharding().IsTileMaximal());
|
||||
std::vector<int64> output_to_rhs_indices(instruction->shape().rank());
|
||||
output_to_rhs_indices[dnums.output_batch_dimension()] =
|
||||
dnums.kernel_input_feature_dimension();
|
||||
output_to_rhs_indices[dnums.output_feature_dimension()] =
|
||||
dnums.kernel_output_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
output_to_rhs_indices[dnums.output_spatial_dimensions(i)] =
|
||||
dnums.kernel_spatial_dimensions(i);
|
||||
}
|
||||
return hlo_sharding_util::TransposeSharding(rhs->sharding(),
|
||||
output_to_rhs_indices);
|
||||
};
|
||||
if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution(
|
||||
instruction)) {
|
||||
// lhs_or_rhs: lhs is 0 and rhs is 1. Skips dimensions with size 1.
|
||||
auto partitioned_only_along_non_trivial_dims =
|
||||
[&](const HloSharding& sharding,
|
||||
std::vector<dot_as_convolution_util::
|
||||
DotGeneralAsConvolutionDimsInfo::DimNums>& dims,
|
||||
int64 lhs_or_rhs) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return false;
|
||||
}
|
||||
int64 partition_count = 1;
|
||||
for (const auto& dim : dims) {
|
||||
if (lhs_or_rhs == 0) {
|
||||
if (lhs->shape().dimensions(dim.lhs) == 1) {
|
||||
continue;
|
||||
}
|
||||
partition_count *= sharding.tile_assignment().dim(dim.lhs);
|
||||
} else {
|
||||
if (rhs->shape().dimensions(dim.rhs) == 1) {
|
||||
continue;
|
||||
}
|
||||
CHECK_EQ(lhs_or_rhs, 1);
|
||||
partition_count *= sharding.tile_assignment().dim(dim.rhs);
|
||||
}
|
||||
}
|
||||
return partition_count == sharding.tile_assignment().num_elements();
|
||||
};
|
||||
// If LHS/RHS is partitioned only along the batch dimensions, propagate
|
||||
// the sharding to the output, since batch dimensions are the easiest to
|
||||
// partition.
|
||||
if (IsSpatiallyPartitioned(lhs) &&
|
||||
partitioned_only_along_non_trivial_dims(lhs->sharding(),
|
||||
dot_dims->batch_dims, 0)) {
|
||||
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
|
||||
instruction);
|
||||
}
|
||||
if (IsSpatiallyPartitioned(rhs) &&
|
||||
partitioned_only_along_non_trivial_dims(rhs->sharding(),
|
||||
dot_dims->batch_dims, 1)) {
|
||||
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_rhs(),
|
||||
instruction);
|
||||
}
|
||||
if (aggressive_prop) {
|
||||
// If LHS/RHS is partitioned only along the non-contracting
|
||||
// dimensions, propagate the sharding to the output.
|
||||
const bool can_propagate_from_lhs =
|
||||
IsSpatiallyPartitioned(lhs) &&
|
||||
partitioned_only_along_non_trivial_dims(
|
||||
lhs->sharding(), dot_dims->lhs_non_contracting_dims, 0);
|
||||
const bool can_propagate_from_rhs =
|
||||
IsSpatiallyPartitioned(rhs) &&
|
||||
partitioned_only_along_non_trivial_dims(
|
||||
rhs->sharding(), dot_dims->rhs_non_contracting_dims, 1);
|
||||
// If we can propagate from both operands, choose the larger one which
|
||||
// should help us reduce communications.
|
||||
if (can_propagate_from_lhs && can_propagate_from_rhs) {
|
||||
if (Product(lhs->shape().dimensions()) >=
|
||||
Product(rhs->shape().dimensions())) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_lhs(), instruction);
|
||||
} else {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_rhs(), instruction);
|
||||
}
|
||||
}
|
||||
if (can_propagate_from_lhs) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_lhs(), instruction);
|
||||
}
|
||||
if (can_propagate_from_rhs) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_rhs(), instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsSpatiallyPartitioned(lhs)) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->sharding().IsReplicated()) {
|
||||
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
|
||||
instruction);
|
||||
}
|
||||
|
||||
if (IsConvolutionKernelSmall(instruction)) {
|
||||
// If the kernel is small compared to the input then we can generate an
|
||||
// output what is sharded the same way as the input.
|
||||
const auto& tile_assignment = lhs->sharding().tile_assignment();
|
||||
if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
|
||||
return false;
|
||||
}
|
||||
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
|
||||
instruction);
|
||||
}
|
||||
// If the kernel is large (e.g backward convolution) then we only support
|
||||
// replicated output.
|
||||
return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction);
|
||||
}
|
||||
|
||||
// Tries to update the sharding of the specified instruction based on its
|
||||
// operands and returns true if the sharding of the instruction have been
|
||||
// changed and false otherwise.
|
||||
@ -531,132 +665,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
|
||||
HloSharding new_sharding = HloSharding::Tile(new_tile_assignment);
|
||||
return MaybeImproveInstructionSharding(new_sharding, instruction);
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
const auto& dnums = instruction->convolution_dimension_numbers();
|
||||
const HloInstruction* lhs = instruction->operand(0);
|
||||
const HloInstruction* rhs = instruction->operand(1);
|
||||
auto get_tiled_sharding_based_on_lhs = [&] {
|
||||
CHECK(!lhs->sharding().IsTileMaximal());
|
||||
std::vector<int64> output_to_lhs_indices(instruction->shape().rank());
|
||||
output_to_lhs_indices[dnums.output_batch_dimension()] =
|
||||
dnums.input_batch_dimension();
|
||||
output_to_lhs_indices[dnums.output_feature_dimension()] =
|
||||
dnums.input_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
|
||||
dnums.input_spatial_dimensions(i);
|
||||
}
|
||||
return hlo_sharding_util::TransposeSharding(lhs->sharding(),
|
||||
output_to_lhs_indices);
|
||||
};
|
||||
auto get_tiled_sharding_based_on_rhs = [&] {
|
||||
CHECK(!rhs->sharding().IsTileMaximal());
|
||||
std::vector<int64> output_to_rhs_indices(instruction->shape().rank());
|
||||
output_to_rhs_indices[dnums.output_batch_dimension()] =
|
||||
dnums.kernel_input_feature_dimension();
|
||||
output_to_rhs_indices[dnums.output_feature_dimension()] =
|
||||
dnums.kernel_output_feature_dimension();
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
output_to_rhs_indices[dnums.output_spatial_dimensions(i)] =
|
||||
dnums.kernel_spatial_dimensions(i);
|
||||
}
|
||||
return hlo_sharding_util::TransposeSharding(rhs->sharding(),
|
||||
output_to_rhs_indices);
|
||||
};
|
||||
if (auto dot_dims =
|
||||
dot_as_convolution_util::ParseDotGeneralFromConvolution(
|
||||
instruction)) {
|
||||
// lhs_or_rhs: lhs is 0 and rhs is 1.
|
||||
auto partitioned_only_along =
|
||||
[&](const HloSharding& sharding,
|
||||
std::vector<dot_as_convolution_util::
|
||||
DotGeneralAsConvolutionDimsInfo::DimNums>& dims,
|
||||
int64 lhs_or_rhs) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return false;
|
||||
}
|
||||
int64 partition_count = 1;
|
||||
for (const auto& dim : dims) {
|
||||
if (lhs_or_rhs == 0) {
|
||||
partition_count *= sharding.tile_assignment().dim(dim.lhs);
|
||||
} else {
|
||||
CHECK_EQ(lhs_or_rhs, 1);
|
||||
partition_count *= sharding.tile_assignment().dim(dim.rhs);
|
||||
}
|
||||
}
|
||||
return partition_count ==
|
||||
sharding.tile_assignment().num_elements();
|
||||
};
|
||||
// If LHS/RHS is partitioned only along the batch dimensions, propagate
|
||||
// the sharding to the output, since batch dimensions are the easiest to
|
||||
// partition.
|
||||
if (IsSpatiallyPartitioned(lhs) &&
|
||||
partitioned_only_along(lhs->sharding(), dot_dims->batch_dims, 0)) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_lhs(), instruction);
|
||||
}
|
||||
if (IsSpatiallyPartitioned(rhs) &&
|
||||
partitioned_only_along(rhs->sharding(), dot_dims->batch_dims, 1)) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_rhs(), instruction);
|
||||
}
|
||||
if (aggressive_prop) {
|
||||
// If LHS/RHS is partitioned only along the non-contracting
|
||||
// dimensions, propagate the sharding to the output.
|
||||
const bool can_propagate_from_lhs =
|
||||
IsSpatiallyPartitioned(lhs) &&
|
||||
partitioned_only_along(lhs->sharding(),
|
||||
dot_dims->lhs_non_contracting_dims, 0);
|
||||
const bool can_propagate_from_rhs =
|
||||
IsSpatiallyPartitioned(rhs) &&
|
||||
partitioned_only_along(rhs->sharding(),
|
||||
dot_dims->rhs_non_contracting_dims, 1);
|
||||
// If we can propagate from both operands, choose the larger one which
|
||||
// should help us reduce communications.
|
||||
if (can_propagate_from_lhs && can_propagate_from_rhs) {
|
||||
if (Product(lhs->shape().dimensions()) >=
|
||||
Product(rhs->shape().dimensions())) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_lhs(), instruction);
|
||||
} else {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_rhs(), instruction);
|
||||
}
|
||||
}
|
||||
if (can_propagate_from_lhs) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_lhs(), instruction);
|
||||
}
|
||||
if (can_propagate_from_rhs) {
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_rhs(), instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsSpatiallyPartitioned(lhs)) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->sharding().IsReplicated()) {
|
||||
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
|
||||
instruction);
|
||||
}
|
||||
|
||||
if (IsConvolutionKernelSmall(instruction)) {
|
||||
// If the kernel is small compared to the input then we can generate an
|
||||
// output what is sharded the same way as the input.
|
||||
const auto& tile_assignment = lhs->sharding().tile_assignment();
|
||||
if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
|
||||
return false;
|
||||
}
|
||||
return MaybeImproveInstructionSharding(
|
||||
get_tiled_sharding_based_on_lhs(), instruction);
|
||||
}
|
||||
// If the kernel is large (e.g backward convolution) then we only support
|
||||
// replicated output.
|
||||
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
|
||||
instruction);
|
||||
}
|
||||
case HloOpcode::kConvolution:
|
||||
return InferConvolutionShardingFromOperands(instruction, aggressive_prop);
|
||||
case HloOpcode::kTranspose: {
|
||||
const HloInstruction* input = instruction->operand(0);
|
||||
if (!IsSpatiallyPartitioned(input)) {
|
||||
@ -1002,7 +1012,7 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
if (auto dot_dims =
|
||||
dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) {
|
||||
const auto& dnums = user.convolution_dimension_numbers();
|
||||
auto partitioned_only_along =
|
||||
auto partitioned_only_along_non_trivial_dims =
|
||||
[&](const HloSharding& sharding,
|
||||
std::vector<dot_as_convolution_util::
|
||||
DotGeneralAsConvolutionDimsInfo::DimNums>&
|
||||
@ -1012,6 +1022,9 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
}
|
||||
int64 partition_count = 1;
|
||||
for (const auto& dim : dims) {
|
||||
if (user.shape().dimensions(dim.output) == 1) {
|
||||
continue;
|
||||
}
|
||||
partition_count *= sharding.tile_assignment().dim(dim.output);
|
||||
}
|
||||
return partition_count ==
|
||||
@ -1021,9 +1034,10 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
// along the non-contracting dimensions, propagate the sharding to the
|
||||
// operand.
|
||||
if (&instruction == user.operand(0) &&
|
||||
(partitioned_only_along(user.sharding(), dot_dims->batch_dims) ||
|
||||
partitioned_only_along(user.sharding(),
|
||||
dot_dims->lhs_non_contracting_dims))) {
|
||||
(partitioned_only_along_non_trivial_dims(user.sharding(),
|
||||
dot_dims->batch_dims) ||
|
||||
partitioned_only_along_non_trivial_dims(
|
||||
user.sharding(), dot_dims->lhs_non_contracting_dims))) {
|
||||
std::vector<int64> lhs_to_output_indices(user.shape().rank());
|
||||
lhs_to_output_indices[dnums.input_batch_dimension()] =
|
||||
dnums.output_batch_dimension();
|
||||
@ -1037,9 +1051,10 @@ absl::optional<HloSharding> GetShardingFromUser(
|
||||
lhs_to_output_indices);
|
||||
}
|
||||
if (&instruction == user.operand(1) &&
|
||||
(partitioned_only_along(user.sharding(), dot_dims->batch_dims) ||
|
||||
partitioned_only_along(user.sharding(),
|
||||
dot_dims->rhs_non_contracting_dims))) {
|
||||
(partitioned_only_along_non_trivial_dims(user.sharding(),
|
||||
dot_dims->batch_dims) ||
|
||||
partitioned_only_along_non_trivial_dims(
|
||||
user.sharding(), dot_dims->rhs_non_contracting_dims))) {
|
||||
std::vector<int64> rhs_to_output_indices(user.shape().rank());
|
||||
rhs_to_output_indices[dnums.kernel_input_feature_dimension()] =
|
||||
dnums.output_batch_dimension();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user