Use cudnn for grouped backward input convolution.
Previously we didn't enable this code because it looked like we had some benchmark regressions. It turns out this was just noise. PiperOrigin-RevId: 315456845 Change-Id: I32404152ad35692461808e2d8d449e21e55ac95c
This commit is contained in:
parent
55e93450d7
commit
f76e9c2915
|
@ -321,38 +321,11 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||||
const auto no_match_result =
|
const auto no_match_result =
|
||||||
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
|
||||||
|
|
||||||
// TODO: Theoretically cuDNN supports grouped convolutions also
|
|
||||||
// for the backward input convolution, but based on the cudnn's current state
|
|
||||||
// there is not much performance improvement when using the
|
|
||||||
// cudnn backward input API for grouped conv.
|
|
||||||
// This needs to be re-evaluated for future cuDNN versions.
|
|
||||||
// Note that we already have the necessary code down below, the only thing to
|
|
||||||
// enable it is to remove the following early return.
|
|
||||||
if (conv->feature_group_count() > 1) {
|
|
||||||
return no_match_result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Match instruction pattern.
|
// Match instruction pattern.
|
||||||
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
|
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
|
||||||
HloInstruction* reverse_filter = conv->mutable_operand(1);
|
HloInstruction* reverse_filter = conv->mutable_operand(1);
|
||||||
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
|
ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
|
||||||
|
|
||||||
// Match BackwardInput for a depthwise convolution and thunk it to forward
|
|
||||||
// convolution Output feature dimension and input feature dimension has been
|
|
||||||
// swapped in the bridge. Hence to get the actual input features we need to
|
|
||||||
// query the output feature dimension
|
|
||||||
auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension();
|
|
||||||
auto kernel_out_features =
|
|
||||||
reverse_filter->shape().dimensions(kernel_out_feature_dim);
|
|
||||||
|
|
||||||
// For a depthwise convolution, the input features must be equal to the
|
|
||||||
// feature_group_count. We can leverage this property to match a depthwise
|
|
||||||
// convolution and thunk it to forward conv
|
|
||||||
if (conv->feature_group_count() > 1 &&
|
|
||||||
kernel_out_features == conv->feature_group_count()) {
|
|
||||||
return no_match_result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We pattern-match to a backwards input conv if:
|
// We pattern-match to a backwards input conv if:
|
||||||
//
|
//
|
||||||
// - all spatial dims of the filter are reversed
|
// - all spatial dims of the filter are reversed
|
||||||
|
|
Loading…
Reference in New Issue