Use cudnn for grouped backward input convolution.
Previously we didn't enable this code because it looked like we had some benchmark regressions. It turns out this was just noise. PiperOrigin-RevId: 315456845 Change-Id: I32404152ad35692461808e2d8d449e21e55ac95c
This commit is contained in:
parent
55e93450d7
commit
f76e9c2915
|
@ -321,38 +321,11 @@ MatchBackwardInput(HloInstruction* conv) {
|
|||
const auto no_match_result =
|
||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
||||
|
||||
// TODO: Theoretically cuDNN supports grouped convolutions also
|
||||
// for the backward input convolution, but based on the cudnn's current state
|
||||
// there is not much performance improvement when using the
|
||||
// cudnn backward input API for grouped conv.
|
||||
// This needs to be re-evaluated for future cuDNN versions.
|
||||
// Note that we already have the necessary code down below, the only thing to
|
||||
// enable it is to remove the following early return.
|
||||
if (conv->feature_group_count() > 1) {
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// Match instruction pattern.
|
||||
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
|
||||
HloInstruction* reverse_filter = conv->mutable_operand(1);
|
||||
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
|
||||
|
||||
// Match BackwardInput for a depthwise convolution and thunk it to forward
|
||||
// convolution Output feature dimension and input feature dimension has been
|
||||
// swapped in the bridge. Hence to get the actual input features we need to
|
||||
// query the output feature dimension
|
||||
auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
|
||||
auto kernel_out_features =
|
||||
reverse_filter->shape().dimensions(kernel_out_feature_dim);
|
||||
|
||||
// For a depthwise convolution, the input features must be equal to the
|
||||
// feature_group_count. We can leverage this property to match a depthwise
|
||||
// convolution and thunk it to forward conv
|
||||
if (conv->feature_group_count() > 1 &&
|
||||
kernel_out_features == conv->feature_group_count()) {
|
||||
return no_match_result;
|
||||
}
|
||||
|
||||
// We pattern-match to a backwards input conv if:
|
||||
//
|
||||
// - all spatial dims of the filter are reversed
|
||||
|
|
Loading…
Reference in New Issue