From cc5ea8469641b6680971eb76020407f81ab3f573 Mon Sep 17 00:00:00 2001
From: Anna R <annarev@google.com>
Date: Wed, 9 Dec 2020 16:13:53 -0800
Subject: [PATCH] Remove changes made to support TFRT-based OpKernel classes in
 Conv3d kernel. This is essentially a rollback of
 https://github.com/tensorflow/tensorflow/commit/2f10fa781040595d6801abc1dcbeb43d91b1565a
 since we decided not to follow this approach.

PiperOrigin-RevId: 346660110
Change-Id: I1c03d47d4f4b42f74ef23f4c8e1d6d3f8207b240
---
 tensorflow/core/framework/BUILD             |   3 -
 tensorflow/core/framework/numeric_op.h      |  22 ++-
 tensorflow/core/framework/numeric_op_base.h |  49 -----
 tensorflow/core/kernels/BUILD               |  46 +----
 tensorflow/core/kernels/conv_ops_3d.cc      | 154 ++++++++++++++--
 tensorflow/core/kernels/conv_ops_3d.h       | 187 --------------------
 6 files changed, 163 insertions(+), 298 deletions(-)
 delete mode 100644 tensorflow/core/framework/numeric_op_base.h
 delete mode 100644 tensorflow/core/kernels/conv_ops_3d.h

diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index 33034bb4e92..994b91c78d0 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -68,7 +68,6 @@ exports_files(
         "model.h",
         "node_def_builder.h",
         "numeric_op.h",
-        "numeric_op_base.h",
         "op_kernel.h",
         "op_requires.h",
         "op_segment.h",
@@ -204,7 +203,6 @@ filegroup(
         "node_def_util.h",
         "node_properties.h",
         "numeric_op.h",
-        "numeric_op_base.h",
         "numeric_types.h",
         "op.h",
         "op_def_builder.h",
@@ -305,7 +303,6 @@ filegroup(
         "kernel_shape_util.h",
         "log_memory.cc",
         "log_memory.h",
-        "numeric_op_base.h",
         "numeric_types.h",
         "op_requires.h",
         "ops_util.cc",
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
index 9f8ceed2968..0167e21f113 100644
--- a/tensorflow/core/framework/numeric_op.h
+++ b/tensorflow/core/framework/numeric_op.h
@@ -12,22 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+
 #ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
 #define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
 
-#include "tensorflow/core/framework/numeric_op_base.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 
 namespace tensorflow {
 
+// One input and one output, both the same type.
 template <class T>
-using UnaryOp = UnaryOpBase<T, OpKernel, OpKernelConstruction>;
+class UnaryOp : public OpKernel {
+ public:
+  explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) {
+    const DataType dt = DataTypeToEnum<T>::v();
+    OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt}));
+  }
+};
 
+// Two inputs and one output, all the same type.
 template <class T>
-using BinaryOp = BinaryOpBase<T, OpKernel, OpKernelConstruction>;
+class BinaryOp : public OpKernel {
+ public:
+  explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) {
+    const DataType dt = DataTypeToEnum<T>::v();
+    OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt}));
+  }
+};
 
 // For operations where the input and output are the same shape.
 //
diff --git a/tensorflow/core/framework/numeric_op_base.h b/tensorflow/core/framework/numeric_op_base.h
deleted file mode 100644
index be7d3bf8f9e..00000000000
--- a/tensorflow/core/framework/numeric_op_base.h
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_BASE_H_
-#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_BASE_H_
-
-#include "tensorflow/core/framework/op_requires.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-
-namespace tensorflow {
-
-// One input and one output, both the same type.
-template <class T, class OpKernelT, class OpKernelConstructionT>
-class UnaryOpBase : public OpKernelT {
- public:
-  explicit UnaryOpBase(OpKernelConstructionT* construction) :
-      OpKernelT(construction) {
-    const DataType dt = DataTypeToEnum<T>::v();
-    OP_REQUIRES_OK(construction, construction->MatchSignature({dt}, {dt}));
-  }
-};
-
-// Two inputs and one output, all the same type.
-template <class T, class OpKernelT, class OpKernelConstructionT>
-class BinaryOpBase : public OpKernelT {
- public:
-  explicit BinaryOpBase(OpKernelConstructionT* construction) :
-      OpKernelT(construction) {
-    const DataType dt = DataTypeToEnum<T>::v();
-    OP_REQUIRES_OK(construction, construction->MatchSignature({dt, dt}, {dt}));
-  }
-};
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_BASE_H_
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 423bd427534..75b10a3a26c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3297,48 +3297,6 @@ cc_library(
     }),
 )
 
-# TODO(annarev): conv_ops_3d_headers currently depends on android target build
-# from selected sources. We should switch to use granular dependencies instead.
-# Then, we can just depend on "conv3d".
-cc_library(
-    name = "conv_3d_mobile",
-    hdrs = [
-        "conv_3d.h",
-        "eigen_backward_cuboid_convolutions.h",
-        "eigen_convolution_helpers.h",
-        "eigen_cuboid_convolution.h",
-        "eigen_volume_patch.h",
-    ],
-    deps = [
-        ":eigen_spatial_convolutions-inl",
-    ] + select({
-        "//tensorflow:android": [
-            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
-        ],
-        "//conditions:default": [
-            "//tensorflow/core:framework",
-        ],
-    }),
-)
-
-cc_library(
-    name = "conv_ops_3d_headers",
-    hdrs = [
-        "conv_ops_3d.h",
-    ],
-    deps = select({
-        "//tensorflow:android": [
-            ":conv_3d_mobile",
-            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
-        ],
-        "//conditions:default": [
-            ":conv_3d",
-            "//third_party/eigen3",
-            "//tensorflow/core:framework",
-        ],
-    }),
-)
-
 tf_kernel_library(
     name = "argmax_op",
     prefix = "argmax_op",
@@ -3810,6 +3768,7 @@ tf_kernel_library(
         "deep_conv2d.h",
         "gemm_functors.h",
         "winograd_transform.h",
+        "conv_ops_fused_impl.h",
     ] + select({
         ":xsmm_convolutions": ["xsmm_conv2d.h"],
         "//conditions:default": [],
@@ -3824,7 +3783,6 @@ tf_kernel_library(
     prefix = "conv_ops",
     deps = [
         ":conv_grad_shape_utils",
-        ":conv_ops_3d_headers",
         ":conv_2d",
         ":conv_3d",
         ":eigen_contraction_kernel",
@@ -5948,7 +5906,6 @@ filegroup(
         "conv_2d.h",
         "conv_3d.h",
         "conv_ops.h",
-        "conv_ops_3d.h",
         "conv_ops_gpu.h",
         "data_format_ops.h",
         "depthtospace_op.h",
@@ -6445,7 +6402,6 @@ filegroup(
         "stateful_random_ops_cpu_gpu.h",
         # Allows conv_3d ops for android but excluded from *_3d* rule above.
         "conv_3d.h",
-        "conv_ops_3d.h",
         "conv_ops_3d.cc",
         "conv_ops_gpu.h",
     ],
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 4ca5f514b7a..f6b30ced4ea 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -16,8 +16,7 @@ limitations under the License.
 #define USE_EIGEN_TENSOR
 #define EIGEN_USE_THREADS
 
-#include "tensorflow/core/kernels/conv_ops_3d.h"
-
+#include "tensorflow/core/framework/kernel_shape_util.h"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -51,11 +50,147 @@ namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
+template <typename Device, typename T>
+struct LaunchConvOp;
+
+template <typename T>
+struct LaunchConvOp<CPUDevice, T> {
+  static void launch(OpKernelContext* context, bool cudnn_use_autotune,
+                     const Tensor& input, const Tensor& filter,
+                     const std::array<int64, 3>& dilations,
+                     const std::array<int64, 3>& strides, const Padding padding,
+                     TensorFormat data_format, Tensor* output) {
+    OP_REQUIRES(context, data_format == FORMAT_NHWC,
+                errors::InvalidArgument("CPU implementation of Conv3D "
+                                        "currently only supports the NHWC "
+                                        "tensor format."));
+    OP_REQUIRES(context,
+                dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
+                errors::InvalidArgument("CPU implementation of Conv3D "
+                                        "currently only supports dilated rates "
+                                        "of 1."));
+    functor::CuboidConvolution<CPUDevice, T>()(
+        context->eigen_device<CPUDevice>(), output->tensor<T, 5>(),
+        input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
+        strides[0], BrainPadding2EigenPadding(padding));
+  }
+};
+
+template <typename Device, typename T>
+class Conv3DOp : public BinaryOp<T> {
+ public:
+  explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
+    string data_format;
+    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+                errors::InvalidArgument("Invalid data format"));
+    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+    OP_REQUIRES(context, stride_.size() == 5,
+                errors::InvalidArgument("Sliding window strides field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(stride_, data_format_, 'N') == 1 &&
+         GetTensorDim(stride_, data_format_, 'C') == 1),
+        errors::InvalidArgument("Current implementation does not yet support "
+                                "strides in the batch and depth dimensions."));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(stride_, data_format_, '0') > 0 &&
+         GetTensorDim(stride_, data_format_, '1') > 0 &&
+         GetTensorDim(stride_, data_format_, '2') > 0),
+        errors::InvalidArgument("Spatial strides should be larger than 0."));
+    OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+    OP_REQUIRES(context, dilation_.size() == 5,
+                errors::InvalidArgument("Dilation rates field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(context,
+                (GetTensorDim(dilation_, data_format_, 'N') == 1 &&
+                 GetTensorDim(dilation_, data_format_, 'C') == 1),
+                errors::InvalidArgument(
+                    "Current implementation does not yet support "
+                    "dilation rates in the batch and depth dimensions."));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(dilation_, data_format_, '0') > 0 &&
+         GetTensorDim(dilation_, data_format_, '1') > 0 &&
+         GetTensorDim(dilation_, data_format_, '2') > 0),
+        errors::InvalidArgument("Dilated rates should be larger than 0."));
+    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+    cudnn_use_autotune_ = CudnnUseAutotune();
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // Input tensor is of the following dimensions:
+    // [ batch, in_z, in_y, in_x, in_channels ]
+    const Tensor& input = context->input(0);
+
+    // Input filter is of the following dimensions:
+    // [ filter_z, filter_y, filter_x, in_channels, out_channels]
+    const Tensor& filter = context->input(1);
+
+    // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
+    // kept consistent between input/filter/output.
+    OP_REQUIRES(context, input.dims() == 5,
+                errors::InvalidArgument("input must be 5-dimensional"));
+    OP_REQUIRES(context, filter.dims() == 5,
+                errors::InvalidArgument("filter must be 5-dimensional"));
+
+    const int64 in_depth = GetTensorDim(input, data_format_, 'C');
+    const int64 in_batch = GetTensorDim(input, data_format_, 'N');
+
+    const int64 filter_depth = filter.dim_size(3);
+    const int64 out_depth = filter.dim_size(4);
+
+    OP_REQUIRES(context, in_depth % filter_depth == 0,
+                errors::InvalidArgument(
+                    "Input depth must be evenly divisible by filter depth: ",
+                    in_depth, " vs ", filter_depth));
+
+    // Dimension order for these arrays is: z, y, x.
+    std::array<int64, 3> input_size = {
+        {GetTensorDim(input, data_format_, '0'),
+         GetTensorDim(input, data_format_, '1'),
+         GetTensorDim(input, data_format_, '2')}};
+    std::array<int64, 3> filter_size = {
+        {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
+    std::array<int64, 3> dilations = {
+        {GetTensorDim(dilation_, data_format_, '0'),
+         GetTensorDim(dilation_, data_format_, '1'),
+         GetTensorDim(dilation_, data_format_, '2')}};
+    std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
+                                     GetTensorDim(stride_, data_format_, '1'),
+                                     GetTensorDim(stride_, data_format_, '2')}};
+    std::array<int64, 3> out, padding;
+
+    OP_REQUIRES_OK(
+        context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,
+                                   padding_, &out, &padding));
+    TensorShape out_shape = ShapeFromFormat(
+        data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
+    Tensor* output;
+    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+    // Return early if nothing to do.
+    if (out_shape.num_elements() == 0) return;
+
+    LaunchConvOp<Device, T>::launch(context, cudnn_use_autotune_, input, filter,
+                                    dilations, strides, padding_, data_format_,
+                                    output);
+  }
+
+ private:
+  std::vector<int32> dilation_;
+  std::vector<int32> stride_;
+  Padding padding_;
+  TensorFormat data_format_;
+  bool cudnn_use_autotune_;
+};
+
 #define REGISTER_CPU_KERNEL(T)                                  \
   REGISTER_KERNEL_BUILDER(                                      \
       Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      Conv3DOp<CPUDevice, T, OpKernel, OpKernelConstruction,    \
-               OpKernelContext>);
+      Conv3DOp<CPUDevice, T>);
 TF_CALL_half(REGISTER_CPU_KERNEL);
 TF_CALL_float(REGISTER_CPU_KERNEL);
 TF_CALL_double(REGISTER_CPU_KERNEL);
@@ -73,7 +208,7 @@ typedef AutoTuneSingleton<Conv3dAutoTuneGroup, ConvParameters,
 
 // TODO(mjanusz): Share logic with 2d implementation as much as possible.
 template <typename T>
-struct LaunchConvOp<GPUDevice, T, OpKernelContext> {
+struct LaunchConvOp<GPUDevice, T> {
   static void launch(OpKernelContext* ctx, bool cudnn_use_autotune,
                      const Tensor& input_param, const Tensor& filter,
                      const std::array<int64, 3>& dilations,
@@ -548,16 +683,13 @@ DECLARE_GPU_SPEC(double);
 // Registration of the GPU implementations.
 REGISTER_KERNEL_BUILDER(
     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
-    Conv3DOp<GPUDevice, Eigen::half, OpKernel, OpKernelConstruction,
-             OpKernelContext>);
+    Conv3DOp<GPUDevice, Eigen::half>);
 REGISTER_KERNEL_BUILDER(
     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
-    Conv3DOp<GPUDevice, float, OpKernel, OpKernelConstruction,
-             OpKernelContext>);
+    Conv3DOp<GPUDevice, float>);
 REGISTER_KERNEL_BUILDER(
     Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
-    Conv3DOp<GPUDevice, double, OpKernel, OpKernelConstruction,
-             OpKernelContext>);
+    Conv3DOp<GPUDevice, double>);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_3d.h b/tensorflow/core/kernels/conv_ops_3d.h
deleted file mode 100644
index 9dcdea5b18f..00000000000
--- a/tensorflow/core/kernels/conv_ops_3d.h
+++ /dev/null
@@ -1,187 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_3D_H_
-#define TENSORFLOW_CORE_KERNELS_CONV_OPS_3D_H_
-
-#include <vector>
-
-#define USE_EIGEN_TENSOR
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/core/framework/numeric_op_base.h"
-#include "tensorflow/core/framework/kernel_shape_util.h"
-#include "tensorflow/core/framework/op_requires.h"
-#include "tensorflow/core/framework/ops_util.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/kernels/conv_3d.h"
-#include "tensorflow/core/platform/errors.h"
-#include "tensorflow/core/util/padding.h"
-#include "tensorflow/core/util/tensor_format.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/util/use_cudnn.h"
-#endif
-
-namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-
-template <typename Device, typename T, class OpKernelContextT>
-struct LaunchConvOp;
-
-template <typename T, class OpKernelContextT>
-struct LaunchConvOp<CPUDevice, T, OpKernelContextT> {
-  static void launch(OpKernelContextT* context, bool cudnn_use_autotune,
-                     const Tensor& input, const Tensor& filter,
-                     const std::array<int64, 3>& dilations,
-                     const std::array<int64, 3>& strides, const Padding padding,
-                     TensorFormat data_format, Tensor* output) {
-    OP_REQUIRES(context, data_format == FORMAT_NHWC,
-                errors::InvalidArgument("CPU implementation of Conv3D "
-                                        "currently only supports the NHWC "
-                                        "tensor format."));
-    OP_REQUIRES(context,
-                dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1,
-                errors::InvalidArgument("CPU implementation of Conv3D "
-                                        "currently only supports dilated rates "
-                                        "of 1."));
-    functor::CuboidConvolution<CPUDevice, T>()(
-        context->template eigen_device<CPUDevice>(), output->tensor<T, 5>(),
-        input.tensor<T, 5>(), filter.tensor<T, 5>(), strides[2], strides[1],
-        strides[0], BrainPadding2EigenPadding(padding));
-  }
-};
-
-template <typename Device, typename T, class OpKernelT,
-          class OpKernelConstructionT, class OpKernelContextT>
-class Conv3DOp : public BinaryOpBase<T, OpKernelT, OpKernelConstructionT> {
- public:
-  explicit Conv3DOp(OpKernelConstructionT* context) :
-      BinaryOpBase<T, OpKernelT, OpKernelConstructionT>(context) {
-    string data_format;
-    OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
-    OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
-                errors::InvalidArgument("Invalid data format"));
-    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
-    OP_REQUIRES(context, stride_.size() == 5,
-                errors::InvalidArgument("Sliding window strides field must "
-                                        "specify 5 dimensions"));
-    OP_REQUIRES(
-        context,
-        (GetTensorDim(stride_, data_format_, 'N') == 1 &&
-         GetTensorDim(stride_, data_format_, 'C') == 1),
-        errors::InvalidArgument("Current implementation does not yet support "
-                                "strides in the batch and depth dimensions."));
-    OP_REQUIRES(
-        context,
-        (GetTensorDim(stride_, data_format_, '0') > 0 &&
-         GetTensorDim(stride_, data_format_, '1') > 0 &&
-         GetTensorDim(stride_, data_format_, '2') > 0),
-        errors::InvalidArgument("Spatial strides should be larger than 0."));
-    OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
-    OP_REQUIRES(context, dilation_.size() == 5,
-                errors::InvalidArgument("Dilation rates field must "
-                                        "specify 5 dimensions"));
-    OP_REQUIRES(context,
-                (GetTensorDim(dilation_, data_format_, 'N') == 1 &&
-                 GetTensorDim(dilation_, data_format_, 'C') == 1),
-                errors::InvalidArgument(
-                    "Current implementation does not yet support "
-                    "dilation rates in the batch and depth dimensions."));
-    OP_REQUIRES(
-        context,
-        (GetTensorDim(dilation_, data_format_, '0') > 0 &&
-         GetTensorDim(dilation_, data_format_, '1') > 0 &&
-         GetTensorDim(dilation_, data_format_, '2') > 0),
-        errors::InvalidArgument("Dilated rates should be larger than 0."));
-    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
-#if GOOGLE_CUDA
-    cudnn_use_autotune_ = CudnnUseAutotune();
-#else
-    cudnn_use_autotune_ = false;
-#endif
-  }
-
-  void Compute(OpKernelContextT* context) override {
-    // Input tensor is of the following dimensions:
-    // [ batch, in_z, in_y, in_x, in_channels ]
-    const Tensor& input = context->input(0);
-
-    // Input filter is of the following dimensions:
-    // [ filter_z, filter_y, filter_x, in_channels, out_channels]
-    const Tensor& filter = context->input(1);
-
-    // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be
-    // kept consistent between input/filter/output.
-    OP_REQUIRES(context, input.dims() == 5,
-                errors::InvalidArgument("input must be 5-dimensional"));
-    OP_REQUIRES(context, filter.dims() == 5,
-                errors::InvalidArgument("filter must be 5-dimensional"));
-
-    const int64 in_depth = GetTensorDim(input, data_format_, 'C');
-    const int64 in_batch = GetTensorDim(input, data_format_, 'N');
-
-    const int64 filter_depth = filter.dim_size(3);
-    const int64 out_depth = filter.dim_size(4);
-
-    OP_REQUIRES(context, in_depth % filter_depth == 0,
-                errors::InvalidArgument(
-                    "Input depth must be evenly divisible by filter depth: ",
-                    in_depth, " vs ", filter_depth));
-
-    // Dimension order for these arrays is: z, y, x.
-    std::array<int64, 3> input_size = {
-        {GetTensorDim(input, data_format_, '0'),
-         GetTensorDim(input, data_format_, '1'),
-         GetTensorDim(input, data_format_, '2')}};
-    std::array<int64, 3> filter_size = {
-        {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}};
-    std::array<int64, 3> dilations = {
-        {GetTensorDim(dilation_, data_format_, '0'),
-         GetTensorDim(dilation_, data_format_, '1'),
-         GetTensorDim(dilation_, data_format_, '2')}};
-    std::array<int64, 3> strides = {{GetTensorDim(stride_, data_format_, '0'),
-                                     GetTensorDim(stride_, data_format_, '1'),
-                                     GetTensorDim(stride_, data_format_, '2')}};
-    std::array<int64, 3> out, padding;
-
-    OP_REQUIRES_OK(
-        context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,
-                                   padding_, &out, &padding));
-    TensorShape out_shape = ShapeFromFormat(
-        data_format_, in_batch, {{out[0], out[1], out[2]}}, out_depth);
-    Tensor* output;
-    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
-
-    // Return early if nothing to do.
-    if (out_shape.num_elements() == 0) return;
-
-    LaunchConvOp<Device, T, OpKernelContextT>::launch(
-        context, cudnn_use_autotune_, input, filter,
-        dilations, strides, padding_, data_format_,
-        output);
-  }
-
- private:
-  std::vector<int32> dilation_;
-  std::vector<int32> stride_;
-  Padding padding_;
-  TensorFormat data_format_;
-  bool cudnn_use_autotune_;
-};
-
-}  // namespace tensorflow
-
-
-#endif  // TENSORFLOW_CORE_KERNELS_CONV_OPS_3D_H_