[XLA] Avoid sharding propagation on trivial dims for dot-general convolution

A dimension with size 1 can be treated as a batch dimension, but it doesn't make much sense to shard it.

PiperOrigin-RevId: 315840599
Change-Id: I34946a94f6a921468a1a3c1f2fbb4fcbcc506c97
This commit is contained in:
Yuanzhong Xu 2020-06-10 23:00:11 -07:00 committed by TensorFlower Gardener
parent 262ac8ec97
commit 1c064ab760

View File

@ -359,6 +359,140 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction,
}
}
// Convolution handling for InferShardingFromOperands().
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
bool aggressive_prop) {
const auto& dnums = instruction->convolution_dimension_numbers();
const HloInstruction* lhs = instruction->operand(0);
const HloInstruction* rhs = instruction->operand(1);
auto get_tiled_sharding_based_on_lhs = [&] {
CHECK(!lhs->sharding().IsTileMaximal());
std::vector<int64> output_to_lhs_indices(instruction->shape().rank());
output_to_lhs_indices[dnums.output_batch_dimension()] =
dnums.input_batch_dimension();
output_to_lhs_indices[dnums.output_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
return hlo_sharding_util::TransposeSharding(lhs->sharding(),
output_to_lhs_indices);
};
auto get_tiled_sharding_based_on_rhs = [&] {
CHECK(!rhs->sharding().IsTileMaximal());
std::vector<int64> output_to_rhs_indices(instruction->shape().rank());
output_to_rhs_indices[dnums.output_batch_dimension()] =
dnums.kernel_input_feature_dimension();
output_to_rhs_indices[dnums.output_feature_dimension()] =
dnums.kernel_output_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
output_to_rhs_indices[dnums.output_spatial_dimensions(i)] =
dnums.kernel_spatial_dimensions(i);
}
return hlo_sharding_util::TransposeSharding(rhs->sharding(),
output_to_rhs_indices);
};
if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution(
instruction)) {
// lhs_or_rhs: lhs is 0 and rhs is 1. Skips dimensions with size 1.
auto partitioned_only_along_non_trivial_dims =
[&](const HloSharding& sharding,
std::vector<dot_as_convolution_util::
DotGeneralAsConvolutionDimsInfo::DimNums>& dims,
int64 lhs_or_rhs) {
if (sharding.IsTileMaximal()) {
return false;
}
int64 partition_count = 1;
for (const auto& dim : dims) {
if (lhs_or_rhs == 0) {
if (lhs->shape().dimensions(dim.lhs) == 1) {
continue;
}
partition_count *= sharding.tile_assignment().dim(dim.lhs);
} else {
if (rhs->shape().dimensions(dim.rhs) == 1) {
continue;
}
CHECK_EQ(lhs_or_rhs, 1);
partition_count *= sharding.tile_assignment().dim(dim.rhs);
}
}
return partition_count == sharding.tile_assignment().num_elements();
};
// If LHS/RHS is partitioned only along the batch dimensions, propagate
// the sharding to the output, since batch dimensions are the easiest to
// partition.
if (IsSpatiallyPartitioned(lhs) &&
partitioned_only_along_non_trivial_dims(lhs->sharding(),
dot_dims->batch_dims, 0)) {
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
instruction);
}
if (IsSpatiallyPartitioned(rhs) &&
partitioned_only_along_non_trivial_dims(rhs->sharding(),
dot_dims->batch_dims, 1)) {
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_rhs(),
instruction);
}
if (aggressive_prop) {
// If LHS/RHS is partitioned only along the non-contracting
// dimensions, propagate the sharding to the output.
const bool can_propagate_from_lhs =
IsSpatiallyPartitioned(lhs) &&
partitioned_only_along_non_trivial_dims(
lhs->sharding(), dot_dims->lhs_non_contracting_dims, 0);
const bool can_propagate_from_rhs =
IsSpatiallyPartitioned(rhs) &&
partitioned_only_along_non_trivial_dims(
rhs->sharding(), dot_dims->rhs_non_contracting_dims, 1);
// If we can propagate from both operands, choose the larger one which
// should help us reduce communications.
if (can_propagate_from_lhs && can_propagate_from_rhs) {
if (Product(lhs->shape().dimensions()) >=
Product(rhs->shape().dimensions())) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_lhs(), instruction);
} else {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_rhs(), instruction);
}
}
if (can_propagate_from_lhs) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_lhs(), instruction);
}
if (can_propagate_from_rhs) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_rhs(), instruction);
}
}
}
if (!IsSpatiallyPartitioned(lhs)) {
return false;
}
if (lhs->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
instruction);
}
if (IsConvolutionKernelSmall(instruction)) {
// If the kernel is small compared to the input then we can generate an
// output what is sharded the same way as the input.
const auto& tile_assignment = lhs->sharding().tile_assignment();
if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
return false;
}
return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
instruction);
}
// If the kernel is large (e.g backward convolution) then we only support
// replicated output.
return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction);
}
// Tries to update the sharding of the specified instruction based on its
// operands and returns true if the sharding of the instruction have been
// changed and false otherwise.
@ -531,132 +665,8 @@ bool InferShardingFromOperands(HloInstruction* instruction,
HloSharding new_sharding = HloSharding::Tile(new_tile_assignment);
return MaybeImproveInstructionSharding(new_sharding, instruction);
}
case HloOpcode::kConvolution: {
const auto& dnums = instruction->convolution_dimension_numbers();
const HloInstruction* lhs = instruction->operand(0);
const HloInstruction* rhs = instruction->operand(1);
auto get_tiled_sharding_based_on_lhs = [&] {
CHECK(!lhs->sharding().IsTileMaximal());
std::vector<int64> output_to_lhs_indices(instruction->shape().rank());
output_to_lhs_indices[dnums.output_batch_dimension()] =
dnums.input_batch_dimension();
output_to_lhs_indices[dnums.output_feature_dimension()] =
dnums.input_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
dnums.input_spatial_dimensions(i);
}
return hlo_sharding_util::TransposeSharding(lhs->sharding(),
output_to_lhs_indices);
};
auto get_tiled_sharding_based_on_rhs = [&] {
CHECK(!rhs->sharding().IsTileMaximal());
std::vector<int64> output_to_rhs_indices(instruction->shape().rank());
output_to_rhs_indices[dnums.output_batch_dimension()] =
dnums.kernel_input_feature_dimension();
output_to_rhs_indices[dnums.output_feature_dimension()] =
dnums.kernel_output_feature_dimension();
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
output_to_rhs_indices[dnums.output_spatial_dimensions(i)] =
dnums.kernel_spatial_dimensions(i);
}
return hlo_sharding_util::TransposeSharding(rhs->sharding(),
output_to_rhs_indices);
};
if (auto dot_dims =
dot_as_convolution_util::ParseDotGeneralFromConvolution(
instruction)) {
// lhs_or_rhs: lhs is 0 and rhs is 1.
auto partitioned_only_along =
[&](const HloSharding& sharding,
std::vector<dot_as_convolution_util::
DotGeneralAsConvolutionDimsInfo::DimNums>& dims,
int64 lhs_or_rhs) {
if (sharding.IsTileMaximal()) {
return false;
}
int64 partition_count = 1;
for (const auto& dim : dims) {
if (lhs_or_rhs == 0) {
partition_count *= sharding.tile_assignment().dim(dim.lhs);
} else {
CHECK_EQ(lhs_or_rhs, 1);
partition_count *= sharding.tile_assignment().dim(dim.rhs);
}
}
return partition_count ==
sharding.tile_assignment().num_elements();
};
// If LHS/RHS is partitioned only along the batch dimensions, propagate
// the sharding to the output, since batch dimensions are the easiest to
// partition.
if (IsSpatiallyPartitioned(lhs) &&
partitioned_only_along(lhs->sharding(), dot_dims->batch_dims, 0)) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_lhs(), instruction);
}
if (IsSpatiallyPartitioned(rhs) &&
partitioned_only_along(rhs->sharding(), dot_dims->batch_dims, 1)) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_rhs(), instruction);
}
if (aggressive_prop) {
// If LHS/RHS is partitioned only along the non-contracting
// dimensions, propagate the sharding to the output.
const bool can_propagate_from_lhs =
IsSpatiallyPartitioned(lhs) &&
partitioned_only_along(lhs->sharding(),
dot_dims->lhs_non_contracting_dims, 0);
const bool can_propagate_from_rhs =
IsSpatiallyPartitioned(rhs) &&
partitioned_only_along(rhs->sharding(),
dot_dims->rhs_non_contracting_dims, 1);
// If we can propagate from both operands, choose the larger one which
// should help us reduce communications.
if (can_propagate_from_lhs && can_propagate_from_rhs) {
if (Product(lhs->shape().dimensions()) >=
Product(rhs->shape().dimensions())) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_lhs(), instruction);
} else {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_rhs(), instruction);
}
}
if (can_propagate_from_lhs) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_lhs(), instruction);
}
if (can_propagate_from_rhs) {
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_rhs(), instruction);
}
}
}
if (!IsSpatiallyPartitioned(lhs)) {
return false;
}
if (lhs->sharding().IsReplicated()) {
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
instruction);
}
if (IsConvolutionKernelSmall(instruction)) {
// If the kernel is small compared to the input then we can generate an
// output what is sharded the same way as the input.
const auto& tile_assignment = lhs->sharding().tile_assignment();
if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
return false;
}
return MaybeImproveInstructionSharding(
get_tiled_sharding_based_on_lhs(), instruction);
}
// If the kernel is large (e.g backward convolution) then we only support
// replicated output.
return MaybeImproveInstructionSharding(HloSharding::Replicate(),
instruction);
}
case HloOpcode::kConvolution:
return InferConvolutionShardingFromOperands(instruction, aggressive_prop);
case HloOpcode::kTranspose: {
const HloInstruction* input = instruction->operand(0);
if (!IsSpatiallyPartitioned(input)) {
@ -1002,7 +1012,7 @@ absl::optional<HloSharding> GetShardingFromUser(
if (auto dot_dims =
dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) {
const auto& dnums = user.convolution_dimension_numbers();
auto partitioned_only_along =
auto partitioned_only_along_non_trivial_dims =
[&](const HloSharding& sharding,
std::vector<dot_as_convolution_util::
DotGeneralAsConvolutionDimsInfo::DimNums>&
@ -1012,6 +1022,9 @@ absl::optional<HloSharding> GetShardingFromUser(
}
int64 partition_count = 1;
for (const auto& dim : dims) {
if (user.shape().dimensions(dim.output) == 1) {
continue;
}
partition_count *= sharding.tile_assignment().dim(dim.output);
}
return partition_count ==
@ -1021,9 +1034,10 @@ absl::optional<HloSharding> GetShardingFromUser(
// along the non-contracting dimensions, propagate the sharding to the
// operand.
if (&instruction == user.operand(0) &&
(partitioned_only_along(user.sharding(), dot_dims->batch_dims) ||
partitioned_only_along(user.sharding(),
dot_dims->lhs_non_contracting_dims))) {
(partitioned_only_along_non_trivial_dims(user.sharding(),
dot_dims->batch_dims) ||
partitioned_only_along_non_trivial_dims(
user.sharding(), dot_dims->lhs_non_contracting_dims))) {
std::vector<int64> lhs_to_output_indices(user.shape().rank());
lhs_to_output_indices[dnums.input_batch_dimension()] =
dnums.output_batch_dimension();
@ -1037,9 +1051,10 @@ absl::optional<HloSharding> GetShardingFromUser(
lhs_to_output_indices);
}
if (&instruction == user.operand(1) &&
(partitioned_only_along(user.sharding(), dot_dims->batch_dims) ||
partitioned_only_along(user.sharding(),
dot_dims->rhs_non_contracting_dims))) {
(partitioned_only_along_non_trivial_dims(user.sharding(),
dot_dims->batch_dims) ||
partitioned_only_along_non_trivial_dims(
user.sharding(), dot_dims->rhs_non_contracting_dims))) {
std::vector<int64> rhs_to_output_indices(user.shape().rank());
rhs_to_output_indices[dnums.kernel_input_feature_dimension()] =
dnums.output_batch_dimension();