[XLA] Extend dot_as_convolution to detect the rhs transpose rule of the forward conv as batch dot dimension. Also do not swap conv operations based purely on inputs of size 1.

PiperOrigin-RevId: 321492318
Change-Id: Ieb7f24c1347f9d2cbcfaba5e3bf3a62d8898b01c
This commit is contained in:
Blake Hechtman 2020-07-15 20:37:36 -07:00 committed by TensorFlower Gardener
parent 540f8dbdd8
commit 59215389ce
2 changed files with 15 additions and 5 deletions

View File

@ -4697,15 +4697,17 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
for (int64 spatial_dim = 0; for (int64 spatial_dim = 0;
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
const int64 kernel_size = window_dims[spatial_dim].size(); const int64 kernel_size = window_dims[spatial_dim].size();
kernel_product *= kernel_size;
const int64 dilated_kernel_size = const int64 dilated_kernel_size =
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
const int64 input_size = const int64 input_size =
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
swapped_kernel_product *= input_size;
const int64 dilated_input_size = const int64 dilated_input_size =
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
// Don't decide to swap if the input size is one, since many convolution
// implementations can easily hand that special case efficiently.
kernel_product *= kernel_size;
swapped_kernel_product *= input_size == 1 ? kernel_size : input_size;
auto new_dim = swapped_window.add_dimensions(); auto new_dim = swapped_window.add_dimensions();
new_dim->set_size(input_size); new_dim->set_size(input_size);

View File

@ -49,15 +49,23 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) {
int64 rhs_size = conv->operand(1)->shape().dimensions(rhs); int64 rhs_size = conv->operand(1)->shape().dimensions(rhs);
int64 output = conv_dims.output_spatial_dimensions(i); int64 output = conv_dims.output_spatial_dimensions(i);
const auto& wd = conv->window().dimensions(i); const auto& wd = conv->window().dimensions(i);
if (lhs_size == wd.size() && if (lhs_size == wd.size() && lhs_size == wd.base_dilation() &&
std::max<int64>(1, lhs_size - 1) == wd.stride() && ((std::max<int64>(1, lhs_size - 1) == wd.stride() &&
lhs_size == wd.base_dilation() && wd.window_dilation() == 1 && wd.window_dilation() == 1) ||
(std::max<int64>(1, lhs_size - 1) == wd.window_dilation() &&
wd.stride() == 1)) &&
wd.padding_high() == 0 && wd.padding_low() == 0 && wd.padding_high() == 0 && wd.padding_low() == 0 &&
!wd.window_reversal()) { !wd.window_reversal()) {
// A batch dimension in DotGeneral is represented as a spatial dimension // A batch dimension in DotGeneral is represented as a spatial dimension
// with window size B (batch dimension size), stride B - 1, and base // with window size B (batch dimension size), stride B - 1, and base
// dilation B. // dilation B.
dims.batch_dims.push_back({lhs, rhs, output, i}); dims.batch_dims.push_back({lhs, rhs, output, i});
} else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 &&
wd.padding_low() == lhs_size - 1 && wd.window_reversal() &&
wd.window_dilation() == 1 && wd.stride() == lhs_size &&
wd.base_dilation() == lhs_size - 1) {
// Aternative representation of a batch dimension.
dims.batch_dims.push_back({lhs, rhs, output, i});
} else if (lhs_size == wd.size() && wd.base_dilation() == 1 && } else if (lhs_size == wd.size() && wd.base_dilation() == 1 &&
wd.window_dilation() == 1 && wd.padding_high() == 0 && wd.window_dilation() == 1 && wd.padding_high() == 0 &&
wd.padding_low() == 0 && !wd.window_reversal()) { wd.padding_low() == 0 && !wd.window_reversal()) {