[XLA] Grouped dims do not need to be modified when being swapped as they will
correspond one-to-one on both operands. PiperOrigin-RevId: 339796841 Change-Id: Ia71bb1bf74cb728a036b393c6ed16b2721137c7b
This commit is contained in:
parent
550434b73c
commit
ec23155aff
@ -5220,13 +5220,31 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
||||
for (int64 spatial_dim = 0;
|
||||
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
|
||||
const int64 kernel_size = window_dims[spatial_dim].size();
|
||||
const int64 dilated_kernel_size =
|
||||
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
|
||||
|
||||
const bool can_be_group_or_contraction =
|
||||
!window_dims[spatial_dim].window_reversal() &&
|
||||
window_dims[spatial_dim].padding_low() == 0 &&
|
||||
window_dims[spatial_dim].padding_high() == 0 &&
|
||||
window_dims[spatial_dim].window_dilation() == 1;
|
||||
const bool is_group_dim =
|
||||
can_be_group_or_contraction &&
|
||||
window_dims[spatial_dim].base_dilation() == kernel_size &&
|
||||
window_dims[spatial_dim].stride() == kernel_size - 1;
|
||||
const int64 input_size =
|
||||
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
|
||||
const bool is_pure_contraction_dim =
|
||||
kernel_size == input_size && can_be_group_or_contraction &&
|
||||
window_dims[spatial_dim].base_dilation() == 1 &&
|
||||
window_dims[spatial_dim].stride() == 1;
|
||||
if (is_group_dim || is_pure_contraction_dim) {
|
||||
*(swapped_window.add_dimensions()) = window_dims[spatial_dim];
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 dilated_kernel_size =
|
||||
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
|
||||
const int64 dilated_input_size =
|
||||
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
|
||||
|
||||
// Don't decide to swap if the input size is one, since many convolution
|
||||
// implementations can easily hand that special case efficiently.
|
||||
kernel_product *= kernel_size;
|
||||
|
Loading…
Reference in New Issue
Block a user