Improve performance of fp16 DepthwiseConv2DBackpropFilter.

When cuDNN is not used, performance and numeric stability is improved by casting inputs to float32 and outputs back to float16. The original implementation does accumulation in float16 and is slow for unknown reasons.

Running the benchmark in [this comment](https://github.com/tensorflow/tensorflow/issues/41715#issuecomment-664705080) on my machine with a Titan V, I get the following numbers. All numbers are in seconds.

```
bench                         before    after
float16 NHWC backprop_filter  7.6379    0.0098
float16 NCHW backprop_filter  4.1965    0.0449
float32 NHWC backprop_filter  0.0094    0.0094
float32 NCHW backprop_filter  0.0449    0.0444
```

Fixes https://github.com/tensorflow/tensorflow/issues/41715.

PiperOrigin-RevId: 325508729
Change-Id: I694a62dcdd8731bc90e98d2a09486160d8740b5f
This commit is contained in:
Reed Wanderman-Milne 2020-08-07 14:27:14 -07:00 committed by TensorFlower Gardener
parent 4684e40f18
commit 201d45cea2
3 changed files with 40 additions and 5 deletions
tensorflow

View File

@ -4422,6 +4422,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
":cast_op",
] + if_cuda([
"@local_config_cuda//cuda:cudnn_header",
]),

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/lib/core/status.h"
@ -1180,12 +1181,45 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
return;
}
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto input_ptr = input.template flat<T>().data();
auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
// For GPU inputs with type half, we cast inputs to float and outputs back
// to half, as half implementation is slow and does not use full precision
// accumulation in some cases.
constexpr bool cast_to_float = std::is_same<T, Eigen::half>::value &&
std::is_same<Device, GPUDevice>::value;
using U = typename std::conditional<cast_to_float, float, T>::type;
Tensor casted_out_backprop = out_backprop;
Tensor casted_input = input;
Tensor casted_filter_backprop = *filter_backprop;
const Device& device = context->template eigen_device<Device>();
if (cast_to_float) {
functor::CastFunctor<Device, float, Eigen::half> cast;
OP_REQUIRES_OK(context,
context->allocate_temp(DT_FLOAT, out_backprop.shape(),
&casted_out_backprop));
cast(device, casted_out_backprop.template flat<float>(),
out_backprop.template flat<Eigen::half>());
OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, input.shape(),
&casted_input));
cast(device, casted_input.template flat<float>(),
input.template flat<Eigen::half>());
OP_REQUIRES_OK(context,
context->allocate_temp(DT_FLOAT, filter_backprop->shape(),
&casted_filter_backprop));
}
auto out_backprop_ptr = casted_out_backprop.template flat<U>().data();
auto input_ptr = casted_input.template flat<U>().data();
auto filter_backprop_ptr = casted_filter_backprop.template flat<U>().data();
LaunchDepthwiseConvBackpropFilterOp<Device, U>()(
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
data_format_);
if (cast_to_float) {
functor::CastFunctor<Device, Eigen::half, float> cast;
const Tensor& casted_filter_backprop_const = casted_filter_backprop;
cast(device, filter_backprop->template flat<Eigen::half>(),
casted_filter_backprop_const.template flat<float>());
}
}
protected:

View File

@ -832,7 +832,7 @@ class DepthwiseConv2DTest(test.TestCase):
# double datatype is currently not supported for convolution ops
# on the ROCm platform
optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
for data_type in ([dtypes.float32] + optional_float64):
for data_type in ([dtypes.float16, dtypes.float32] + optional_float64):
self._ConstructAndTestGradient(
input_size,
filter_size,