[XLA] Extend dot_as_convolution to detect the rhs transpose rule of the forward conv as batch dot dimension. Also do not swap conv operations based purely on inputs of size 1.
PiperOrigin-RevId: 321492318 Change-Id: Ieb7f24c1347f9d2cbcfaba5e3bf3a62d8898b01c
This commit is contained in:
parent
540f8dbdd8
commit
59215389ce
@ -4697,15 +4697,17 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
|||||||
for (int64 spatial_dim = 0;
|
for (int64 spatial_dim = 0;
|
||||||
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
|
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
|
||||||
const int64 kernel_size = window_dims[spatial_dim].size();
|
const int64 kernel_size = window_dims[spatial_dim].size();
|
||||||
kernel_product *= kernel_size;
|
|
||||||
const int64 dilated_kernel_size =
|
const int64 dilated_kernel_size =
|
||||||
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
|
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
|
||||||
|
|
||||||
const int64 input_size =
|
const int64 input_size =
|
||||||
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
|
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
|
||||||
swapped_kernel_product *= input_size;
|
|
||||||
const int64 dilated_input_size =
|
const int64 dilated_input_size =
|
||||||
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
|
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
|
||||||
|
// Don't decide to swap if the input size is one, since many convolution
|
||||||
|
// implementations can easily hand that special case efficiently.
|
||||||
|
kernel_product *= kernel_size;
|
||||||
|
swapped_kernel_product *= input_size == 1 ? kernel_size : input_size;
|
||||||
|
|
||||||
auto new_dim = swapped_window.add_dimensions();
|
auto new_dim = swapped_window.add_dimensions();
|
||||||
new_dim->set_size(input_size);
|
new_dim->set_size(input_size);
|
||||||
|
@ -49,15 +49,23 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
|
|||||||
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
|
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
|
||||||
int64 output = conv_dims.output_spatial_dimensions(i);
|
int64 output = conv_dims.output_spatial_dimensions(i);
|
||||||
const auto& wd = conv->window().dimensions(i);
|
const auto& wd = conv->window().dimensions(i);
|
||||||
if (lhs_size == wd.size() &&
|
if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
|
||||||
std::max<int64>(1, lhs_size - 1) == wd.stride() &&
|
((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
|
||||||
lhs_size == wd.base_dilation() && wd.window_dilation() == 1 &&
|
wd.window_dilation() == 1) ||
|
||||||
|
(std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
|
||||||
|
wd.stride() == 1)) &&
|
||||||
wd.padding_high() == 0 && wd.padding_low() == 0 &&
|
wd.padding_high() == 0 && wd.padding_low() == 0 &&
|
||||||
!wd.window_reversal()) {
|
!wd.window_reversal()) {
|
||||||
// A batch dimension in DotGeneral is represented as a spatial dimension
|
// A batch dimension in DotGeneral is represented as a spatial dimension
|
||||||
// with window size B (batch dimension size), stride B - 1, and base
|
// with window size B (batch dimension size), stride B - 1, and base
|
||||||
// dilation B.
|
// dilation B.
|
||||||
dims.batch_dims.push_back({lhs, rhs, output, i});
|
dims.batch_dims.push_back({lhs, rhs, output, i});
|
||||||
|
} else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
|
||||||
|
wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
|
||||||
|
wd.window_dilation() == 1 && wd.stride() == lhs_size &&
|
||||||
|
wd.base_dilation() == lhs_size - 1) {
|
||||||
|
// Aternative representation of a batch dimension.
|
||||||
|
dims.batch_dims.push_back({lhs, rhs, output, i});
|
||||||
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
|
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
|
||||||
wd.window_dilation() == 1 && wd.padding_high() == 0 &&
|
wd.window_dilation() == 1 && wd.padding_high() == 0 &&
|
||||||
wd.padding_low() == 0 && !wd.window_reversal()) {
|
wd.padding_low() == 0 && !wd.window_reversal()) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user