From 201d45cea27c1792a86b3fc7eb688fb2dd1d0df1 Mon Sep 17 00:00:00 2001
From: Reed Wanderman-Milne <reedwm@google.com>
Date: Fri, 7 Aug 2020 14:27:14 -0700
Subject: [PATCH] Improve performance of fp16 DepthwiseConv2DBackpropFilter.

When cuDNN is not used, performance and numeric stability is improved by casting inputs to float32 and outputs back to float16. The original implementation does accumulation in float16 and is slow for unknown reasons.

Running the benchmark in [this comment](https://github.com/tensorflow/tensorflow/issues/41715#issuecomment-664705080) on my machine with a Titan V, I get the following numbers. All numbers are in seconds.

```
bench                         before    after
float16 NHWC backprop_filter  7.6379    0.0098
float16 NCHW backprop_filter  4.1965    0.0449
float32 NHWC backprop_filter  0.0094    0.0094
float32 NCHW backprop_filter  0.0449    0.0444
```

Fixes https://github.com/tensorflow/tensorflow/issues/41715.

PiperOrigin-RevId: 325508729
Change-Id: I694a62dcdd8731bc90e98d2a09486160d8740b5f
---
 tensorflow/core/kernels/BUILD                 |  1 +
 .../core/kernels/depthwise_conv_grad_op.cc    | 42 +++++++++++++++++--
 .../kernel_tests/depthwise_conv_op_test.py    |  2 +-
 3 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f98b510b96f..99970a9558c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4422,6 +4422,7 @@ tf_kernel_library(
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        ":cast_op",
     ] + if_cuda([
         "@local_config_cuda//cuda:cudnn_header",
     ]),
diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
index 310bd73ba65..b809e1d1065 100644
--- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
@@ -27,6 +27,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/cast_op.h"
 #include "tensorflow/core/kernels/conv_grad_ops.h"
 #include "tensorflow/core/kernels/depthwise_conv_op.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -1180,12 +1181,45 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
       return;
     }
 
-    auto out_backprop_ptr = out_backprop.template flat<T>().data();
-    auto input_ptr = input.template flat<T>().data();
-    auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
-    LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
+    // For GPU inputs with type half, we cast inputs to float and outputs back
+    // to half, as half implementation is slow and does not use full precision
+    // accumulation in some cases.
+    constexpr bool cast_to_float = std::is_same<T, Eigen::half>::value &&
+                                   std::is_same<Device, GPUDevice>::value;
+    using U = typename std::conditional<cast_to_float, float, T>::type;
+    Tensor casted_out_backprop = out_backprop;
+    Tensor casted_input = input;
+    Tensor casted_filter_backprop = *filter_backprop;
+    const Device& device = context->template eigen_device<Device>();
+    if (cast_to_float) {
+      functor::CastFunctor<Device, float, Eigen::half> cast;
+      OP_REQUIRES_OK(context,
+                     context->allocate_temp(DT_FLOAT, out_backprop.shape(),
+                                            &casted_out_backprop));
+      cast(device, casted_out_backprop.template flat<float>(),
+           out_backprop.template flat<Eigen::half>());
+      OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, input.shape(),
+                                                     &casted_input));
+      cast(device, casted_input.template flat<float>(),
+           input.template flat<Eigen::half>());
+      OP_REQUIRES_OK(context,
+                     context->allocate_temp(DT_FLOAT, filter_backprop->shape(),
+                                            &casted_filter_backprop));
+    }
+
+    auto out_backprop_ptr = casted_out_backprop.template flat<U>().data();
+    auto input_ptr = casted_input.template flat<U>().data();
+    auto filter_backprop_ptr = casted_filter_backprop.template flat<U>().data();
+    LaunchDepthwiseConvBackpropFilterOp<Device, U>()(
         context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
         data_format_);
+
+    if (cast_to_float) {
+      functor::CastFunctor<Device, Eigen::half, float> cast;
+      const Tensor& casted_filter_backprop_const = casted_filter_backprop;
+      cast(device, filter_backprop->template flat<Eigen::half>(),
+           casted_filter_backprop_const.template flat<float>());
+    }
   }
 
  protected:
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 093de720b53..266a0f8d0fb 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -832,7 +832,7 @@ class DepthwiseConv2DTest(test.TestCase):
       # double datatype is currently not supported for convolution ops
       # on the ROCm platform
       optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
-      for data_type in ([dtypes.float32] + optional_float64):
+      for data_type in ([dtypes.float16, dtypes.float32] + optional_float64):
         self._ConstructAndTestGradient(
             input_size,
             filter_size,