From 3a5cb493f5e9f6357367842fefeadd0663f61112 Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev <ezhulenev@google.com>
Date: Thu, 9 May 2019 14:59:16 -0700
Subject: [PATCH] Support NHWC Conv2D with cuDNN for fp16 (aka Eigen::half and
 DT_HALF)

PiperOrigin-RevId: 247502362
---
 tensorflow/core/kernels/BUILD                 |   2 +-
 tensorflow/core/kernels/conv_2d.h             |  60 +++--
 tensorflow/core/kernels/conv_2d_gpu.h         |  21 +-
 tensorflow/core/kernels/conv_ops.cc           | 118 +++++++--
 .../core/kernels/conv_ops_benchmark_test.cc   | 245 ++++++++++++------
 tensorflow/core/util/tensor_format.cc         |   2 +
 tensorflow/core/util/tensor_format.h          |   5 +-
 7 files changed, 325 insertions(+), 128 deletions(-)

diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0a4bfc98bb4..e377571b3d9 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1606,10 +1606,10 @@ tf_cuda_cc_test(
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:tensorflow",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "//tensorflow/stream_executor/cuda:cudnn_plugin",
     ],
 )
 
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index b735f78c2e3..22b10ade4db 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -179,42 +179,50 @@ struct MatMulConvFunctor {
 
 // Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
 //
-// Note: Currently OIHW is the only supported destination format. Support for
-// OHWI format will be added in a follow-up change.
+// Note: Currently supports OIHW and OHWI destination formats.
 template <typename Device, typename T, typename IndexType, int NDIMS>
 struct TransformFilter {
   void operator()(const Device& d, FilterTensorFormat dst_filter_format,
                   typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
                   typename TTypes<T, NDIMS, IndexType>::Tensor out) {
+    // NOTE: Source filter format is always HWIO.
+    Eigen::DSizes<IndexType, NDIMS - 2> spatial_dims;
+    for (int i = 0; i < spatial_dims.rank(); ++i) {
+      spatial_dims[i] = in.dimension(i);
+    }
+
     // Merge the spatial dimensions together to speed up the shuffle operation.
     Eigen::DSizes<IndexType, 3> merged_dims;
-    merged_dims[0] = in.dimension(0);  // spatial dimensions
-    for (int i = 1; i < NDIMS - 2; ++i) {
-      merged_dims[0] *= in.dimension(i);
-    }
-    merged_dims[1] = in.dimension(NDIMS - 2);  // input filters
-    merged_dims[2] = in.dimension(NDIMS - 1);  // output filters
-
-    DCHECK(dst_filter_format == FORMAT_OIHW)
-        << "Unsupported destination filter format: "
-        << ToString(dst_filter_format);
-    // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
-    // in the beginning.
-    Eigen::DSizes<IndexType, 3> shuffling_perm =
-        Eigen::DSizes<IndexType, 3>(2, 1, 0);
+    merged_dims[0] = spatial_dims.TotalSize();  // product of spatial dims [H*W]
+    merged_dims[1] = in.dimension(NDIMS - 2);   // input filters           [I]
+    merged_dims[2] = in.dimension(NDIMS - 1);   // output filters          [O]
 
+    // Shuffle tensor with merged spatial dimensions.
+    Eigen::DSizes<IndexType, 3> shuffling_perm;
+    // Expand shuffled tensor into final dimensions.
     Eigen::DSizes<IndexType, NDIMS> expanded_dims;
-    int out_index = 0;
-    for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
-      if (shuffling_perm[merged_dim] == 0) {
-        for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
-          expanded_dims[out_index++] = in.dimension(spatial_dim);
-        }
-      } else {
-        constexpr int kLastSpatialDim = NDIMS - 3;
-        expanded_dims[out_index++] =
-            in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
+
+    if (dst_filter_format == FORMAT_OIHW) {
+      shuffling_perm = Eigen::DSizes<IndexType, 3>(2, 1, 0);
+
+      expanded_dims[0] = merged_dims[2];  // [O]
+      expanded_dims[1] = merged_dims[1];  // [I]
+      for (int i = 0; i < spatial_dims.rank(); ++i) {
+        expanded_dims[2 + i] = spatial_dims[i];
       }
+
+    } else if (dst_filter_format == FORMAT_OHWI) {
+      shuffling_perm = Eigen::DSizes<IndexType, 3>(2, 0, 1);
+
+      expanded_dims[0] = merged_dims[2];          // [O]
+      expanded_dims[NDIMS - 1] = merged_dims[1];  // [I]
+      for (int i = 0; i < spatial_dims.rank(); ++i) {
+        expanded_dims[1 + i] = spatial_dims[i];
+      }
+
+    } else {
+      DCHECK(false) << "Unsupported destination filter format: "
+                    << ToString(dst_filter_format);
     }
 
     out.device(d) =
diff --git a/tensorflow/core/kernels/conv_2d_gpu.h b/tensorflow/core/kernels/conv_2d_gpu.h
index 820a92b0f09..a8c218a7321 100644
--- a/tensorflow/core/kernels/conv_2d_gpu.h
+++ b/tensorflow/core/kernels/conv_2d_gpu.h
@@ -434,13 +434,22 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
     combined_dims[2] = in.dimension(NDIMS - 1);  // output filters
     CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
 
-    CHECK(dst_filter_format == FORMAT_OIHW)
-        << "Unsupported output layout: " << ToString(dst_filter_format);
+    if (dst_filter_format == FORMAT_OIHW) {
+      TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
+                                   config.block_count, config.thread_per_block,
+                                   0, d.stream(), config.virtual_thread_count,
+                                   in.data(), combined_dims, out.data()));
 
-    TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
-                                 config.block_count, config.thread_per_block, 0,
-                                 d.stream(), config.virtual_thread_count,
-                                 in.data(), combined_dims, out.data()));
+    } else if (dst_filter_format == FORMAT_OHWI) {
+      TF_CHECK_OK(CudaLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
+                                   config.block_count, config.thread_per_block,
+                                   0, d.stream(), config.virtual_thread_count,
+                                   in.data(), combined_dims, out.data()));
+
+    } else {
+      LOG(ERROR) << "Unsupported filter format: "
+                 << ToString(dst_filter_format);
+    }
   }
 };
 
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ec54ece9d7c..8050320e441 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/conv_ops.h"
 
 #include <string.h>
+
 #include <map>
 #include <vector>
 
@@ -561,6 +562,15 @@ template struct LaunchConv2DOp<CPUDevice, float>;
 template struct LaunchConv2DOp<CPUDevice, double>;
 
 #if GOOGLE_CUDA
+// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU.
+bool IsVoltaOrLater(const se::StreamExecutor& stream_exec) {
+  int major, minor;
+  CHECK(stream_exec  // Crash OK
+            .GetDeviceDescription()
+            .cuda_compute_capability(&major, &minor));
+  return major >= 7;
+}
+
 int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
                            int64 default_value_in_bytes) {
   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
@@ -676,6 +686,23 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
     return;
   }
 
+  // Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
+  // in NHWC data layout. In all other configurations it's more efficient to
+  // run computation in NCHW data format.
+  const bool compute_in_nhwc =
+      DataTypeToEnum<T>::value == DT_HALF && IsVoltaOrLater(*stream->parent());
+
+  // We only do one directional conversion: NHWC->NCHW. We never convert in the
+  // other direction. Grappler layout optimizer selects preferred layout and
+  // adds necessary annotations to the graph.
+  // TODO(ezhulenev): Convert in other direction for fp16?
+  const TensorFormat compute_data_format =
+      compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW;
+
+  VLOG(3) << "Compute Conv2D with cuDNN:"
+          << " data_format=" << ToString(data_format)
+          << " compute_data_format=" << ToString(compute_data_format);
+
   const int64 out_batch = GetTensorDim(*output, data_format, 'N');
   const int64 out_rows = GetTensorDim(*output, data_format, 'H');
   const int64 out_cols = GetTensorDim(*output, data_format, 'W');
@@ -708,6 +735,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
     // cuDNN only supports padding the same amount on the left and right sides,
     // and on the top and bottom sides. So we manually create a new padded
     // input tensor such that we can pass it to cuDNN.
+    VLOG(4) << "Pad input tensor:"
+            << " padding_top=" << padding_top
+            << " padding_bottom=" << padding_bottom
+            << " padding_left=" << padding_left
+            << " padding_right=" << padding_right;
 
     // TODO(reedwm): In some cases, we can avoid an allocation even if the two
     // padding sides are different. For example, if the input is 2x2, the filter
@@ -750,8 +782,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
     in_cols = new_in_cols;
   }
 
-  if (data_format == FORMAT_NHWC) {
-    // Convert the input tensor from NHWC to NCHW.
+  if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
+    VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
+
     TensorShape nchw_shape =
         ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
     if (in_depths > 1) {
@@ -767,28 +800,48 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
       // If depth <= 1, then just reshape.
       CHECK(input.CopyFrom(input, nchw_shape));
     }
+  } else {
+    CHECK(data_format == compute_data_format)  // Crash OK
+        << "Illegal data and compute format pair:"
+        << " data_format=" << ToString(data_format)
+        << " compute_data_format=" << ToString(compute_data_format);
   }
 
   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
       << "Negative row or col paddings: (" << common_padding_rows << ", "
       << common_padding_cols << ")";
+
+  constexpr auto kComputeInNHWC =
+      std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
+                      se::dnn::FilterLayout::kOutputYXInput);
+  constexpr auto kComputeInNCHW =
+      std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
+                      se::dnn::FilterLayout::kOutputInputYX);
+
+  se::dnn::DataLayout compute_data_layout;
+  se::dnn::FilterLayout filter_layout;
+
+  std::tie(compute_data_layout, filter_layout) =
+      compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
+
   se::dnn::BatchDescriptor input_desc;
   input_desc.set_count(in_batch)
       .set_feature_map_count(in_depths)
       .set_height(in_rows)
       .set_width(in_cols)
-      .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+      .set_layout(compute_data_layout);
   se::dnn::BatchDescriptor output_desc;
   output_desc.set_count(out_batch)
       .set_height(out_rows)
       .set_width(out_cols)
       .set_feature_map_count(out_depths)
-      .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+      .set_layout(compute_data_layout);
   se::dnn::FilterDescriptor filter_desc;
   filter_desc.set_input_filter_height(patch_rows)
       .set_input_filter_width(patch_cols)
       .set_input_feature_map_count(patch_depths)
-      .set_output_feature_map_count(filter.dim_size(3));
+      .set_output_feature_map_count(filter.dim_size(3))
+      .set_layout(filter_layout);
   se::dnn::ConvolutionDescriptor conv_desc;
   conv_desc.set_vertical_dilation_rate(row_dilation)
       .set_horizontal_dilation_rate(col_dilation)
@@ -799,22 +852,42 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
       .set_group_count(in_depths / patch_depths);
 
   Tensor transformed_filter;
-  OP_REQUIRES_OK(ctx, ctx->allocate_temp(
-                          DataTypeToEnum<T>::value,
-                          TensorShape({filter.dim_size(3), filter.dim_size(2),
-                                       filter.dim_size(0), filter.dim_size(1)}),
-                          &transformed_filter));
-  functor::TransformFilter<GPUDevice, T, int, 4>()(
-      ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
-      To32Bit(filter.tensor<T, 4>()),
-      To32Bit(transformed_filter.tensor<T, 4>()));
+
+  const auto transform_filter = [&](FilterTensorFormat dst_format) -> void {
+    VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
+            << " to " << ToString(dst_format);
+
+    TensorShape dst_shape =
+        dst_format == FORMAT_OIHW
+            ? TensorShape({filter.dim_size(3), filter.dim_size(2),
+                           filter.dim_size(0), filter.dim_size(1)})
+            : TensorShape({filter.dim_size(3), filter.dim_size(0),
+                           filter.dim_size(1), filter.dim_size(2)});
+
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
+                                           &transformed_filter));
+    functor::TransformFilter<GPUDevice, T, int, 4>()(
+        ctx->eigen_device<GPUDevice>(), dst_format,
+        To32Bit(filter.tensor<T, 4>()),
+        To32Bit(transformed_filter.tensor<T, 4>()));
+  };
+
+  if (compute_data_format == FORMAT_NCHW) {
+    transform_filter(FORMAT_OIHW);
+  } else if (compute_data_format == FORMAT_NHWC) {
+    transform_filter(FORMAT_OHWI);
+  } else {
+    ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
+                                           ToString(compute_data_format)));
+    return;
+  }
 
   Tensor transformed_output;
-  if (data_format == FORMAT_NHWC) {
-    // Only allocate temporary memory when a layout transformation is needed.
+  if (data_format != compute_data_format) {
+    VLOG(4) << "Allocate temporary memory for output in compute data format";
     OP_REQUIRES_OK(
         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
-                                ShapeFromFormat(FORMAT_NCHW, out_batch,
+                                ShapeFromFormat(compute_data_format, out_batch,
                                                 out_rows, out_cols, out_depths),
                                 &transformed_output));
   } else {
@@ -842,7 +915,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
       in_depths,                // in_depths
       {{in_rows,                // in_rows
         in_cols}},              // in_cols
-      FORMAT_NCHW,              // compute_data_format
+      compute_data_format,      // compute_data_format
       out_depths,               // out_depths
       {{patch_rows,             // filter_rows
         patch_cols,             // filter_cols
@@ -901,6 +974,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
     AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
   }
 
+  VLOG(4) << "Convolution Algorithm: "
+          << algorithm_config.algorithm()->algo_id();
+  VLOG(4) << "tensor_ops_enabled: "
+          << algorithm_config.algorithm()->tensor_ops_enabled();
+
   DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
   bool cudnn_launch_status =
       stream
@@ -916,8 +994,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
         ") filter shape(", filter.shape().DebugString(), ")"));
   }
 
-  // Convert the output tensor back from NCHW to NHWC.
-  if (data_format == FORMAT_NHWC) {
+  if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
+    VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
     functor::NCHWToNHWC<GPUDevice, T, 4>()(
         ctx->eigen_device<GPUDevice>(),
         const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
diff --git a/tensorflow/core/kernels/conv_ops_benchmark_test.cc b/tensorflow/core/kernels/conv_ops_benchmark_test.cc
index 259a2f2e570..a03f62b80b1 100644
--- a/tensorflow/core/kernels/conv_ops_benchmark_test.cc
+++ b/tensorflow/core/kernels/conv_ops_benchmark_test.cc
@@ -29,7 +29,7 @@ limitations under the License.
 namespace tensorflow {
 
 ////////////////////////////////////////////////////////////////////////////////
-// Performance benchmarks for the FusedConv2Op.                               //
+// Performance benchmarks for the Conv2DOp and FusedConv2Op.                  //
 ////////////////////////////////////////////////////////////////////////////////
 
 struct Conv2DGraph {
@@ -63,19 +63,27 @@ struct Conv2DWithBatchNormAndActivationGraph {
   Node* activation;
 };
 
+template <typename T>
 static Tensor MakeRandomTensor(const TensorShape& shape) {
-  Tensor tensor(DT_FLOAT, TensorShape(shape));
-  tensor.flat<float>() = tensor.flat<float>().setRandom();
+  Tensor tensor(DataTypeToEnum<T>::value, TensorShape(shape));
+  tensor.flat<T>() = tensor.flat<T>().setRandom();
   return tensor;
 }
 
 // Creates a simple Tensorflow graph with single Conv2D node.
+template <typename T>
 static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
-                          int filter_w, int filter_h, int out_depth) {
+                          int filter_w, int filter_h, int out_depth,
+                          TensorFormat data_format = FORMAT_NHWC) {
   Graph* graph = new Graph(OpRegistry::Global());
 
-  Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
-  Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
+  Tensor images_t = data_format == FORMAT_NHWC
+                        ? MakeRandomTensor<T>({batch, height, width, in_depth})
+                        : MakeRandomTensor<T>({batch, in_depth, height, width});
+
+  // Filter is always in HWIO.
+  Tensor filter_t =
+      MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
 
   Node* images = test::graph::Constant(graph, images_t, "images");
   Node* filter = test::graph::Constant(graph, filter_t, "filter");
@@ -84,33 +92,35 @@ static Conv2DGraph Conv2D(int batch, int height, int width, int in_depth,
   TF_CHECK_OK(NodeBuilder(graph->NewName("conv"), "Conv2D")
                   .Input(images)
                   .Input(filter)
-                  .Attr("T", DT_FLOAT)
+                  .Attr("T", DataTypeToEnum<T>::value)
                   .Attr("strides", {1, 1, 1, 1})
                   .Attr("padding", "SAME")
+                  .Attr("data_format", ToString(data_format))
                   .Finalize(graph, &conv2d));
 
   return {graph, conv2d};
 }
 
 // Creates a Tensorflow graph with a Conv2D node followed by BiasAdd.
-static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
-                                          int in_depth, int filter_w,
-                                          int filter_h, int out_depth) {
-  Conv2DGraph conv_graph =
-      Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
+template <typename T>
+static Conv2DWithBiasGraph Conv2DWithBias(
+    int batch, int height, int width, int in_depth, int filter_w, int filter_h,
+    int out_depth, TensorFormat data_format = FORMAT_NHWC) {
+  Conv2DGraph conv_graph = Conv2D<T>(batch, height, width, in_depth, filter_w,
+                                     filter_h, out_depth, data_format);
 
   Graph* graph = conv_graph.graph;
   Node* conv2d = conv_graph.conv2d;
 
-  Tensor bias_t = MakeRandomTensor({out_depth});
+  Tensor bias_t = MakeRandomTensor<T>({out_depth});
   Node* bias = test::graph::Constant(graph, bias_t, "bias");
 
   Node* out;
   TF_CHECK_OK(NodeBuilder(graph->NewName("bias"), "BiasAdd")
                   .Input(conv2d)
                   .Input(bias)
-                  .Attr("T", DT_FLOAT)
-                  .Attr("data_format", "NHWC")
+                  .Attr("T", DataTypeToEnum<T>::value)
+                  .Attr("data_format", ToString(data_format))
                   .Finalize(graph, &out));
 
   return {graph, conv2d, out};
@@ -118,11 +128,14 @@ static Conv2DWithBiasGraph Conv2DWithBias(int batch, int height, int width,
 
 // Creates a Tensorflow graph with a Conv2D node followed by BiasAdd and
 // activation (Relu, Relu6, etc...).
+template <typename T>
 static Conv2DWithBiasAndActivationGraph Conv2DWithBiasAndActivation(
     int batch, int height, int width, int in_depth, int filter_w, int filter_h,
-    int out_depth, const string& activation_type) {
-  Conv2DWithBiasGraph conv_graph = Conv2DWithBias(
-      batch, height, width, in_depth, filter_w, filter_h, out_depth);
+    int out_depth, const string& activation_type,
+    TensorFormat data_format = FORMAT_NHWC) {
+  Conv2DWithBiasGraph conv_graph =
+      Conv2DWithBias<T>(batch, height, width, in_depth, filter_w, filter_h,
+                        out_depth, data_format);
 
   Graph* graph = conv_graph.graph;
   Node* conv2d = conv_graph.conv2d;
@@ -131,27 +144,27 @@ static Conv2DWithBiasAndActivationGraph Conv2DWithBiasAndActivation(
   Node* activation;
   TF_CHECK_OK(NodeBuilder(graph->NewName("activation"), activation_type)
                   .Input(bias)
-                  .Attr("T", DT_FLOAT)
+                  .Attr("T", DataTypeToEnum<T>::value)
                   .Finalize(graph, &activation));
 
   return {graph, conv2d, bias, activation};
 }
 
 // Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm.
-static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
-                                                    int width, int in_depth,
-                                                    int filter_w, int filter_h,
-                                                    int out_depth) {
-  Conv2DGraph conv_graph =
-      Conv2D(batch, height, width, in_depth, filter_w, filter_h, out_depth);
+template <typename T>
+static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(
+    int batch, int height, int width, int in_depth, int filter_w, int filter_h,
+    int out_depth, TensorFormat data_format = FORMAT_NHWC) {
+  Conv2DGraph conv_graph = Conv2D<T>(batch, height, width, in_depth, filter_w,
+                                     filter_h, out_depth, data_format);
 
   Graph* graph = conv_graph.graph;
   Node* conv2d = conv_graph.conv2d;
 
-  Tensor scale_t = MakeRandomTensor({out_depth});
-  Tensor offset_t = MakeRandomTensor({out_depth});
-  Tensor mean_t = MakeRandomTensor({out_depth});
-  Tensor variance_t = MakeRandomTensor({out_depth});
+  Tensor scale_t = MakeRandomTensor<T>({out_depth});
+  Tensor offset_t = MakeRandomTensor<T>({out_depth});
+  Tensor mean_t = MakeRandomTensor<T>({out_depth});
+  Tensor variance_t = MakeRandomTensor<T>({out_depth});
 
   Node* scale = test::graph::Constant(graph, scale_t, "scale");
   Node* offset = test::graph::Constant(graph, offset_t, "offset");
@@ -165,8 +178,9 @@ static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
                   .Input(offset)
                   .Input(mean)
                   .Input(variance)
-                  .Attr("T", DT_FLOAT)
+                  .Attr("T", DataTypeToEnum<T>::value)
                   .Attr("is_training", false)
+                  .Attr("data_format", ToString(data_format))
                   .Finalize(graph, &out));
 
   return {graph, conv2d, out};
@@ -174,11 +188,14 @@ static Conv2DWithBatchNormGraph Conv2DWithBatchNorm(int batch, int height,
 
 // Creates a Tensorflow graph with a Conv2D node followed by FusedBatchNorm and
 // activation (Relu, Relu6, etc...).
+template <typename T>
 static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
     int batch, int height, int width, int in_depth, int filter_w, int filter_h,
-    int out_depth, const string& activation_type) {
-  Conv2DWithBatchNormGraph conv_graph = Conv2DWithBatchNorm(
-      batch, height, width, in_depth, filter_w, filter_h, out_depth);
+    int out_depth, const string& activation_type,
+    TensorFormat data_format = FORMAT_NHWC) {
+  Conv2DWithBatchNormGraph conv_graph =
+      Conv2DWithBatchNorm<T>(batch, height, width, in_depth, filter_w, filter_h,
+                             out_depth, data_format);
 
   Graph* graph = conv_graph.graph;
   Node* conv2d = conv_graph.conv2d;
@@ -187,7 +204,7 @@ static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
   Node* activation;
   TF_CHECK_OK(NodeBuilder(graph->NewName("activation"), activation_type)
                   .Input(batch_norm)
-                  .Attr("T", DT_FLOAT)
+                  .Attr("T", DataTypeToEnum<T>::value)
                   .Finalize(graph, &activation));
 
   return {graph, conv2d, batch_norm, activation};
@@ -195,15 +212,22 @@ static Conv2DWithBatchNormAndActivationGraph Conv2DWithBatchNormAndActivation(
 
 // Creates a tensorflow graph with a single FusedConv2D (with BiasAdd) node and
 // fuses into it additional computations (e.g. Relu).
+template <typename T>
 static Graph* FusedConv2DWithBias(int batch, int height, int width,
                                   int in_depth, int filter_w, int filter_h,
                                   int out_depth,
-                                  const std::vector<string>& fused_ops = {}) {
+                                  const std::vector<string>& fused_ops = {},
+                                  TensorFormat data_format = FORMAT_NHWC) {
   Graph* graph = new Graph(OpRegistry::Global());
 
-  Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
-  Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
-  Tensor bias_t = MakeRandomTensor({out_depth});
+  Tensor images_t = data_format == FORMAT_NHWC
+                        ? MakeRandomTensor<T>({batch, height, width, in_depth})
+                        : MakeRandomTensor<T>({batch, in_depth, height, width});
+
+  // Filter is always in HWIO.
+  Tensor filter_t =
+      MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
+  Tensor bias_t = MakeRandomTensor<T>({out_depth});
 
   Node* images = test::graph::Constant(graph, images_t, "images");
   Node* filter = test::graph::Constant(graph, filter_t, "filter");
@@ -217,7 +241,7 @@ static Graph* FusedConv2DWithBias(int batch, int height, int width,
                   .Input(filter)
                   .Attr("num_args", 1)
                   .Input(args)
-                  .Attr("T", DT_FLOAT)
+                  .Attr("T", DataTypeToEnum<T>::value)
                   .Attr("strides", {1, 1, 1, 1})
                   .Attr("padding", "SAME")
                   .Attr("fused_ops", fused_ops)
@@ -228,17 +252,24 @@ static Graph* FusedConv2DWithBias(int batch, int height, int width,
 
 // Creates a tensorflow graph with a single FusedConv2D (with FusedBatchNorm)
 // node and fuses into it additional computations (e.g. Relu).
+template <typename T>
 static Graph* FusedConv2DWithBatchNorm(
     int batch, int height, int width, int in_depth, int filter_w, int filter_h,
-    int out_depth, const std::vector<string>& fused_ops = {}) {
+    int out_depth, const std::vector<string>& fused_ops = {},
+    TensorFormat data_format = FORMAT_NHWC) {
   Graph* graph = new Graph(OpRegistry::Global());
 
-  Tensor images_t = MakeRandomTensor({batch, height, width, in_depth});
-  Tensor filter_t = MakeRandomTensor({filter_w, filter_h, in_depth, out_depth});
-  Tensor scale_t = MakeRandomTensor({out_depth});
-  Tensor offset_t = MakeRandomTensor({out_depth});
-  Tensor mean_t = MakeRandomTensor({out_depth});
-  Tensor variance_t = MakeRandomTensor({out_depth});
+  Tensor images_t = data_format == FORMAT_NHWC
+                        ? MakeRandomTensor<T>({batch, height, width, in_depth})
+                        : MakeRandomTensor<T>({batch, in_depth, height, width});
+
+  // Filter is always in HWIO.
+  Tensor filter_t =
+      MakeRandomTensor<T>({filter_w, filter_h, in_depth, out_depth});
+  Tensor scale_t = MakeRandomTensor<T>({out_depth});
+  Tensor offset_t = MakeRandomTensor<T>({out_depth});
+  Tensor mean_t = MakeRandomTensor<T>({out_depth});
+  Tensor variance_t = MakeRandomTensor<T>({out_depth});
 
   Node* images = test::graph::Constant(graph, images_t, "images");
   Node* filter = test::graph::Constant(graph, filter_t, "filter");
@@ -255,7 +286,7 @@ static Graph* FusedConv2DWithBatchNorm(
                   .Input(filter)
                   .Attr("num_args", 4)
                   .Input(args)
-                  .Attr("T", DT_FLOAT)
+                  .Attr("T", DataTypeToEnum<T>::value)
                   .Attr("strides", {1, 1, 1, 1})
                   .Attr("padding", "SAME")
                   .Attr("fused_ops", fused_ops)
@@ -273,6 +304,10 @@ static Graph* FusedConv2DWithBatchNorm(
 //   FH: filter height
 //   FW: filter width
 
+// -------------------------------------------------------------------------- //
+// Following benchmarks are always using 'float' data type with NHWC layout.
+// -------------------------------------------------------------------------- //
+
 #define BM_SETUP(N, H, W, C, type, LABEL, NAME)                               \
   testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * (C)); \
   testing::SetLabel(LABEL);
@@ -280,39 +315,41 @@ static Graph* FusedConv2DWithBatchNorm(
 #define BM_NAME(name, type, N, H, W, C, FW, FH, FC) \
   name##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
 
-#define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL)                       \
-  static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) {  \
-    BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                               \
-    test::Benchmark(#type, Conv2D(N, H, W, C, FW, FH, FC).graph).Run(iters); \
-  }                                                                          \
+#define BM_Conv2D(N, H, W, C, FW, FH, FC, type, LABEL)                      \
+  static void BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC)(int iters) { \
+    BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                              \
+    test::Benchmark(#type, Conv2D<float>(N, H, W, C, FW, FH, FC).graph)     \
+        .Run(iters);                                                        \
+  }                                                                         \
   BENCHMARK(BM_NAME(BM_Conv2D, type, N, H, W, C, FW, FH, FC));
 
 #define BM_Conv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)           \
   static void BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH,       \
                       FC)(int iters) {                                   \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                           \
-    test::Benchmark(#type, Conv2DWithBias(N, H, W, C, FW, FH, FC).graph) \
+    test::Benchmark(#type,                                               \
+                    Conv2DWithBias<float>(N, H, W, C, FW, FH, FC).graph) \
         .Run(iters);                                                     \
   }                                                                      \
   BENCHMARK(BM_NAME(BM_Conv2DWithBias, type, N, H, W, C, FW, FH, FC));
 
-#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)      \
-  static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH,  \
-                      FC)(int iters) {                                     \
-    BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                             \
-    test::Benchmark(                                                       \
-        #type,                                                             \
-        Conv2DWithBiasAndActivation(N, H, W, C, FW, FH, FC, "Relu").graph) \
-        .Run(iters);                                                       \
-  }                                                                        \
+#define BM_Conv2DWithBiasAndRelu(N, H, W, C, FW, FH, FC, type, LABEL)         \
+  static void BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH,     \
+                      FC)(int iters) {                                        \
+    BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
+    test::Benchmark(#type, Conv2DWithBiasAndActivation<float>(N, H, W, C, FW, \
+                                                              FH, FC, "Relu") \
+                               .graph)                                        \
+        .Run(iters);                                                          \
+  }                                                                           \
   BENCHMARK(BM_NAME(BM_Conv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, FC));
 
 #define BM_FusedConv2DWithBias(N, H, W, C, FW, FH, FC, type, LABEL)           \
   static void BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH,       \
                       FC)(int iters) {                                        \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
-    test::Benchmark(#type,                                                    \
-                    FusedConv2DWithBias(N, H, W, C, FW, FH, FC, {"BiasAdd"})) \
+    test::Benchmark(#type, FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC, \
+                                                      {"BiasAdd"}))           \
         .Run(iters);                                                          \
   }                                                                           \
   BENCHMARK(BM_NAME(BM_FusedConv2DWithBias, type, N, H, W, C, FW, FH, FC));
@@ -321,8 +358,8 @@ static Graph* FusedConv2DWithBatchNorm(
   static void BM_NAME(BM_FusedConv2DWithBiasAndRelu, type, N, H, W, C, FW, FH, \
                       FC)(int iters) {                                         \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                 \
-    test::Benchmark(#type, FusedConv2DWithBias(N, H, W, C, FW, FH, FC,         \
-                                               {"BiasAdd", "Relu"}))           \
+    test::Benchmark(#type, FusedConv2DWithBias<float>(N, H, W, C, FW, FH, FC,  \
+                                                      {"BiasAdd", "Relu"}))    \
         .Run(iters);                                                           \
   }                                                                            \
   BENCHMARK(                                                                   \
@@ -332,7 +369,8 @@ static Graph* FusedConv2DWithBatchNorm(
   static void BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH,       \
                       FC)(int iters) {                                        \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
-    test::Benchmark(#type, Conv2DWithBatchNorm(N, H, W, C, FW, FH, FC).graph) \
+    test::Benchmark(#type,                                                    \
+                    Conv2DWithBatchNorm<float>(N, H, W, C, FW, FH, FC).graph) \
         .Run(iters);                                                          \
   }                                                                           \
   BENCHMARK(BM_NAME(BM_Conv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
@@ -341,8 +379,8 @@ static Graph* FusedConv2DWithBatchNorm(
   static void BM_NAME(BM_Conv2DWithBatchNormAndRelu, type, N, H, W, C, FW, FH, \
                       FC)(int iters) {                                         \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                 \
-    test::Benchmark(#type, Conv2DWithBatchNormAndActivation(N, H, W, C, FW,    \
-                                                            FH, FC, "Relu")    \
+    test::Benchmark(#type, Conv2DWithBatchNormAndActivation<float>(            \
+                               N, H, W, C, FW, FH, FC, "Relu")                 \
                                .graph)                                         \
         .Run(iters);                                                           \
   }                                                                            \
@@ -353,8 +391,8 @@ static Graph* FusedConv2DWithBatchNorm(
   static void BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, \
                       FC)(int iters) {                                       \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                               \
-    test::Benchmark(#type, FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC,  \
-                                                    {"FusedBatchNorm"}))     \
+    test::Benchmark(#type, FusedConv2DWithBatchNorm<float>(                  \
+                               N, H, W, C, FW, FH, FC, {"FusedBatchNorm"}))  \
         .Run(iters);                                                         \
   }                                                                          \
   BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNorm, type, N, H, W, C, FW, FH, FC));
@@ -364,9 +402,9 @@ static Graph* FusedConv2DWithBatchNorm(
   static void BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C,   \
                       FW, FH, FC)(int iters) {                                \
     BM_SETUP(N, H, W, C, type, LABEL, Conv2D);                                \
-    test::Benchmark(#type,                                                    \
-                    FusedConv2DWithBatchNorm(N, H, W, C, FW, FH, FC,          \
-                                             {"FusedBatchNorm", "Relu"}))     \
+    test::Benchmark(                                                          \
+        #type, FusedConv2DWithBatchNorm<float>(N, H, W, C, FW, FH, FC,        \
+                                               {"FusedBatchNorm", "Relu"}))   \
         .Run(iters);                                                          \
   }                                                                           \
   BENCHMARK(BM_NAME(BM_FusedConv2DWithBatchNormAndRelu, type, N, H, W, C, FW, \
@@ -500,4 +538,63 @@ BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
 BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");
 #endif
 
+// Macro arguments names: --------------------------------------------------- //
+//      T: data type
+// FORMAT: data format (NHWC or NCHW)
+//      N: batch size
+//      H: height
+//      W: width
+//      C: channels
+//     FC: filter count
+//     FH: filter height
+//     FW: filter width
+
+// -------------------------------------------------------------------------- //
+// Following benchmarks are used to compare different data format performance
+// for different data types. They make sense only when CUDA enabled, because on
+// CPU we only support data in NHWC.
+// -------------------------------------------------------------------------- //
+
+#define BM_LONG_NAME(name, type, T, FORMAT, N, H, W, C, FW, FH, FC) \
+  name##_##T##_##FORMAT##_##type##_##N##_##H##_##W##_##C##_##FW##_##FH##_##FC
+
+#define BM_Conv2DFmt(T, FORMAT, N, H, W, C, FW, FH, FC, type)                 \
+  static void BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH,    \
+                           FC)(int iters) {                                   \
+    BM_SETUP(N, H, W, C, type, "", Conv2D);                                   \
+    test::Benchmark(#type,                                                    \
+                    Conv2D<T>(N, H, W, C, FW, FH, FC, FORMAT_##FORMAT).graph) \
+        .Run(iters);                                                          \
+  }                                                                           \
+  BENCHMARK(BM_LONG_NAME(BM_Conv2D, type, T, FORMAT, N, H, W, C, FW, FH, FC));
+
+#if GOOGLE_CUDA
+using fp32 = float;
+using fp16 = Eigen::half;
+
+// ResNet50-ish convolutions.
+#define BENCHMARK_DTYPE(BATCH, T)                             \
+  BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 1, 1, 64, gpu);    \
+  BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 1, 1, 256, gpu);   \
+  BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 256, 1, 1, 64, gpu);   \
+  BM_Conv2DFmt(T, NHWC, BATCH, 56, 56, 64, 3, 3, 64, gpu);    \
+                                                              \
+  BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 128, 1, 1, 128, gpu);  \
+  BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 128, 1, 1, 512, gpu);  \
+  BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 512, 1, 1, 128, gpu);  \
+  BM_Conv2DFmt(T, NHWC, BATCH, 28, 28, 512, 3, 3, 128, gpu);  \
+                                                              \
+  BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 1, 1, 256, gpu);  \
+  BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \
+  BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \
+  BM_Conv2DFmt(T, NHWC, BATCH, 14, 14, 256, 3, 3, 256, gpu);
+
+BENCHMARK_DTYPE(32, fp32);
+BENCHMARK_DTYPE(32, fp16);
+
+BENCHMARK_DTYPE(64, fp32);
+BENCHMARK_DTYPE(64, fp16);
+
+#endif  // GOOGLE_CUDA
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index f331973f5ce..5dbd8ef318f 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -63,6 +63,8 @@ string ToString(FilterTensorFormat format) {
       return "HWIO";
     case FORMAT_OIHW:
       return "OIHW";
+    case FORMAT_OHWI:
+      return "OHWI";
     case FORMAT_OIHW_VECT_I:
       return "OIHW_VECT_I";
     default:
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 643e14e0b56..82af5c545f7 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -80,6 +80,9 @@ enum FilterTensorFormat {
   // FORMAT_OIHW often improves performance on GPUs.
   FORMAT_OIHW = 1,
 
+  // FORMAT_OHWI used by cuDNN for NHWC convolutions.
+  FORMAT_OHWI = 2,
+
   // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
   // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
   // data format. It is laid out in the same order as OIHW, except that the size
@@ -88,7 +91,7 @@ enum FilterTensorFormat {
   // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
   // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
   // A pre-condition of this format is that I must be a multiple of 4.
-  FORMAT_OIHW_VECT_I = 2,
+  FORMAT_OIHW_VECT_I = 3,
 };
 
 // Parse tensor format from the given string.