diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
index 5ddcf1d816b..de472d5d4fe 100644
--- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
@@ -570,8 +570,17 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
     // For in_depth == 1 and grouped convolutions.
     use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
     cudnn_use_autotune_ = CudnnUseAutotune();
-    use_cudnn_grouped_conv_ = false;
     dtype_ = DataTypeToEnum<T>::value;
+    // Use CuDNN grouped conv (input gradient) when stride = 1, input/output is
+    // NCHW and float16(half). See cudnn release note 7.6.3 (https://docs.nvidi
+    // a.com/deeplearning/sdk/cudnn-release-notes/rel_763.html#rel_763).
+#if CUDNN_VERSION >= 7603
+    use_cudnn_grouped_conv_ = dtype_ == DT_HALF &&
+                              data_format_ == FORMAT_NCHW && stride_ == 1 &&
+                              stride_w == 1;
+#else
+    use_cudnn_grouped_conv_ = false;
+#endif
   }
 
   void Compute(OpKernelContext* context) override {
@@ -605,7 +614,13 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
 
     // If in_depth==1, this operation is just a standard convolution.
     // Depthwise convolution is a special case of cuDNN's grouped convolution.
-    bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+    bool use_cudnn =
+        use_cudnn_ && (in_depth == 1 ||
+                       (use_cudnn_grouped_conv_ &&
+                        IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
+                                                   /*filter_cols=*/filter_cols,
+                                                   /*in_depth=*/in_depth,
+                                                   /*out_depth=*/out_depth)));
 
     VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
             << " Input: [" << batch << ", " << input_rows << ", " << input_cols
@@ -1044,7 +1059,6 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
     // For in_depth == 1 and grouped convolutions.
     use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
     cudnn_use_autotune_ = CudnnUseAutotune();
-    use_cudnn_grouped_conv_ = false;
 
     if (std::is_same<T, Eigen::half>::value) {
       dtype_ = DT_HALF;
@@ -1055,6 +1069,14 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
     } else {
       LOG(ERROR) << "Only half, float, and double are supported.";
     }
+    // Use CuDNN grouped conv (filter gradients) when input/output is
+    // float16(half). See cudnn release note 7.6.3. (https://docs.nvidia.com/dee
+    // plearning/sdk/cudnn-release-notes/rel_763.html#rel_763)
+#if CUDNN_VERSION >= 7603
+    use_cudnn_grouped_conv_ = dtype_ == DT_HALF;
+#else
+    use_cudnn_grouped_conv_ = false;
+#endif
   }
 
   void Compute(OpKernelContext* context) override {
@@ -1087,7 +1109,13 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
 
     // If in_depth==1, this operation is just a standard convolution.
     // Depthwise convolution is a special case of cuDNN's grouped convolution.
-    bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+    bool use_cudnn =
+        use_cudnn_ && (in_depth == 1 ||
+                       (use_cudnn_grouped_conv_ &&
+                        IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
+                                                   /*filter_cols=*/filter_cols,
+                                                   /*in_depth=*/in_depth,
+                                                   /*out_depth=*/out_depth)));
 
     VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
             << " Input: [" << batch << ", " << input_rows << ", " << input_cols
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index a7a0088fd3d..7777c18ddc5 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -296,8 +296,15 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
     // For in_depth == 1 and grouped convolutions.
     use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
     cudnn_use_autotune_ = CudnnUseAutotune();
-    use_cudnn_grouped_conv_ = false;
     dtype_ = DataTypeToEnum<T>::value;
+    // Use CuDNN grouped conv only when input/output is NCHW and float16(half).
+    // See cudnn release note 7.6.3. (https://docs.nvidia.com/deeplearning/sdk/c
+    // udnn-release-notes/rel_763.html#rel_763)
+#if CUDNN_VERSION >= 7603
+    use_cudnn_grouped_conv_ = dtype_ == DT_HALF && data_format_ == FORMAT_NCHW;
+#else
+    use_cudnn_grouped_conv_ = false;
+#endif
   }
 
   void Compute(OpKernelContext* context) override {
@@ -376,7 +383,13 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
     // TODO(csigg): Have autotune decide if native is faster than cuDNN.
     // If in_depth==1, this operation is just a standard convolution.
     // Depthwise convolution is a special case of cuDNN's grouped convolution.
-    bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
+    bool use_cudnn =
+        use_cudnn_ && (in_depth == 1 ||
+                       (use_cudnn_grouped_conv_ &&
+                        IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
+                                                   /*filter_cols=*/filter_cols,
+                                                   /*in_depth=*/in_depth,
+                                                   /*out_depth=*/out_depth)));
 
     VLOG(2) << "DepthwiseConv2dNative: "
             << " Input: [" << batch << ", " << input_rows << ", " << input_cols
diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc
index 144967a53ef..fc9988df5fd 100644
--- a/tensorflow/core/util/use_cudnn.cc
+++ b/tensorflow/core/util/use_cudnn.cc
@@ -90,4 +90,12 @@ FP16ConvMode CudnnConvComputeMode() {
   return FP16ConvMode::kAccurate;
 }
 
+bool IsCudnnSupportedFilterSize(const int32 filter_rows,
+                                const int32 filter_cols, const int32 in_depth,
+                                const int32 out_depth) {
+  return in_depth == out_depth && filter_rows == filter_cols &&
+         (filter_rows == 1 || filter_rows == 3 || filter_rows == 5 ||
+          filter_rows == 7);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h
index f8cc5944d71..bbacd349daf 100644
--- a/tensorflow/core/util/use_cudnn.h
+++ b/tensorflow/core/util/use_cudnn.h
@@ -39,6 +39,13 @@ FP16ConvMode CudnnConvComputeMode();
 bool DebugCudnnRnn();
 bool DebugCudnnRnnUseTensorOps();
 int64 DebugCudnnRnnAlgo();
+
+// Returns true if the CuDNN depthwise convolution can be used. See cudnn
+// release note 7.6.3.
+// (https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/rel_763.html)
+bool IsCudnnSupportedFilterSize(const int32 filter_rows,
+                                const int32 filter_cols, const int32 in_depth,
+                                const int32 out_depth);
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 4b207ae4290..0509fcad283 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -186,6 +186,26 @@ class DepthwiseConv2DTest(test.TestCase):
     self.assertShapeEqual(native_result, conv_native)
     self.assertShapeEqual(native_result, conv_interface)
 
+  @test_util.run_v1_only("b/120545219")
+  @test_util.run_cuda_only
+  def testDepthwiseConv2DCudnn(self):
+    for index, (input_size, filter_size, _, stride,
+                padding) in enumerate(ConfigsToTest()):
+      # The CuDNN depthwise conv is turned on only when input/output is NCHW and
+      # float16(half). See cudnn release note 7.6.3.
+      tf_logging.info(
+          "Testing DepthwiseConv2DCudnn, %dth config: %r * %r, stride: %d, "
+          "padding: %s", index, input_size, filter_size, stride, padding)
+      data_type = dtypes.float16
+      self._VerifyValues(
+          input_size,
+          filter_size,
+          stride,
+          padding,
+          data_type,
+          use_gpu=True,
+          data_format="NCHW")
+
   @test_util.run_v1_only("b/120545219")
   def testDepthwiseConv2D(self):
     for index, (input_size, filter_size, _, stride,
@@ -438,6 +458,32 @@ class DepthwiseConv2DTest(test.TestCase):
           use_gpu, grouped_conv, err)
       self.assertLess(err, tolerance)
 
+  @test_util.run_v1_only("b/120545219")
+  @test_util.run_cuda_only
+  def testDepthwiseConv2DInputGradCudnn(self):
+    for index, (input_size, filter_size, output_size, stride,
+                padding) in enumerate(CheckGradConfigsToTest()):
+      # The CuDNN depthwise conv (input gradient) is turned on only when
+      # stride = 1, input/output is NCHW and float16(half). See cudnn release
+      # note 7.6.3.
+      if stride != 1:
+        continue
+      tf_logging.info(
+          "Testing DepthwiseConv2DInputGradCudnn, %dth config: %r * %r, "
+          "stride: %d, padding: %s", index, input_size, filter_size, stride,
+          padding)
+      data_type = dtypes.float16
+      self._ConstructAndTestGradient(
+          input_size,
+          filter_size,
+          output_size,
+          stride,
+          padding,
+          data_type,
+          test_input=True,
+          use_gpu=True,
+          data_format="NCHW")
+
   @test_util.run_v1_only("b/120545219")
   def testDepthwiseConv2DInputGrad(self):
     for index, (input_size, filter_size, output_size, stride,
@@ -495,6 +541,39 @@ class DepthwiseConv2DTest(test.TestCase):
             use_gpu=True,
             data_format="NCHW")
 
+  @test_util.run_v1_only("b/120545219")
+  @test_util.run_cuda_only
+  def testDepthwiseConv2DFilterGradCudnn(self):
+    for index, (input_size, filter_size, output_size, stride,
+                padding) in enumerate(CheckGradConfigsToTest()):
+      # The CuDNN depthwise conv (filter gradient) is turned on only when
+      # input/output is float16(half). See cudnn release note 7.6.3.
+      tf_logging.info(
+          "Testing DepthwiseConv2DFilterGradCudnn, %dth config: %r * %r, "
+          "stride: %d, padding: %s", index, input_size, filter_size, stride,
+          padding)
+      data_type = dtypes.float16
+      self._ConstructAndTestGradient(
+          input_size,
+          filter_size,
+          output_size,
+          stride,
+          padding,
+          data_type,
+          test_input=False,
+          use_gpu=True,
+          data_format="NCHW")
+      self._ConstructAndTestGradient(
+          input_size,
+          filter_size,
+          output_size,
+          stride,
+          padding,
+          data_type,
+          test_input=False,
+          use_gpu=True,
+          data_format="NHWC")
+
   @test_util.run_v1_only("b/120545219")
   def testDepthwiseConv2DFilterGrad(self):
     for index, (input_size, filter_size, output_size, stride,