[XLA:SPMD] Support convolution with non contracting spatial dim partitioned at
batch dim. PiperOrigin-RevId: 321579087 Change-Id: I42f86b9c281e02f8157287653aa30a54c14a0e72
This commit is contained in:
parent
279ef8e3e6
commit
960358aaa2
tensorflow/compiler/xla/service
@ -24,6 +24,31 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace dot_as_convolution_util {
|
||||
|
||||
bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) {
|
||||
// A parallel batch dimension in DotGeneral is represented as a
|
||||
// spatial dimension with window size B (batch dimension size),
|
||||
// stride B - 1, and base dilation B.
|
||||
if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
|
||||
((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
|
||||
wd.window_dilation() == 1) ||
|
||||
(std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
|
||||
wd.stride() == 1)) &&
|
||||
wd.padding_high() == 0 && wd.padding_low() == 0 &&
|
||||
!wd.window_reversal()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Aternative representation of a batch dimension.
|
||||
if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
|
||||
wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
|
||||
wd.window_dilation() == 1 && wd.stride() == lhs_size &&
|
||||
wd.base_dilation() == lhs_size - 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo>
|
||||
ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
||||
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
|
||||
@ -49,22 +74,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
||||
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
|
||||
int64 output = conv_dims.output_spatial_dimensions(i);
|
||||
const auto& wd = conv->window().dimensions(i);
|
||||
if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
|
||||
((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
|
||||
wd.window_dilation() == 1) ||
|
||||
(std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
|
||||
wd.stride() == 1)) &&
|
||||
wd.padding_high() == 0 && wd.padding_low() == 0 &&
|
||||
!wd.window_reversal()) {
|
||||
// A batch dimension in DotGeneral is represented as a spatial dimension
|
||||
// with window size B (batch dimension size), stride B - 1, and base
|
||||
// dilation B.
|
||||
dims.batch_dims.push_back({lhs, rhs, output, i});
|
||||
} else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
|
||||
wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
|
||||
wd.window_dilation() == 1 && wd.stride() == lhs_size &&
|
||||
wd.base_dilation() == lhs_size - 1) {
|
||||
// Aternative representation of a batch dimension.
|
||||
if (ConvSpatialDimensionIsParallel(wd, lhs_size)) {
|
||||
dims.batch_dims.push_back({lhs, rhs, output, i});
|
||||
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
|
||||
wd.window_dilation() == 1 && wd.padding_high() == 0 &&
|
||||
|
@ -62,6 +62,12 @@ CreateShardedConvForDotGeneralConvolution(
|
||||
const DotGeneralAsConvolutionDimsInfo& dot_dnums,
|
||||
HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo);
|
||||
|
||||
// Check if a spatial dim is parallel batch dimension.
|
||||
// A parallel batch dimension in DotGeneral is represented as a spatial
|
||||
// dimension with window size B (batch dimension size), stride B - 1, and base
|
||||
// dilation B.
|
||||
bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size);
|
||||
|
||||
} // namespace dot_as_convolution_util
|
||||
} // namespace xla
|
||||
|
||||
|
@ -3149,6 +3149,72 @@ Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) {
|
||||
auto aligned_lhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
|
||||
|
||||
// Handling cases where all the partitioned dimensions are parallel
|
||||
// dimensions.
|
||||
int64 lhs_parallel_dim_partitions = 1;
|
||||
int64 rhs_parallel_dim_partitions = 1;
|
||||
std::vector<int64> parallel_spatial_dims;
|
||||
for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
|
||||
int64 lhs_dim = dnums.input_spatial_dimensions(i);
|
||||
int64 lhs_size = lhs.base_shape().dimensions(lhs_dim);
|
||||
const auto& wd = hlo->window().dimensions(i);
|
||||
int64 rhs_dim = dnums.kernel_spatial_dimensions(i);
|
||||
// Only non reversal window is supported right now.
|
||||
if (!wd.window_reversal() &&
|
||||
dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) {
|
||||
parallel_spatial_dims.emplace_back(i);
|
||||
lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim);
|
||||
rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim);
|
||||
}
|
||||
}
|
||||
bool lhs_partition_dims_are_parallel =
|
||||
(lhs_parallel_dim_partitions == num_partitions_);
|
||||
bool rhs_partition_dims_are_parallel =
|
||||
(rhs_parallel_dim_partitions == num_partitions_);
|
||||
|
||||
// If there is a parallel dim and all the partitioned dimensions are parallel
|
||||
// dimensions in either LHS or RHS, simply create partitioned convolutions.
|
||||
if (!parallel_spatial_dims.empty() &&
|
||||
(lhs_partition_dims_are_parallel || rhs_partition_dims_are_parallel)) {
|
||||
// Reshard LHS or RHS to partition at parallel dimensions as the other
|
||||
// operand.
|
||||
if (lhs_partition_dims_are_parallel) {
|
||||
rhs = rhs.Reshard(aligned_rhs_sharding);
|
||||
} else {
|
||||
lhs = lhs.Reshard(aligned_lhs_sharding);
|
||||
}
|
||||
auto lhs_shard_shape =
|
||||
MakePartitionedShape(lhs.base_shape(), lhs.sharding());
|
||||
auto rhs_shard_shape =
|
||||
MakePartitionedShape(rhs.base_shape(), rhs.sharding());
|
||||
// Update convolution window.
|
||||
auto new_window = hlo->window();
|
||||
for (const auto& spatial_dim : parallel_spatial_dims) {
|
||||
auto wd = new_window.mutable_dimensions(spatial_dim);
|
||||
wd->set_size(lhs_shard_shape.dimensions(
|
||||
dnums.input_spatial_dimensions(spatial_dim)));
|
||||
wd->set_stride(std::max<int64>(1, wd->size() - 1));
|
||||
wd->set_base_dilation(wd->size());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shard_shape, rhs_shard_shape, hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window, dnums));
|
||||
*sharded_conv_shape.mutable_layout() = hlo->shape().layout();
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, lhs.hlo(), rhs.hlo(), hlo->feature_group_count(),
|
||||
hlo->batch_group_count(), new_window, dnums,
|
||||
hlo->precision_config()));
|
||||
sharded_conv->set_sharding(hlo->sharding());
|
||||
return PartitionedHlo(sharded_conv, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding())
|
||||
.hlo();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Handling cases where both operands' shardings are aligned. We check that
|
||||
// the LHS batch dimension is not partitioned because it is mapped to the
|
||||
// output feature dimension in aligned_rhs_sharding, which are not the same
|
||||
|
@ -877,5 +877,13 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
|
||||
output_shape, hlo, start_indices, limit_indices, strides));
|
||||
}
|
||||
|
||||
// Check if a dimension is sharded.
|
||||
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) {
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return 1;
|
||||
}
|
||||
return sharding.tile_assignment().dim(dim);
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -262,6 +262,9 @@ absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
|
||||
HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
|
||||
int64 slice_dim, int64 k);
|
||||
|
||||
// Check if a dimension is sharded.
|
||||
int64 ShardCountAtDim(const HloSharding& sharding, int64 dim);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user