diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f98b510b96f..99970a9558c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4422,6 +4422,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + ":cast_op", ] + if_cuda([ "@local_config_cuda//cuda:cudnn_header", ]), diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 310bd73ba65..b809e1d1065 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/lib/core/status.h" @@ -1180,12 +1181,45 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel { return; } - auto out_backprop_ptr = out_backprop.template flat().data(); - auto input_ptr = input.template flat().data(); - auto filter_backprop_ptr = filter_backprop->template flat().data(); - LaunchDepthwiseConvBackpropFilterOp()( + // For GPU inputs with type half, we cast inputs to float and outputs back + // to half, as half implementation is slow and does not use full precision + // accumulation in some cases. + constexpr bool cast_to_float = std::is_same::value && + std::is_same::value; + using U = typename std::conditional::type; + Tensor casted_out_backprop = out_backprop; + Tensor casted_input = input; + Tensor casted_filter_backprop = *filter_backprop; + const Device& device = context->template eigen_device(); + if (cast_to_float) { + functor::CastFunctor cast; + OP_REQUIRES_OK(context, + context->allocate_temp(DT_FLOAT, out_backprop.shape(), + &casted_out_backprop)); + cast(device, casted_out_backprop.template flat(), + out_backprop.template flat()); + OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, input.shape(), + &casted_input)); + cast(device, casted_input.template flat(), + input.template flat()); + OP_REQUIRES_OK(context, + context->allocate_temp(DT_FLOAT, filter_backprop->shape(), + &casted_filter_backprop)); + } + + auto out_backprop_ptr = casted_out_backprop.template flat().data(); + auto input_ptr = casted_input.template flat().data(); + auto filter_backprop_ptr = casted_filter_backprop.template flat().data(); + LaunchDepthwiseConvBackpropFilterOp()( context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr, data_format_); + + if (cast_to_float) { + functor::CastFunctor cast; + const Tensor& casted_filter_backprop_const = casted_filter_backprop; + cast(device, filter_backprop->template flat(), + casted_filter_backprop_const.template flat()); + } } protected: diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 093de720b53..266a0f8d0fb 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -832,7 +832,7 @@ class DepthwiseConv2DTest(test.TestCase): # double datatype is currently not supported for convolution ops # on the ROCm platform optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64] - for data_type in ([dtypes.float32] + optional_float64): + for data_type in ([dtypes.float16, dtypes.float32] + optional_float64): self._ConstructAndTestGradient( input_size, filter_size,