From c2b3222ac552e9698968c9a212095dbc8b9ca40b Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev <ezhulenev@google.com>
Date: Wed, 12 Sep 2018 14:17:17 -0700
Subject: [PATCH] Use Eigen::CuboidConvolutionBackwardKernel in
 Conv3DBackpropFilter.

Instead of multiple primitive Eigen ops in Conv3DBackpropFilter, call directly into Eigen function.

Modest ~10-25% latency improvement and ~10-20% peak memory reduction.

PiperOrigin-RevId: 212701797
---
 tensorflow/core/kernels/conv_3d.h             |  21 ++
 tensorflow/core/kernels/conv_grad_ops_3d.cc   |  76 +----
 .../eigen_backward_cuboid_convolutions.h      | 295 ++++++++++++------
 ...igen_backward_spatial_convolutions_test.cc |  31 +-
 4 files changed, 251 insertions(+), 172 deletions(-)

diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index e5054e062e5..b819c6f9103 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -33,6 +33,10 @@ struct CuboidConvolution;
 template <typename Device, typename T>
 struct CuboidConvolutionBackwardInput;
 
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
 template <typename T>
@@ -64,6 +68,23 @@ struct CuboidConvolutionBackwardInput<CPUDevice, T> {
   }
 };
 
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+  void operator()(const CPUDevice& d,
+                  typename TTypes<T, 5>::Tensor filter_backward,
+                  typename TTypes<T, 5>::ConstTensor input,
+                  typename TTypes<T, 5>::ConstTensor output_backward,
+                  int stride_planes, int stride_rows, int stride_cols) {
+    // Need to swap the order of plane/row/col strides when calling Eigen.
+    filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+        input, output_backward,
+        filter_backward.dimension(2),  // kernel_planes
+        filter_backward.dimension(1),  // kernel_rows
+        filter_backward.dimension(0),  // kernel_cols
+        stride_cols, stride_rows, stride_planes);
+  }
+};
+
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index ec7c02ac2bf..78e83750621 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -322,70 +322,18 @@ class Conv3DBackpropFilterOp : public OpKernel {
       return;
     }
 
-    // For the backprop of the filter, we need to also transpose the
-    // out_backprop.
-    // The shape of backprop is
-    //   [batch, out_z, out_y, out_x, out_depth]
-    // And we need to change it to
-    //   [out_depth, out_x, out_y, out_z, batch]
-    Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
-    TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
-                                  padded_out_cols, batch});
-    Tensor padded_output;
-    OP_REQUIRES_OK(context,
-                   context->allocate_temp(DataTypeToEnum<T>::v(),
-                                          padded_out_shape, &padded_output));
-    Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
-                                                      strides[2], 1};
-    functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
-        eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
-    const Tensor& padded_output_cref = padded_output;
-
-    // For the backprop of the filter, we need to transpose the input.
-    // The shape of input is
-    //   [batch, in_z, in_y, in_x, in_depth]
-    // And we need to change it to
-    //   [in_z, in_y, in_x, batch, in_depth]
-    Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
-    TensorShape in_shuffle_shape(
-        {input_size[0], input_size[1], input_size[2], batch, in_depth});
-    Tensor in_shuffle;
-    OP_REQUIRES_OK(context,
-                   context->allocate_temp(DataTypeToEnum<T>::v(),
-                                          in_shuffle_shape, &in_shuffle));
-    // No need for reversing this time.
-    Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
-    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
-        no_reverse, in_shuffle.tensor<T, 5>());
-    const Tensor& in_shuffle_cref = in_shuffle;
-
-    // The output of the conv_3d would be
-    //   [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
-    // and we need to shuffle it back to
-    //   [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
-    // And we need to reverse the filter backprops.
-    // So we need to allocate (sigh) yet another piece of memory to hold the
-    // output.
-    TensorShape filter_shuffle_shape(
-        {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
-    Tensor filter_shuffle;
-    OP_REQUIRES_OK(
-        context, context->allocate_temp(DataTypeToEnum<T>::v(),
-                                        filter_shuffle_shape, &filter_shuffle));
-    functor::CuboidConvolution<Device, T>()(
-        context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
-        padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
-        1, BrainPadding2EigenPadding(VALID));
-
-    // Now copy the filter_backprop back to the destination.
-    Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
-    Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
-    const Tensor& filter_shuffle_cref = filter_shuffle;
-    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
-        filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+    // There is no need to explicitly compute padding values (and pad
+    // out_backprop), because Eigen uses the same padding inference mechanism as
+    // Tensorflow.
+    functor::CuboidConvolutionBackwardFilter<Device, T>()(
+        context->eigen_device<Device>(),
+        filter_backprop->tensor<T, 5>(),  // filter_backward
+        input.tensor<T, 5>(),             // input
+        out_backprop.tensor<T, 5>(),      // output_backward
+        // Order of strides will be reversed before passing to Eigen.
+        static_cast<int>(strides[0]),   // stride_planes
+        static_cast<int>(strides[1]),   // stride_rows
+        static_cast<int>(strides[2]));  // stride_cols
   }
 
  private:
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index f12c8d943d8..8edf7d4a2c4 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -59,12 +59,12 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
                     const array<
                         typename internal::traits<OutputBackward>::Index, 5>,
                     const TensorReverseOp<const Eigen::array<bool, 5>,
-                                          const Kernel> > > >,
+                                          const Kernel>>>>,
             const TensorReshapingOp<
                 const DSizes<typename internal::traits<OutputBackward>::Index,
                              2>,
                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                          const OutputBackward> > > >,
+                                          const OutputBackward>>>>,
     TensorReshapingOp<
         const DSizes<typename internal::traits<OutputBackward>::Index,
                      internal::traits<OutputBackward>::NumDimensions>,
@@ -75,7 +75,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
                 const DSizes<typename internal::traits<OutputBackward>::Index,
                              2>,
                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                          const OutputBackward> >,
+                                          const OutputBackward>>,
             const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
                 const DSizes<typename internal::traits<OutputBackward>::Index,
                              2>,
@@ -83,7 +83,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
                     const array<
                         typename internal::traits<OutputBackward>::Index, 5>,
                     const TensorReverseOp<const Eigen::array<bool, 5>,
-                                          const Kernel> > > > > > >::type
+                                          const Kernel>>>>>>>::type
 CuboidConvolutionBackwardInput(
     const Kernel& kernel, const OutputBackward& output_backward,
     typename internal::traits<OutputBackward>::Index inputPlanes,
@@ -94,12 +94,12 @@ CuboidConvolutionBackwardInput(
   typedef typename internal::traits<OutputBackward>::Index TensorIndex;
   const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
                                internal::traits<Kernel>::NumDimensions,
-                               internal::traits<Kernel>::Layout, TensorIndex> >
+                               internal::traits<Kernel>::Layout, TensorIndex>>
       kern(kernel);
   const TensorRef<
       const Tensor<typename internal::traits<OutputBackward>::Scalar,
                    internal::traits<OutputBackward>::NumDimensions,
-                   internal::traits<OutputBackward>::Layout, TensorIndex> >
+                   internal::traits<OutputBackward>::Layout, TensorIndex>>
       out(output_backward);
 
   EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -323,29 +323,69 @@ CuboidConvolutionBackwardInput(
  */
 template <typename OutputBackward, typename Input>
 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
-    internal::traits<OutputBackward>::Layout == ColMajor,
-    TensorReshapingOp<
-        const DSizes<typename internal::traits<Input>::Index, 5>,
-        const TensorContractionOp<
-            const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<Input>::Index, 2>,
-                const OutputBackward>,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<Input>::Index, 2>,
-                const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                          const Input> > > >,
-    TensorReshapingOp<
-        const DSizes<typename internal::traits<Input>::Index, 5>,
-        const TensorContractionOp<
-            const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<Input>::Index, 2>,
-                const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                          const Input> >,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<Input>::Index, 2>,
-                const OutputBackward> > > >::type
+    internal::traits<Input>::Layout == ColMajor,
+    const TensorReverseOp<
+        const Eigen::array<typename internal::traits<Input>::Index,
+                           internal::traits<Input>::NumDimensions>,
+        const Eigen::TensorShufflingOp<
+            const Eigen::array<typename internal::traits<Input>::Index,
+                               internal::traits<Input>::NumDimensions>,
+            const Eigen::TensorReshapingOp<
+                const Eigen::DSizes<typename internal::traits<Input>::Index,
+                                    internal::traits<Input>::NumDimensions>,
+                const TensorContractionOp<
+                    const array<
+                        IndexPair<typename internal::traits<Input>::Index>, 1>,
+                    const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+                        const DSizes<typename internal::traits<Input>::Index,
+                                     2>,
+                        const Eigen::TensorShufflingOp<
+                            const Eigen::array<
+                                typename internal::traits<Input>::Index,
+                                internal::traits<Input>::NumDimensions>,
+                            const OutputBackward>>>,
+                    const TensorReshapingOp<
+                        const DSizes<typename internal::traits<Input>::Index,
+                                     2>,
+                        const TensorVolumePatchOp<
+                            Dynamic, Dynamic, Dynamic,
+                            const Eigen::TensorForcedEvalOp<
+                                const Eigen::TensorShufflingOp<
+                                    const Eigen::array<
+                                        typename internal::traits<Input>::Index,
+                                        internal::traits<Input>::NumDimensions>,
+                                    const Input>>>>>>>>,
+    const TensorReverseOp<
+        const Eigen::array<typename internal::traits<Input>::Index,
+                           internal::traits<Input>::NumDimensions>,
+        const Eigen::TensorShufflingOp<
+            const Eigen::array<typename internal::traits<Input>::Index,
+                               internal::traits<Input>::NumDimensions>,
+            const Eigen::TensorReshapingOp<
+                const Eigen::DSizes<typename internal::traits<Input>::Index,
+                                    internal::traits<Input>::NumDimensions>,
+                const TensorContractionOp<
+                    const array<
+                        IndexPair<typename internal::traits<Input>::Index>, 1>,
+                    const TensorReshapingOp<
+                        const DSizes<typename internal::traits<Input>::Index,
+                                     2>,
+                        const TensorVolumePatchOp<
+                            Dynamic, Dynamic, Dynamic,
+                            const Eigen::TensorForcedEvalOp<
+                                const Eigen::TensorShufflingOp<
+                                    const Eigen::array<
+                                        typename internal::traits<Input>::Index,
+                                        internal::traits<Input>::NumDimensions>,
+                                    const Input>>>>,
+                    const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+                        const DSizes<typename internal::traits<Input>::Index,
+                                     2>,
+                        const Eigen::TensorShufflingOp<
+                            const Eigen::array<
+                                typename internal::traits<Input>::Index,
+                                internal::traits<Input>::NumDimensions>,
+                            const OutputBackward>>>>>>>>::type
 CuboidConvolutionBackwardKernel(
     const Input& input, const OutputBackward& output_backward,
     typename internal::traits<Input>::Index kernelPlanes,
@@ -356,11 +396,11 @@ CuboidConvolutionBackwardKernel(
   typedef typename internal::traits<Input>::Index TensorIndex;
   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
                    internal::traits<Input>::NumDimensions,
-                   internal::traits<Input>::Layout, TensorIndex> >
+                   internal::traits<Input>::Layout, TensorIndex>>
       in(input);
   TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
                    internal::traits<OutputBackward>::NumDimensions,
-                   internal::traits<OutputBackward>::Layout, TensorIndex> >
+                   internal::traits<OutputBackward>::Layout, TensorIndex>>
       out(output_backward);
 
   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -374,6 +414,13 @@ CuboidConvolutionBackwardKernel(
                           internal::traits<OutputBackward>::NumDimensions,
                       YOU_MADE_A_PROGRAMMING_MISTAKE);
 
+  // We do not support higher dimensional backward convolutions, or convolutions
+  // without batch dimension.
+  // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+  // dimension in eigen_backward_cuboid_convolutions_test.cc.
+  EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+                      YOU_MADE_A_PROGRAMMING_MISTAKE);
+
   const TensorIndex inputPlanes =
       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
   const TensorIndex inputRows =
@@ -395,6 +442,10 @@ CuboidConvolutionBackwardKernel(
   const TensorIndex kernelChannels =
       isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
 
+  // Number of batches in the input tensor.
+  const TensorIndex batch =
+      isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
+
   // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
   // effective kernel planes/rows/cols are always the same as the kernel itself
   // (see eigen_spatial_convolutions for details).
@@ -402,6 +453,7 @@ CuboidConvolutionBackwardKernel(
   const TensorIndex kernelRowsEff = kernelRows;
   const TensorIndex kernelColsEff = kernelCols;
 
+  // Compute forward padding from input and output_backward dimensions.
   const TensorIndex padPlanes = numext::maxi<Index>(
       0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
   const TensorIndex padRows = numext::maxi<Index>(
@@ -410,94 +462,147 @@ CuboidConvolutionBackwardKernel(
       0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
 
   const TensorIndex padding_top_z = padPlanes / 2;
-  const TensorIndex padding_bottom_z = padPlanes - padding_top_z;
   const TensorIndex padding_top = padRows / 2;
-  const TensorIndex padding_bottom = padRows - padding_top;
   const TensorIndex padding_left = padCols / 2;
-  const TensorIndex padding_right = padCols - padding_left;
 
-  // Reshaped output_backward before contraction.
-  DSizes<TensorIndex, 2> output_dims;
+  // Compute paddings for output_backward before extracting patches.
+  const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+  const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+  const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+  const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+  const auto padded_out_rows = inputRows + kernelRows - 1;
+  const auto padded_out_cols = inputCols + kernelCols - 1;
+  const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+  const auto top_pad_rows = kernelRows - 1 - padding_top;
+  const auto left_pad_cols = kernelCols - 1 - padding_left;
+  const auto bottom_pad_planes =
+      padded_out_planes - expanded_out_planes - top_pad_planes;
+  const auto bottom_pad_rows =
+      padded_out_rows - expanded_out_rows - top_pad_rows;
+  const auto right_pad_cols =
+      padded_out_cols - expanded_out_cols - left_pad_cols;
+
+  // Reorder output_backward dimensions.
+  array<TensorIndex, 5> output_backward_shuffle;
   if (isColMajor) {
-    output_dims[0] = kernelFilters;
-    output_dims[1] = outputPlanes * outputRows * outputCols;
-    for (int i = 4; i < NumDims; ++i) {
-      output_dims[1] *= out.dimension(i);
-    }
+    // From: [out_depth, out_planes, out_rows, out_cols, batch]
+    // To:   [batch, out_planes, out_rows, out_cols, out_depth]
+    output_backward_shuffle = {4, 1, 2, 3, 0};
   } else {
-    output_dims[1] = kernelFilters;
-    output_dims[0] = outputCols * outputRows * outputPlanes;
-    for (int i = 0; i < NumDims - 4; ++i) {
-      output_dims[0] *= out.dimension(i);
-    }
+    // From: [batch, out_cols, out_rows, out_planes, out_depth]
+    // To:   [out_depth, out_cols, out_rows, out_planes, batch]
+    output_backward_shuffle = {4, 1, 2, 3, 0};
   }
 
-  // Reshaped extract_volume_patches(in)
+  // Reorder input dimensions.
+  array<TensorIndex, 5> input_shuffle;
+  if (isColMajor) {
+    // From: [in_depth, in_planes, in_rows, in_cols, batch]
+    // To:   [in_depth, batch, in_planes, in_rows, in_cols]
+    input_shuffle = {0, 4, 1, 2, 3};
+  } else {
+    // From: [batch, in_cols, in_rows, in_planes, in_depth]
+    // To:   [in_cols, in_rows, in_planes, batch, in_depth]
+    input_shuffle = {1, 2, 3, 0, 4};
+  }
+
+  // Input is playing the role of a "kernel" in this convolution.
+  DSizes<TensorIndex, 2> input_dims;
+  if (isColMajor) {
+    input_dims[0] = kernelChannels;
+    input_dims[1] = batch * inputPlanes * inputRows * inputCols;
+  } else {
+    input_dims[1] = kernelChannels;
+    input_dims[0] = inputCols * inputRows * inputPlanes * batch;
+  }
+
+  // Molds the output of the patch extraction result into a 2D tensor:
+  // - the first dimension (dims[0]): the patch values to be multiplied with the
+  // kernels
+  // - the second dimension (dims[1]): everything else
   DSizes<TensorIndex, 2> pre_contract_dims;
   if (isColMajor) {
-    pre_contract_dims[0] =
-        kernelChannels * kernelPlanes * kernelRows * kernelCols;
-    pre_contract_dims[1] = outputPlanes * outputRows * outputCols;
-    for (int i = 4; i < NumDims; ++i) {
-      pre_contract_dims[1] *= in.dimension(i);
-    }
-    eigen_assert(output_dims[1] == pre_contract_dims[1]);
-  } else {
+    pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
     pre_contract_dims[1] =
-        kernelCols * kernelRows * kernelPlanes * kernelChannels;
-    pre_contract_dims[0] = outputCols * outputRows * outputPlanes;
-    for (int i = 0; i < NumDims - 4; ++i) {
-      pre_contract_dims[0] *= in.dimension(i);
-    }
-    eigen_assert(output_dims[0] == pre_contract_dims[0]);
+        kernelPlanes * kernelRows * kernelCols * kernelFilters;
+  } else {
+    pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+    pre_contract_dims[0] =
+        kernelFilters * kernelCols * kernelRows * kernelPlanes;
   }
 
   // We will contract along the collapsed dimension that contains the
-  // outputCols, outputRows, outputPlanes and OTHERS.
+  // batch, inputPlanes, inputRows and inputCols.
   array<IndexPair<TensorIndex>, 1> contract_dims;
+  contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+
+  // Dimensions after contraction.
+  DSizes<TensorIndex, NumDims> post_contract_dims;
   if (isColMajor) {
-    // col-major: output_backward.contract(input.patches)
-    contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+    post_contract_dims[0] = kernelChannels;
+    post_contract_dims[1] = kernelPlanes;
+    post_contract_dims[2] = kernelRows;
+    post_contract_dims[3] = kernelCols;
+    post_contract_dims[4] = kernelFilters;
   } else {
-    // row-major: input.patches.contract(output_backward)
-    contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+    post_contract_dims[0] = kernelFilters;
+    post_contract_dims[1] = kernelCols;
+    post_contract_dims[2] = kernelRows;
+    post_contract_dims[3] = kernelPlanes;
+    post_contract_dims[4] = kernelChannels;
   }
 
-  DSizes<TensorIndex, 5> kernel_dims;
+  // Reorder output of contraction to valid filter shape.
+  array<TensorIndex, 5> kernel_shuffle;
   if (isColMajor) {
-    kernel_dims[0] = kernelFilters;
-    kernel_dims[1] = kernelChannels;
-    kernel_dims[2] = kernelPlanes;
-    kernel_dims[3] = kernelRows;
-    kernel_dims[4] = kernelCols;
+    // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+    // To:   [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+    kernel_shuffle = {4, 0, 1, 2, 3};
   } else {
-    kernel_dims[4] = kernelFilters;
-    kernel_dims[3] = kernelChannels;
-    kernel_dims[2] = kernelPlanes;
-    kernel_dims[1] = kernelRows;
-    kernel_dims[0] = kernelCols;
+    // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+    // To:   [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+    kernel_shuffle = {1, 2, 3, 4, 0};
   }
 
-  return choose(
-      Cond<internal::traits<Input>::Layout == ColMajor>(),
-      output_backward.reshape(output_dims)
-          .contract(input
+  // Reverse kernel backprop dimensions.
+  array<TensorIndex, 5> kernel_reverse;
+  if (isColMajor) {
+    kernel_reverse = {false, false, true, true, true};
+  } else {
+    kernel_reverse = {true, true, true, false, false};
+  }
+
+  // Create convolution input (aka source of patches) from output backward
+  // tensor by shuffling dimensions.
+  const auto the_input =
+      output_backward.shuffle(output_backward_shuffle).eval();
+
+  // Create convolution kernel (aka filter) from input by shuffling and
+  // reshaping.
+  const auto the_kernel =
+      input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+  return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+                the_kernel.contract(
+                    the_input
                         .extract_volume_patches(
-                            kernelPlanes, kernelRows, kernelCols, stridePlanes,
-                            strideRows, strideCols, 1, 1, 1, padding_top_z,
-                            padding_bottom_z, padding_top, padding_bottom,
-                            padding_left, padding_right)
+                            inputPlanes, inputRows, inputCols, 1, 1, 1,
+                            stridePlanes, strideRows, strideCols,
+                            top_pad_planes, bottom_pad_planes, top_pad_rows,
+                            bottom_pad_rows, left_pad_cols, right_pad_cols)
                         .reshape(pre_contract_dims),
-                    contract_dims)
-          .reshape(kernel_dims),
-      input
-          .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
-                                  stridePlanes, strideRows, strideCols, 1, 1, 1,
-                                  padding_top_z, padding_bottom_z, padding_top,
-                                  padding_bottom, padding_left, padding_right)
-          .reshape(pre_contract_dims)
-          .contract(output_backward.reshape(output_dims), contract_dims)
-          .reshape(kernel_dims));
+                    contract_dims),
+                the_input
+                    .extract_volume_patches(
+                        inputPlanes, inputRows, inputCols, 1, 1, 1,
+                        stridePlanes, strideRows, strideCols, top_pad_planes,
+                        bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+                        left_pad_cols, right_pad_cols)
+                    .reshape(pre_contract_dims)
+                    .contract(the_kernel, contract_dims))
+      .reshape(post_contract_dims)
+      .shuffle(kernel_shuffle)
+      .reverse(kernel_reverse);
 }
 
 }  // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec96594..673ec1458b8 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
   const int output_cols = input_cols - patch_cols + 1;
   const int output_planes = input_planes - patch_planes + 1;
 
-  Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+  // TODO(ezhulenev): Support backward kernel convolution without batch
+  // dimension.
+  Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+                         /*num_batches*/ 1);
   Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
                           patch_cols);
-  Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
-                                   output_cols);
+  Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+                                   output_cols, /*num_batches*/ 1);
 
   output_backward = output_backward.constant(11.0f) + output_backward.random();
   input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
                   if (output_i >= 0 && output_i < output_planes &&
                       output_j >= 0 && output_j < output_rows &&
                       output_k >= 0 && output_k < output_cols) {
-                    expected +=
-                        input(id, i, j, k) *
-                        output_backward(od, output_i, output_j, output_k);
+                    expected += input(id, i, j, k, /*batch*/ 0) *
+                                output_backward(od, output_i, output_j,
+                                                output_k, /*batch*/ 0);
                   }
                 }
               }
@@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
   const int output_cols = input_cols - patch_cols + 1;
   const int output_planes = input_planes - patch_planes + 1;
 
-  Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
-                                   input_depth);
+  // TODO(ezhulenev): Support backward kernel convolution without batch
+  // dimension.
+  Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+                                   input_planes, input_depth);
   Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
                                     input_depth, output_depth);
-  Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
-                                             output_planes, output_depth);
+  Tensor<float, 5, RowMajor> output_backward(
+      /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
 
   output_backward = output_backward.constant(11.0f) + output_backward.random();
   input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
                   if (output_i >= 0 && output_i < output_planes &&
                       output_j >= 0 && output_j < output_rows &&
                       output_k >= 0 && output_k < output_cols) {
-                    expected +=
-                        input(k, j, i, id) *
-                        output_backward(output_k, output_j, output_i, od);
+                    expected += input(/*batch*/ 0, k, j, i, id) *
+                                output_backward(/*batch*/ 0, output_k, output_j,
+                                                output_i, od);
                   }
                 }
               }