[XLA:SPMD] Support convolution with non contracting spatial dim partitioned at

batch dim.

PiperOrigin-RevId: 321579087
Change-Id: I42f86b9c281e02f8157287653aa30a54c14a0e72
This commit is contained in:
A. Unique TensorFlower 2020-07-16 09:12:50 -07:00 committed by TensorFlower Gardener
parent 279ef8e3e6
commit 960358aaa2
5 changed files with 109 additions and 16 deletions

View File

@ -24,6 +24,31 @@ limitations under the License.
namespace xla {
namespace dot_as_convolution_util {
bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
// A parallel batch dimension in DotGeneral is represented as a
// spatial dimension with window size B (batch dimension size),
// stride B - 1, and base dilation B.
if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
wd.window_dilation() == 1) ||
(std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
wd.stride() == 1)) &&
wd.padding_high() == 0 && wd.padding_low() == 0 &&
!wd.window_reversal()) {
return true;
}
// Aternative representation of a batch dimension.
if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
wd.window_dilation() == 1 && wd.stride() == lhs_size &&
wd.base_dilation() == lhs_size - 1) {
return true;
}
return false;
}
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@ -49,22 +74,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
int64 output = conv_dims.output_spatial_dimensions(i);
const auto& wd = conv->window().dimensions(i);
if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
wd.window_dilation() == 1) ||
(std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
wd.stride() == 1)) &&
wd.padding_high() == 0 && wd.padding_low() == 0 &&
!wd.window_reversal()) {
// A batch dimension in DotGeneral is represented as a spatial dimension
// with window size B (batch dimension size), stride B - 1, and base
// dilation B.
dims.batch_dims.push_back({lhs, rhs, output, i});
} else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
wd.window_dilation() == 1 && wd.stride() == lhs_size &&
wd.base_dilation() == lhs_size - 1) {
// Aternative representation of a batch dimension.
if (ConvSpatialDimensionIsParallel(wd, lhs_size)) {
dims.batch_dims.push_back({lhs, rhs, output, i});
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
wd.window_dilation() == 1 && wd.padding_high() == 0 &&

View File

@ -62,6 +62,12 @@ CreateShardedConvForDotGeneralConvolution(
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
// Check if a spatial dim is parallel batch dimension.
// A parallel batch dimension in DotGeneral is represented as a spatial
// dimension with window size B (batch dimension size), stride B - 1, and base
// dilation B.
bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
} // namespace dot_as_convolution_util
} // namespace xla

View File

@ -3149,6 +3149,72 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
auto aligned_lhs_sharding =
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
// Handling cases where all the partitioned dimensions are parallel
// dimensions.
int64 lhs_parallel_dim_partitions = 1;
int64 rhs_parallel_dim_partitions = 1;
std::vector<int64> parallel_spatial_dims;
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
int64 lhs_dim = dnums.input_spatial_dimensions(i);
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
const auto& wd = hlo->window().dimensions(i);
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
// Only non reversal window is supported right now.
if (!wd.window_reversal() &&
dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
parallel_spatial_dims.emplace_back(i);
lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
}
}
bool lhs_partition_dims_are_parallel =
(lhs_parallel_dim_partitions == num_partitions_);
bool rhs_partition_dims_are_parallel =
(rhs_parallel_dim_partitions == num_partitions_);
// If there is a parallel dim and all the partitioned dimensions are parallel
// dimensions in either LHS or RHS, simply create partitioned convolutions.
if (!parallel_spatial_dims.empty() &&
(lhs_partition_dims_are_parallel || rhs_partition_dims_are_parallel)) {
// Reshard LHS or RHS to partition at parallel dimensions as the other
// operand.
if (lhs_partition_dims_are_parallel) {
rhs = rhs.Reshard(aligned_rhs_sharding);
} else {
lhs = lhs.Reshard(aligned_lhs_sharding);
}
auto lhs_shard_shape =
MakePartitionedShape(lhs.base_shape(), lhs.sharding());
auto rhs_shard_shape =
MakePartitionedShape(rhs.base_shape(), rhs.sharding());
// Update convolution window.
auto new_window = hlo->window();
for (const auto& spatial_dim : parallel_spatial_dims) {
auto wd = new_window.mutable_dimensions(spatial_dim);
wd->set_size(lhs_shard_shape.dimensions(
dnums.input_spatial_dimensions(spatial_dim)));
wd->set_stride(std::max<int64>(1, wd->size() - 1));
wd->set_base_dilation(wd->size());
}
TF_ASSIGN_OR_RETURN(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
lhs_shard_shape, rhs_shard_shape, hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums));
*sharded_conv_shape.mutable_layout() = hlo->shape().layout();
SetPartitionedHlo(hlo, [&]() {
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
sharded_conv_shape, lhs.hlo(), rhs.hlo(), hlo->feature_group_count(),
hlo->batch_group_count(), new_window, dnums,
hlo->precision_config()));
sharded_conv->set_sharding(hlo->sharding());
return PartitionedHlo(sharded_conv, hlo->shape(), MakePartitioningState())
.Reshard(hlo->sharding())
.hlo();
});
return Status::OK();
}
// Handling cases where both operands' shardings are aligned. We check that
// the LHS batch dimension is not partitioned because it is mapped to the
// output feature dimension in aligned_rhs_sharding, which are not the same

View File

@ -877,5 +877,13 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
output_shape, hlo, start_indices, limit_indices, strides));
}
// Check if a dimension is sharded.
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
if (sharding.IsTileMaximal()) {
return 1;
}
return sharding.tile_assignment().dim(dim);
}
} // namespace spmd
} // namespace xla

View File

@ -262,6 +262,9 @@ absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
int64 slice_dim, int64 k);
// Check if a dimension is sharded.
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
} // namespace spmd
} // namespace xla