diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc index 9d34bb39ba8..fb8c05798d8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc @@ -321,38 +321,11 @@ MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO: Theoretically cuDNN supports grouped convolutions also - // for the backward input convolution, but based on the cudnn's current state - // there is not much performance improvement when using the - // cudnn backward input API for grouped conv. - // This needs to be re-evaluated for future cuDNN versions. - // Note that we already have the necessary code down below, the only thing to - // enable it is to remove the following early return. - if (conv->feature_group_count() > 1) { - return no_match_result; - } - // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - // Match BackwardInput for a depthwise convolution and thunk it to forward - // convolution Output feature dimension and input feature dimension has been - // swapped in the bridge. Hence to get the actual input features we need to - // query the output feature dimension - auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension(); - auto kernel_out_features = - reverse_filter->shape().dimensions(kernel_out_feature_dim); - - // For a depthwise convolution, the input features must be equal to the - // feature_group_count. We can leverage this property to match a depthwise - // convolution and thunk it to forward conv - if (conv->feature_group_count() > 1 && - kernel_out_features == conv->feature_group_count()) { - return no_match_result; - } - // We pattern-match to a backwards input conv if: // // - all spatial dims of the filter are reversed