Use cudnn for grouped backward input convolution.

Previously we didn't enable this code because it looked like
we had some benchmark regressions. It turns out this was just
noise.

PiperOrigin-RevId: 315456845
Change-Id: I32404152ad35692461808e2d8d449e21e55ac95c
This commit is contained in:
Adrian Kuegel 2020-06-09 04:10:32 -07:00 committed by TensorFlower Gardener
parent 55e93450d7
commit f76e9c2915
1 changed files with 0 additions and 27 deletions

View File

@ -321,38 +321,11 @@ MatchBackwardInput(HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
// TODO: Theoretically cuDNN supports grouped convolutions also
// for the backward input convolution, but based on the cudnn's current state
// there is not much performance improvement when using the
// cudnn backward input API for grouped conv.
// This needs to be re-evaluated for future cuDNN versions.
// Note that we already have the necessary code down below, the only thing to
// enable it is to remove the following early return.
if (conv->feature_group_count() > 1) {
return no_match_result;
}
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
// Match BackwardInput for a depthwise convolution and thunk it to forward
// convolution Output feature dimension and input feature dimension has been
// swapped in the bridge. Hence to get the actual input features we need to
// query the output feature dimension
auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
auto kernel_out_features =
reverse_filter->shape().dimensions(kernel_out_feature_dim);
// For a depthwise convolution, the input features must be equal to the
// feature_group_count. We can leverage this property to match a depthwise
// convolution and thunk it to forward conv
if (conv->feature_group_count() > 1 &&
kernel_out_features == conv->feature_group_count()) {
return no_match_result;
}
// We pattern-match to a backwards input conv if:
//
// - all spatial dims of the filter are reversed