From 59215389ceed525e3a5e739780590a4f1bdf19f4 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Wed, 15 Jul 2020 20:37:36 -0700 Subject: [PATCH] [XLA] Extend dot_as_convolution to detect the rhs transpose rule of the forward conv as batch dot dimension. Also do not swap conv operations based purely on inputs of size 1. PiperOrigin-RevId: 321492318 Change-Id: Ieb7f24c1347f9d2cbcfaba5e3bf3a62d8898b01c --- .../compiler/xla/service/algebraic_simplifier.cc | 6 ++++-- .../xla/service/dot_as_convolution_util.cc | 14 +++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 130661bf1cd..741edfc7c35 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4697,15 +4697,17 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( for (int64 spatial_dim = 0; spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { const int64 kernel_size = window_dims[spatial_dim].size(); - kernel_product *= kernel_size; const int64 dilated_kernel_size = 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); const int64 input_size = input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); - swapped_kernel_product *= input_size; const int64 dilated_input_size = 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); + // Don't decide to swap if the input size is one, since many convolution + // implementations can easily hand that special case efficiently. + kernel_product *= kernel_size; + swapped_kernel_product *= input_size == 1 ? kernel_size : input_size; auto new_dim = swapped_window.add_dimensions(); new_dim->set_size(input_size); diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc index fcdf85d5ecb..576d9d48ab8 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -49,15 +49,23 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) { int64 rhs_size = conv->operand(1)->shape().dimensions(rhs); int64 output = conv_dims.output_spatial_dimensions(i); const auto& wd = conv->window().dimensions(i); - if (lhs_size == wd.size() && - std::max(1, lhs_size - 1) == wd.stride() && - lhs_size == wd.base_dilation() && wd.window_dilation() == 1 && + if (lhs_size == wd.size() && lhs_size == wd.base_dilation() && + ((std::max(1, lhs_size - 1) == wd.stride() && + wd.window_dilation() == 1) || + (std::max(1, lhs_size - 1) == wd.window_dilation() && + wd.stride() == 1)) && wd.padding_high() == 0 && wd.padding_low() == 0 && !wd.window_reversal()) { // A batch dimension in DotGeneral is represented as a spatial dimension // with window size B (batch dimension size), stride B - 1, and base // dilation B. dims.batch_dims.push_back({lhs, rhs, output, i}); + } else if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 && + wd.padding_low() == lhs_size - 1 && wd.window_reversal() && + wd.window_dilation() == 1 && wd.stride() == lhs_size && + wd.base_dilation() == lhs_size - 1) { + // Aternative representation of a batch dimension. + dims.batch_dims.push_back({lhs, rhs, output, i}); } else if (lhs_size == wd.size() && wd.base_dilation() == 1 && wd.window_dilation() == 1 && wd.padding_high() == 0 && wd.padding_low() == 0 && !wd.window_reversal()) {