From ee82131dbccd4e99decb8c05c43bc2bb387ad6ac Mon Sep 17 00:00:00 2001
From: Reed Wanderman-Milne <reedwm@google.com>
Date: Fri, 19 Apr 2019 14:18:19 -0700
Subject: [PATCH] Support explicit padding on CPU for tf.nn.conv2d.

PiperOrigin-RevId: 244421216
---
 tensorflow/core/kernels/conv_2d.h             |  43 +-
 .../core/kernels/conv_grad_filter_ops.cc      |  27 +-
 .../core/kernels/conv_grad_input_ops.cc       |  27 +-
 tensorflow/core/kernels/conv_ops.cc           |  38 +-
 .../kernels/eigen_spatial_convolutions-inl.h  |  86 +-
 .../python/kernel_tests/conv_ops_test.py      | 824 +++++++++---------
 tensorflow/python/ops/nn_ops.py               |   2 +-
 7 files changed, 547 insertions(+), 500 deletions(-)

diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 1bac2a18c30..b735f78c2e3 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -57,11 +57,16 @@ void SpatialConvolutionFunc(const Device& d, Output output, Input input,
                             Filter filter, int row_stride, int col_stride,
                             int row_dilation, int col_dilation,
                             const Eigen::PaddingType& padding,
-                            const OutputKernel& output_kernel) {
-  // Need to swap row/col when calling Eigen.
-  output.device(d) =
-      Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding,
-                                col_dilation, row_dilation, output_kernel);
+                            const OutputKernel& output_kernel,
+                            int padding_top = 0, int padding_bottom = 0,
+                            int padding_left = 0, int padding_right = 0) {
+  // Need to swap row/col, padding_top/padding_left, and
+  // padding_bottom/padding_right when calling Eigen. Eigen expects the tensor
+  // in NWHC format, but the tensor given is in NHWC.
+  output.device(d) = Eigen::SpatialConvolution(
+      input, filter, col_stride, row_stride, padding, col_dilation,
+      row_dilation, output_kernel, padding_left, padding_right, padding_top,
+      padding_bottom);
 }
 
 template <typename Device, typename T,
@@ -76,6 +81,18 @@ struct SpatialConvolution {
     SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
                            row_dilation, col_dilation, padding, output_kernel);
   }
+  void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
+                  typename TTypes<T, 4>::ConstTensor input,
+                  typename TTypes<T, 4>::ConstTensor filter, int row_stride,
+                  int col_stride, int row_dilation, int col_dilation,
+                  int padding_top, int padding_bottom, int padding_left,
+                  int padding_right,
+                  const OutputKernel& output_kernel = OutputKernel()) {
+    SpatialConvolutionFunc(
+        d, output, input, filter, row_stride, col_stride, row_dilation,
+        col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel,
+        padding_top, padding_bottom, padding_left, padding_right);
+  }
 };
 
 template <typename Device, typename OutputKernel>
@@ -93,6 +110,22 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
                                   row_dilation, output_kernel)
             .template cast<Eigen::half>();
   }
+  void operator()(const Device& d,
+                  typename TTypes<Eigen::half, 4>::Tensor output,
+                  typename TTypes<Eigen::half, 4>::ConstTensor input,
+                  typename TTypes<Eigen::half, 4>::ConstTensor filter,
+                  int row_stride, int col_stride, int row_dilation,
+                  int col_dilation, int padding_top, int padding_bottom,
+                  int padding_left, int padding_right,
+                  const OutputKernel& output_kernel = OutputKernel()) {
+    output.device(d) =
+        Eigen::SpatialConvolution(
+            input.cast<float>(), filter.cast<float>(), col_stride, row_stride,
+            Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation,
+            output_kernel, padding_left, padding_right, padding_top,
+            padding_bottom)
+            .template cast<Eigen::half>();
+  }
 };
 
 template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 168a91a312a..e755c3e2041 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -208,14 +208,9 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
                 errors::InvalidArgument(
                     "Row and column strides should be larger than 0."));
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
-    OP_REQUIRES(
-        context, padding_ != Padding::EXPLICIT,
-        errors::Unimplemented("Current CPU implementation does not support "
-                              "EXPLICIT padding yet."));
-    std::vector<int64> explicit_paddings;
     OP_REQUIRES_OK(context,
-                   context->GetAttr("explicit_paddings", &explicit_paddings));
-    OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings,
+                   context->GetAttr("explicit_paddings", &explicit_paddings_));
+    OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
                                               /*num_dims=*/4, data_format_));
     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
     OP_REQUIRES(context, dilations_.size() == 4,
@@ -247,11 +242,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
                                 filter_sizes.vec<int32>(), &filter_shape));
 
     ConvBackpropDimensions dims;
-    OP_REQUIRES_OK(context,
-                   ConvBackpropComputeDimensions(
-                       "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2,
-                       input.shape(), filter_shape, out_backprop.shape(),
-                       strides_, padding_, data_format_, &dims));
+    OP_REQUIRES_OK(
+        context,
+        ConvBackpropComputeDimensionsV2(
+            "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
+            filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1},
+            strides_, padding_, explicit_paddings_, data_format_, &dims));
 
     Tensor* filter_backprop;
     OP_REQUIRES_OK(context,
@@ -264,6 +260,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
 
     int64 pad_top, pad_bottom;
     int64 pad_left, pad_right;
+    if (padding_ == Padding::EXPLICIT) {
+      pad_top = explicit_paddings_[2];
+      pad_bottom = explicit_paddings_[3];
+      pad_left = explicit_paddings_[4];
+      pad_right = explicit_paddings_[5];
+    }
     OP_REQUIRES_OK(
         context,
         GetWindowedOutputSizeVerbose(
@@ -402,6 +404,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
   std::vector<int32> dilations_;
   std::vector<int32> strides_;
   Padding padding_;
+  std::vector<int64> explicit_paddings_;
   TensorFormat data_format_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp);
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 471c73f65a4..4c1a0d9316b 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -299,14 +299,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
                 errors::InvalidArgument(
                     "Current libxsmm and customized CPU implementations do "
                     "not yet support dilation rates larger than 1."));
-    OP_REQUIRES(
-        context, padding_ != Padding::EXPLICIT,
-        errors::Unimplemented("Current CPU implementation does not support "
-                              "EXPLICIT padding yet."));
-    std::vector<int64> explicit_paddings;
     OP_REQUIRES_OK(context,
-                   context->GetAttr("explicit_paddings", &explicit_paddings));
-    OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings,
+                   context->GetAttr("explicit_paddings", &explicit_paddings_));
+    OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
                                               /*num_dims=*/4, data_format_));
   }
 
@@ -325,10 +320,11 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
 
     ConvBackpropDimensions dims;
     OP_REQUIRES_OK(context,
-                   ConvBackpropComputeDimensions(
+                   ConvBackpropComputeDimensionsV2(
                        "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
                        input_shape, filter.shape(), out_backprop.shape(),
-                       strides_, padding_, data_format_, &dims));
+                       /*dilations=*/{1, 1, 1, 1}, strides_, padding_,
+                       explicit_paddings_, data_format_, &dims));
 
     Tensor* in_backprop = nullptr;
     OP_REQUIRES_OK(context,
@@ -375,6 +371,12 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
     int64 pad_top, pad_bottom;
     int64 pad_left, pad_right;
 #endif
+    if (padding_ == Padding::EXPLICIT) {
+      pad_top = explicit_paddings_[2];
+      pad_bottom = explicit_paddings_[3];
+      pad_left = explicit_paddings_[4];
+      pad_right = explicit_paddings_[5];
+    }
     OP_REQUIRES_OK(
         context,
         GetWindowedOutputSizeVerbose(
@@ -536,6 +538,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
   std::vector<int32> dilations_;
   std::vector<int32> strides_;
   Padding padding_;
+  std::vector<int64> explicit_paddings_;
   TensorFormat data_format_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
@@ -617,12 +620,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
     use_cudnn_ &= CanUseCudnn();
     cudnn_use_autotune_ = CudnnUseAutotune();
     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
-    if (!std::is_same<Device, GPUDevice>::value) {
-      OP_REQUIRES(
-          context, padding_ != Padding::EXPLICIT,
-          errors::Unimplemented("Current CPU implementation does not support "
-                                "EXPLICIT padding yet."));
-    }
     OP_REQUIRES_OK(context,
                    context->GetAttr("explicit_paddings", &explicit_paddings_));
     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index f5ec3d91f6a..ec54ece9d7c 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -70,11 +70,12 @@ struct LaunchGeneric {
   void operator()(OpKernelContext* ctx, const Tensor& input,
                   const Tensor& filter, int row_stride, int col_stride,
                   int row_dilation, int col_dilation, const Padding& padding,
-                  Tensor* output, TensorFormat data_format) {
+                  const std::vector<int64>& explicit_paddings, Tensor* output,
+                  TensorFormat data_format) {
     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
                                          "supports NHWC tensor format for now.";
     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
-        col_stride == 1) {
+        col_stride == 1 && (padding == SAME || padding == VALID)) {
       // For 1x1 kernel, the 2D convolution is reduced to matrix
       // multiplication.
       //
@@ -110,10 +111,20 @@ struct LaunchGeneric {
           input.shaped<T, 2>({input.dim_size(0), k}),
           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
     } else {
-      functor::SpatialConvolution<Device, T>()(
-          ctx->eigen_device<Device>(), output->tensor<T, 4>(),
-          input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
-          row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
+      if (padding == EXPLICIT) {
+        functor::SpatialConvolution<Device, T>()(
+            ctx->eigen_device<Device>(), output->tensor<T, 4>(),
+            input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
+            row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]),
+            static_cast<int>(explicit_paddings[3]),
+            static_cast<int>(explicit_paddings[4]),
+            static_cast<int>(explicit_paddings[5]));
+      } else {
+        functor::SpatialConvolution<Device, T>()(
+            ctx->eigen_device<Device>(), output->tensor<T, 4>(),
+            input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
+            row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
+      }
     }
   }
 };
@@ -133,18 +144,19 @@ struct LaunchConv2DOp<CPUDevice, T> {
                                 "NHWC tensor format for now."));
       return;
     }
-    // TODO(reedwm): Enable explicit padding on the CPU.
-    OP_REQUIRES(
-        ctx, padding != Padding::EXPLICIT,
-        errors::Unimplemented("Generic conv implementation does not support "
-                              "EXPLICIT padding yet."));
     const int64 in_depth = GetTensorDim(input, data_format, 'C');
     OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
                 errors::Unimplemented("Generic conv implementation does not "
                                       "support grouped convolutions for now."));
+    for (int64 explicit_padding : explicit_paddings) {
+      if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
+        ctx->SetStatus(errors::InvalidArgument("filter too large"));
+        return;
+      }
+    }
     LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
-                                  row_dilation, col_dilation, padding, output,
-                                  data_format);
+                                  row_dilation, col_dilation, padding,
+                                  explicit_paddings, output, data_format);
   }
 };
 
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h
index a2afab42ec1..4559ac3837c 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h
@@ -1313,6 +1313,10 @@ struct gemm_pack_rhs<
  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
  * pixels.
  *
+ * If padding_top, padding_bottom, padding_left, or padding_right is specified,
+ * then those paddings will be used to pad the input, and padding_type must be
+ * PADDING_VALID.
+ *
  * The result can be assigned to a tensor of rank equal to the rank of the
  * input. The dimensions of the result will be filters, height, width (and
  * others if applicable).
@@ -1360,7 +1364,9 @@ EIGEN_DEVICE_FUNC
                        const PaddingType padding_type = PADDING_SAME,
                        const Index row_in_stride = 1,
                        const Index col_in_stride = 1,
-                       const OutputKernel& output_kernel = OutputKernel()) {
+                       const OutputKernel& output_kernel = OutputKernel(),
+                       Index padding_top = 0, Index padding_bottom = 0,
+                       Index padding_left = 0, Index padding_right = 0) {
   typedef typename internal::traits<Input>::Index TensorIndex;
   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
                    internal::traits<Input>::NumDimensions,
@@ -1402,25 +1408,33 @@ EIGEN_DEVICE_FUNC
       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
   const TensorIndex InputCols =
       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
+  const bool padding_explicit =
+      (padding_top || padding_bottom || padding_left || padding_right);
 
   TensorIndex out_height;
   TensorIndex out_width;
   switch (padding_type) {
-    case PADDING_VALID:
-      out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) /
+    case PADDING_VALID: {
+      const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
+      const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
+      out_height = numext::ceil((InputRowsEff - kernelRowsEff + 1.f) /
                                 static_cast<float>(row_stride));
-      out_width = numext::ceil((InputCols - kernelColsEff + 1.f) /
+      out_width = numext::ceil((InputColsEff - kernelColsEff + 1.f) /
                                static_cast<float>(col_stride));
       break;
-    case PADDING_SAME:
+    }
+    case PADDING_SAME: {
+      eigen_assert(!padding_explicit);
       out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
       out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
       break;
-    default:
+    }
+    default: {
       // Initialize unused variables to avoid a compiler warning
       out_height = 0;
       out_width = 0;
       eigen_assert(false && "unexpected padding");
+    }
   }
 
   // Molds the output of the patch extraction code into a 2d tensor:
@@ -1473,22 +1487,50 @@ EIGEN_DEVICE_FUNC
     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
     kernel_dims[1] = kernelFilters;
   }
-  return choose(
-      Cond<internal::traits<Input>::Layout == ColMajor>(),
-      kernel.reshape(kernel_dims)
-          .contract(input
-                        .extract_image_patches(
-                            kernelRows, kernelCols, row_stride, col_stride,
-                            row_in_stride, col_in_stride, padding_type)
-                        .reshape(pre_contract_dims),
-                    contract_dims, output_kernel)
-          .reshape(post_contract_dims),
-      input
-          .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
-                                 row_in_stride, col_in_stride, padding_type)
-          .reshape(pre_contract_dims)
-          .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
-          .reshape(post_contract_dims));
+  if (padding_explicit) {
+    return choose(
+        Cond<internal::traits<Input>::Layout == ColMajor>(),
+        kernel.reshape(kernel_dims)
+            .contract(input
+                          .extract_image_patches(
+                              kernelRows, kernelCols, row_stride, col_stride,
+                              row_in_stride, col_in_stride,
+                              /*row_inflate_stride=*/1,
+                              /*col_inflate_stride=*/1, padding_top,
+                              padding_bottom, padding_left, padding_right,
+                              /*padding_value=*/0)
+                          .reshape(pre_contract_dims),
+                      contract_dims, output_kernel)
+            .reshape(post_contract_dims),
+        input
+            .extract_image_patches(kernelRows, kernelCols, row_stride,
+                                   col_stride, row_in_stride, col_in_stride,
+                                   /*row_inflate_stride=*/1,
+                                   /*col_inflate_stride=*/1, padding_top,
+                                   padding_bottom, padding_left, padding_right,
+                                   /*padding_value=*/0)
+            .reshape(pre_contract_dims)
+            .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
+            .reshape(post_contract_dims));
+  } else {
+    return choose(
+        Cond<internal::traits<Input>::Layout == ColMajor>(),
+        kernel.reshape(kernel_dims)
+            .contract(input
+                          .extract_image_patches(
+                              kernelRows, kernelCols, row_stride, col_stride,
+                              row_in_stride, col_in_stride, padding_type)
+                          .reshape(pre_contract_dims),
+                      contract_dims, output_kernel)
+            .reshape(post_contract_dims),
+        input
+            .extract_image_patches(kernelRows, kernelCols, row_stride,
+                                   col_stride, row_in_stride, col_in_stride,
+                                   padding_type)
+            .reshape(pre_contract_dims)
+            .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
+            .reshape(post_contract_dims));
+  }
 }
 
 }  // end namespace Eigen
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index b7290497702..0bec67f5213 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -403,7 +403,6 @@ class Conv2DTest(test.TestCase):
         padding,
         expected,
         dilations,
-        gpu_only=True,
         test_grappler_layout_optimizer=test_grappler_layout_optimizer,
         tol=tol,
         fp16_tol=fp16_tol)
@@ -1429,8 +1428,14 @@ class Conv2DTest(test.TestCase):
                                                 strides,
                                                 padding,
                                                 data_format,
+                                                use_gpu,
                                                 dilations=(1, 1),
                                                 err=2e-5):
+    if use_gpu and not test.is_gpu_available(cuda_only=True):
+      return
+    if not use_gpu and dilations != (1, 1):
+      return  # Non-default dilations is currently not supported on the CPU.
+
     x1 = self._CreateNumpyTensor(filter_sizes)
     x2 = self._CreateNumpyTensor(output_sizes)
     dilations = list(dilations)
@@ -1455,133 +1460,128 @@ class Conv2DTest(test.TestCase):
         padding,
         expected,
         data_format,
-        use_gpu=True,
+        use_gpu=use_gpu,
         err=err,
         dilations=dilations)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding0x0BackpropInput(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 2, 3, 1],
-            filter_sizes=[2, 2, 1, 1],
-            output_sizes=[1, 1, 2, 1],
-            strides=[1, 1],
-            padding=[[0, 0], [0, 0]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 2, 3, 1],
+          filter_sizes=[2, 2, 1, 1],
+          output_sizes=[1, 1, 2, 1],
+          strides=[1, 1],
+          padding=[[0, 0], [0, 0]],
+          data_format=data_format,
+          use_gpu=use_gpu)
 
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 3, 4, 2],
-            filter_sizes=[2, 2, 2, 3],
-            output_sizes=[1, 1, 2, 3],
-            strides=[2, 2],
-            padding=[[0, 0], [0, 0]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 3, 4, 2],
+          filter_sizes=[2, 2, 2, 3],
+          output_sizes=[1, 1, 2, 3],
+          strides=[2, 2],
+          padding=[[0, 0], [0, 0]],
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding1x1BackpropInput(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
-
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 2, 3, 1],
-            filter_sizes=[2, 2, 1, 2],
-            output_sizes=[1, 3, 4, 2],
-            strides=[1, 1],
-            padding=[[1, 1], [1, 1]],
-            data_format=data_format, err=1e-4)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 2, 3, 1],
+          filter_sizes=[2, 2, 1, 2],
+          output_sizes=[1, 3, 4, 2],
+          strides=[1, 1],
+          padding=[[1, 1], [1, 1]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          err=1e-4)
 
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 2, 3, 2],
-            filter_sizes=[1, 1, 2, 1],
-            output_sizes=[1, 4, 3, 1],
-            strides=[1, 2],
-            padding=[[1, 1], [1, 1]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 2, 3, 2],
+          filter_sizes=[1, 1, 2, 1],
+          output_sizes=[1, 4, 3, 1],
+          strides=[1, 2],
+          padding=[[1, 1], [1, 1]],
+          data_format=data_format,
+          use_gpu=use_gpu)
 
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 4, 3, 1],
-            filter_sizes=[2, 2, 1, 1],
-            output_sizes=[1, 4, 2, 1],
-            strides=[1, 2],
-            padding=[[1, 1], [1, 1]],
-            data_format=data_format,
-            dilations=[2, 2])
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 4, 3, 1],
+          filter_sizes=[2, 2, 1, 1],
+          output_sizes=[1, 4, 2, 1],
+          strides=[1, 2],
+          padding=[[1, 1], [1, 1]],
+          data_format=data_format,
+          dilations=[2, 2], use_gpu=use_gpu)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding2x2BackpropInput(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
-
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[2, 3, 1, 1],
-            filter_sizes=[2, 1, 1, 1],
-            output_sizes=[2, 2, 5, 1],
-            strides=[3, 1],
-            padding=[[2, 2], [2, 2]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[2, 3, 1, 1],
+          filter_sizes=[2, 1, 1, 1],
+          output_sizes=[2, 2, 5, 1],
+          strides=[3, 1],
+          padding=[[2, 2], [2, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu)
 
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 3, 6, 1],
-            filter_sizes=[3, 2, 1, 1],
-            output_sizes=[1, 3, 4, 1],
-            strides=[1, 2],
-            padding=[[2, 2], [2, 2]],
-            data_format=data_format,
-            dilations=[2, 3])
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 3, 6, 1],
+          filter_sizes=[3, 2, 1, 1],
+          output_sizes=[1, 3, 4, 1],
+          strides=[1, 2],
+          padding=[[2, 2], [2, 2]],
+          data_format=data_format,
+          dilations=[2, 3],
+          use_gpu=use_gpu)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding_1_8_4_1_BackpropInput(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 2, 3, 1],
-            filter_sizes=[2, 2, 1, 1],
-            output_sizes=[1, 10, 8, 1],
-            strides=[1, 1],
-            padding=[[1, 8], [4, 2]],
-            data_format=data_format, err=5e-5)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 2, 3, 1],
+          filter_sizes=[2, 2, 1, 1],
+          output_sizes=[1, 10, 8, 1],
+          strides=[1, 1],
+          padding=[[1, 8], [4, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          err=5e-5)
 
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 5, 3, 1],
-            filter_sizes=[3, 2, 1, 1],
-            output_sizes=[1, 4, 8, 1],
-            strides=[3, 1],
-            padding=[[1, 8], [4, 2]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 5, 3, 1],
+          filter_sizes=[3, 2, 1, 1],
+          output_sizes=[1, 4, 8, 1],
+          strides=[3, 1],
+          padding=[[1, 8], [4, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding_5_0_2_2_BackpropInput(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 3, 3, 1],
-            filter_sizes=[2, 1, 1, 1],
-            output_sizes=[1, 7, 7, 1],
-            strides=[1, 1],
-            padding=[[5, 0], [2, 2]],
-            data_format=data_format,
-            err=5e-5)
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 3, 3, 1],
+          filter_sizes=[2, 1, 1, 1],
+          output_sizes=[1, 7, 7, 1],
+          strides=[1, 1],
+          padding=[[5, 0], [2, 2]],
+          data_format=data_format,
+          err=5e-5,
+          use_gpu=use_gpu)
 
-        self._RunAndVerifyBackpropInputExplicitPadding(
-            input_sizes=[1, 4, 2, 1],
-            filter_sizes=[3, 3, 1, 1],
-            output_sizes=[1, 5, 2, 1],
-            strides=[1, 2],
-            padding=[[5, 0], [2, 2]],
-            data_format=data_format,
-            dilations=[2, 1])
+      self._RunAndVerifyBackpropInputExplicitPadding(
+          input_sizes=[1, 4, 2, 1],
+          filter_sizes=[3, 3, 1, 1],
+          output_sizes=[1, 5, 2, 1],
+          strides=[1, 2],
+          padding=[[5, 0], [2, 2]],
+          data_format=data_format,
+          dilations=[2, 1],
+          use_gpu=use_gpu)
 
   def _RunAndVerifyBackpropFilterExplicitPadding(self,
                                                  input_sizes,
@@ -1590,8 +1590,14 @@ class Conv2DTest(test.TestCase):
                                                  strides,
                                                  padding,
                                                  data_format,
+                                                 use_gpu,
                                                  dilations=(1, 1),
                                                  err=1e-5):
+    if use_gpu and not test.is_gpu_available(cuda_only=True):
+      return
+    if not use_gpu and dilations != (1, 1):
+      return  # Non-default dilations is currently not supported on the CPU.
+
     x0 = self._CreateNumpyTensor(input_sizes)
     x2 = self._CreateNumpyTensor(output_sizes)
     dilations = list(dilations)
@@ -1613,135 +1619,127 @@ class Conv2DTest(test.TestCase):
         padding,
         expected,
         data_format,
-        use_gpu=True,
+        use_gpu=use_gpu,
         dilations=dilations,
         err=err)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding0x0BackpropFilter(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 2, 3, 1],
-            filter_sizes=[2, 2, 1, 1],
-            output_sizes=[1, 1, 2, 1],
-            strides=[1, 1],
-            padding=[[0, 0], [0, 0]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 2, 3, 1],
+          filter_sizes=[2, 2, 1, 1],
+          output_sizes=[1, 1, 2, 1],
+          strides=[1, 1],
+          padding=[[0, 0], [0, 0]],
+          data_format=data_format, use_gpu=use_gpu)
 
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 3, 4, 2],
-            filter_sizes=[2, 2, 2, 3],
-            output_sizes=[1, 1, 2, 3],
-            strides=[2, 2],
-            padding=[[0, 0], [0, 0]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 3, 4, 2],
+          filter_sizes=[2, 2, 2, 3],
+          output_sizes=[1, 1, 2, 3],
+          strides=[2, 2],
+          padding=[[0, 0], [0, 0]],
+          data_format=data_format, use_gpu=use_gpu)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding1x1BackpropFilter(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
-
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 2, 3, 1],
-            filter_sizes=[2, 2, 1, 2],
-            output_sizes=[1, 3, 4, 2],
-            strides=[1, 1],
-            padding=[[1, 1], [1, 1]],
-            data_format=data_format,
-            err=5e-5)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 2, 3, 1],
+          filter_sizes=[2, 2, 1, 2],
+          output_sizes=[1, 3, 4, 2],
+          strides=[1, 1],
+          padding=[[1, 1], [1, 1]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          err=5e-5)
 
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 2, 3, 2],
-            filter_sizes=[1, 1, 2, 1],
-            output_sizes=[1, 4, 3, 1],
-            strides=[1, 2],
-            padding=[[1, 1], [1, 1]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 2, 3, 2],
+          filter_sizes=[1, 1, 2, 1],
+          output_sizes=[1, 4, 3, 1],
+          strides=[1, 2],
+          padding=[[1, 1], [1, 1]],
+          use_gpu=use_gpu,
+          data_format=data_format)
 
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 4, 3, 1],
-            filter_sizes=[2, 2, 1, 1],
-            output_sizes=[1, 4, 2, 1],
-            strides=[1, 2],
-            padding=[[1, 1], [1, 1]],
-            data_format=data_format,
-            dilations=[2, 2])
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 4, 3, 1],
+          filter_sizes=[2, 2, 1, 1],
+          output_sizes=[1, 4, 2, 1],
+          strides=[1, 2],
+          padding=[[1, 1], [1, 1]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          dilations=[2, 2])
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding2x2BackpropFilter(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
-
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[2, 3, 1, 1],
-            filter_sizes=[2, 1, 1, 1],
-            output_sizes=[2, 2, 5, 1],
-            strides=[3, 1],
-            padding=[[2, 2], [2, 2]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[2, 3, 1, 1],
+          filter_sizes=[2, 1, 1, 1],
+          output_sizes=[2, 2, 5, 1],
+          strides=[3, 1],
+          padding=[[2, 2], [2, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu)
 
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 3, 6, 1],
-            filter_sizes=[3, 2, 1, 1],
-            output_sizes=[1, 3, 4, 1],
-            strides=[1, 2],
-            padding=[[2, 2], [2, 2]],
-            data_format=data_format,
-            dilations=[2, 3])
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 3, 6, 1],
+          filter_sizes=[3, 2, 1, 1],
+          output_sizes=[1, 3, 4, 1],
+          strides=[1, 2],
+          padding=[[2, 2], [2, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          dilations=[2, 3])
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding_1_8_4_1_BackpropFilter(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 2, 3, 1],
-            filter_sizes=[2, 2, 1, 1],
-            output_sizes=[1, 10, 8, 1],
-            strides=[1, 1],
-            padding=[[1, 8], [4, 2]],
-            data_format=data_format,
-            err=1e-4)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 2, 3, 1],
+          filter_sizes=[2, 2, 1, 1],
+          output_sizes=[1, 10, 8, 1],
+          strides=[1, 1],
+          padding=[[1, 8], [4, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          err=1e-4)
 
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 5, 3, 1],
-            filter_sizes=[3, 2, 1, 1],
-            output_sizes=[1, 4, 8, 1],
-            strides=[3, 1],
-            padding=[[1, 8], [4, 2]],
-            data_format=data_format)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 5, 3, 1],
+          filter_sizes=[3, 2, 1, 1],
+          output_sizes=[1, 4, 8, 1],
+          strides=[3, 1],
+          padding=[[1, 8], [4, 2]],
+          use_gpu=use_gpu,
+          data_format=data_format)
 
   @test_util.run_in_graph_and_eager_modes()
   def testConv2D2x2Depth1Padding_5_0_2_2_BackpropFilter(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 3, 3, 1],
-            filter_sizes=[2, 1, 1, 1],
-            output_sizes=[1, 7, 7, 1],
-            strides=[1, 1],
-            padding=[[5, 0], [2, 2]],
-            data_format=data_format,
-            err=1e-4)
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 3, 3, 1],
+          filter_sizes=[2, 1, 1, 1],
+          output_sizes=[1, 7, 7, 1],
+          strides=[1, 1],
+          padding=[[5, 0], [2, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          err=1e-4)
 
-        self._RunAndVerifyBackpropFilterExplicitPadding(
-            input_sizes=[1, 4, 2, 1],
-            filter_sizes=[3, 3, 1, 1],
-            output_sizes=[1, 5, 2, 1],
-            strides=[1, 2],
-            padding=[[5, 0], [2, 2]],
-            data_format=data_format,
-            dilations=[2, 1])
+      self._RunAndVerifyBackpropFilterExplicitPadding(
+          input_sizes=[1, 4, 2, 1],
+          filter_sizes=[3, 3, 1, 1],
+          output_sizes=[1, 5, 2, 1],
+          strides=[1, 2],
+          padding=[[5, 0], [2, 2]],
+          data_format=data_format,
+          use_gpu=use_gpu,
+          dilations=[2, 1])
 
   # Gradient checkers
   def ConstructAndTestGradient(self,
@@ -2107,257 +2105,221 @@ class Conv2DTest(test.TestCase):
 
   @test_util.deprecated_graph_mode_only
   def testInputGradient1x1PaddingStrideOne(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=5,
-            input_cols=4,
-            filter_rows=3,
-            filter_cols=3,
-            in_depth=2,
-            out_depth=3,
-            stride_rows=1,
-            stride_cols=1,
-            padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
-            test_input=True,
-            data_format=data_format,
-            use_gpu=use_gpu,
-            max_err=0.0025)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=5,
+          input_cols=4,
+          filter_rows=3,
+          filter_cols=3,
+          in_depth=2,
+          out_depth=3,
+          stride_rows=1,
+          stride_cols=1,
+          padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
+          test_input=True,
+          data_format=data_format,
+          use_gpu=use_gpu,
+          max_err=0.0025)
 
   @test_util.deprecated_graph_mode_only
   def testFilterGradient1x1PaddingStrideOne(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=5,
-            input_cols=4,
-            filter_rows=3,
-            filter_cols=3,
-            in_depth=2,
-            out_depth=3,
-            stride_rows=1,
-            stride_cols=1,
-            padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
-            test_input=False,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=5,
+          input_cols=4,
+          filter_rows=3,
+          filter_cols=3,
+          in_depth=2,
+          out_depth=3,
+          stride_rows=1,
+          stride_cols=1,
+          padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
+          test_input=False,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testInputGradient1x1PaddingStrideTwo(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=4,
-            input_cols=5,
-            filter_rows=3,
-            filter_cols=3,
-            in_depth=2,
-            out_depth=3,
-            stride_rows=2,
-            stride_cols=2,
-            padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
-            test_input=True,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=4,
+          input_cols=5,
+          filter_rows=3,
+          filter_cols=3,
+          in_depth=2,
+          out_depth=3,
+          stride_rows=2,
+          stride_cols=2,
+          padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
+          test_input=True,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testFilterGradient1x1PaddingStrideTwo(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=4,
-            input_cols=5,
-            filter_rows=3,
-            filter_cols=3,
-            in_depth=2,
-            out_depth=3,
-            stride_rows=2,
-            stride_cols=2,
-            padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
-            test_input=False,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=4,
+          input_cols=5,
+          filter_rows=3,
+          filter_cols=3,
+          in_depth=2,
+          out_depth=3,
+          stride_rows=2,
+          stride_cols=2,
+          padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
+          test_input=False,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testInputGradient2x2PaddingStrideOne(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=5,
-            input_cols=4,
-            filter_rows=3,
-            filter_cols=3,
-            in_depth=2,
-            out_depth=3,
-            stride_rows=1,
-            stride_cols=1,
-            padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
-            test_input=True,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=5,
+          input_cols=4,
+          filter_rows=3,
+          filter_cols=3,
+          in_depth=2,
+          out_depth=3,
+          stride_rows=1,
+          stride_cols=1,
+          padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
+          test_input=True,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testFilterGradient2x2PaddingStrideOne(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=5,
-            input_cols=4,
-            filter_rows=3,
-            filter_cols=3,
-            in_depth=2,
-            out_depth=3,
-            stride_rows=1,
-            stride_cols=1,
-            padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
-            test_input=False,
-            data_format=data_format,
-            use_gpu=use_gpu,
-            max_err=0.003)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=5,
+          input_cols=4,
+          filter_rows=3,
+          filter_cols=3,
+          in_depth=2,
+          out_depth=3,
+          stride_rows=1,
+          stride_cols=1,
+          padding=[[0, 0], [2, 2], [2, 2], [0, 0]],
+          test_input=False,
+          data_format=data_format,
+          use_gpu=use_gpu,
+          max_err=0.003)
 
   @test_util.deprecated_graph_mode_only
   def testInputGradient1_2_3_4PaddingStride3x2(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=8,
-            input_cols=5,
-            filter_rows=4,
-            filter_cols=2,
-            in_depth=3,
-            out_depth=2,
-            stride_rows=3,
-            stride_cols=2,
-            padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
-            test_input=True,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=8,
+          input_cols=5,
+          filter_rows=4,
+          filter_cols=2,
+          in_depth=3,
+          out_depth=2,
+          stride_rows=3,
+          stride_cols=2,
+          padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
+          test_input=True,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testFilterGradient1_2_3_4PaddingStride3x2(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=8,
-            input_cols=5,
-            filter_rows=4,
-            filter_cols=2,
-            in_depth=3,
-            out_depth=2,
-            stride_rows=3,
-            stride_cols=2,
-            padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
-            test_input=False,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=8,
+          input_cols=5,
+          filter_rows=4,
+          filter_cols=2,
+          in_depth=3,
+          out_depth=2,
+          stride_rows=3,
+          stride_cols=2,
+          padding=[[0, 0], [1, 2], [3, 4], [0, 0]],
+          test_input=False,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testInputGradient4_3_2_1PaddingStride2x1(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=3,
-            input_rows=5,
-            input_cols=7,
-            filter_rows=3,
-            filter_cols=2,
-            in_depth=1,
-            out_depth=2,
-            stride_rows=2,
-            stride_cols=1,
-            padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
-            test_input=True,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=3,
+          input_rows=5,
+          input_cols=7,
+          filter_rows=3,
+          filter_cols=2,
+          in_depth=1,
+          out_depth=2,
+          stride_rows=2,
+          stride_cols=1,
+          padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
+          test_input=True,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testFilterGradient4_3_2_1PaddingStride2x1(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=3,
-            input_rows=5,
-            input_cols=7,
-            filter_rows=3,
-            filter_cols=2,
-            in_depth=1,
-            out_depth=2,
-            stride_rows=2,
-            stride_cols=1,
-            padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
-            test_input=False,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=3,
+          input_rows=5,
+          input_cols=7,
+          filter_rows=3,
+          filter_cols=2,
+          in_depth=1,
+          out_depth=2,
+          stride_rows=2,
+          stride_cols=1,
+          padding=[[0, 0], [4, 3], [2, 1], [0, 0]],
+          test_input=False,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testInputGradient0_0_0_5PaddingStride1x2(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=6,
-            input_cols=7,
-            filter_rows=3,
-            filter_cols=4,
-            in_depth=3,
-            out_depth=2,
-            stride_rows=1,
-            stride_cols=2,
-            padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
-            test_input=True,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=6,
+          input_cols=7,
+          filter_rows=3,
+          filter_cols=4,
+          in_depth=3,
+          out_depth=2,
+          stride_rows=1,
+          stride_cols=2,
+          padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
+          test_input=True,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testFilterGradient0_0_0_5PaddingStride1x2(self):
-    if not test.is_gpu_available(cuda_only=True):
-      return
     for (data_format, use_gpu) in GetTestConfigs():
-      if use_gpu:
-        self.ConstructAndTestGradient(
-            batch=2,
-            input_rows=6,
-            input_cols=7,
-            filter_rows=3,
-            filter_cols=4,
-            in_depth=3,
-            out_depth=2,
-            stride_rows=1,
-            stride_cols=2,
-            padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
-            test_input=False,
-            data_format=data_format,
-            use_gpu=use_gpu)
+      self.ConstructAndTestGradient(
+          batch=2,
+          input_rows=6,
+          input_cols=7,
+          filter_rows=3,
+          filter_cols=4,
+          in_depth=3,
+          out_depth=2,
+          stride_rows=1,
+          stride_cols=2,
+          padding=[[0, 0], [0, 0], [0, 5], [0, 0]],
+          test_input=False,
+          data_format=data_format,
+          use_gpu=use_gpu)
 
   @test_util.deprecated_graph_mode_only
   def testShapeFunctionEdgeCases(self):
@@ -2505,31 +2467,29 @@ class Conv2DTest(test.TestCase):
                 strides=[1, 1, 1, 1],
                 padding=[[0, 0], [2, 2], [2, 2], [0, 0]]))
 
-    if test.is_gpu_available(cuda_only=True):
-      with self.test_session(use_gpu=True):
-        # Negative padding during backprop.
-        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
-                                     "nonnegative"):
-          sess.run(
-              nn_ops.conv2d_backprop_input([32, 20, 20, 3],
-                                           array_ops.placeholder(
-                                               dtypes.float32,
-                                               shape=[18, 18, 3, 2]),
-                                           array_ops.placeholder(
-                                               dtypes.float32,
-                                               shape=[32, 3, 2, 2]),
-                                           strides=[1, 1, 1, 1],
-                                           padding=[[0, 0], [-1, 0], [0, 0],
-                                                    [0, 0]]))
-        with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
-                                     "nonnegative"):
-          sess.run(
-              nn_ops.conv2d_backprop_filter(
-                  array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
-                  [18, 18, 3, 2],
-                  array_ops.placeholder(dtypes.float32, shape=[32, 3, 2, 2]),
-                  strides=[1, 1, 1, 1],
-                  padding=[[0, 0], [-1, 0], [0, 0], [0, 0]]))
+      # Negative padding during backprop.
+      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+                                   "nonnegative"):
+        sess.run(
+            nn_ops.conv2d_backprop_input([32, 20, 20, 3],
+                                         array_ops.placeholder(
+                                             dtypes.float32,
+                                             shape=[18, 18, 3, 2]),
+                                         array_ops.placeholder(
+                                             dtypes.float32,
+                                             shape=[32, 3, 2, 2]),
+                                         strides=[1, 1, 1, 1],
+                                         padding=[[0, 0], [-1, 0], [0, 0],
+                                                  [0, 0]]))
+      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+                                   "nonnegative"):
+        sess.run(
+            nn_ops.conv2d_backprop_filter(
+                array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
+                [18, 18, 3, 2],
+                array_ops.placeholder(dtypes.float32, shape=[32, 3, 2, 2]),
+                strides=[1, 1, 1, 1],
+                padding=[[0, 0], [-1, 0], [0, 0], [0, 0]]))
 
 
 class DepthwiseConv2DTest(test.TestCase):
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 97cbe55403e..50583c2a893 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1904,7 +1904,7 @@ def conv2d(  # pylint: disable=redefined-builtin,dangerous-default-value
       value is given it is replicated in the `H` and `W` dimension. By default
       the `N` and `C` dimensions are set to 1. The dimension order is determined
       by the value of `data_format`, see below for details.
-    padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
+    padding: Either the `string` `"SAME"` or `"VALID"` indicating the type of
       padding algorithm to use, or a list indicating the explicit paddings at
       the start and end of each dimension. When explicit padding is used and
       data_format is `"NHWC"`, this should be in the form `[[0, 0], [pad_top,