PR #39767: Wider vector for FP16 RELU Grad on GPUs

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/39767

This PR uses wider vector (8 FP16 values) for loading and storing in FP16 ReluGrad kernel to improve performance on Nvidia Ampere GPUs.

For older GPUs, the performance is expected to be unchanged.

fyi @nluehr
Copybara import of the project:

--
ab809024a4 by Kaixi Hou <kaixih@nvidia.com>:

Enable wider vector for reluGrad...

PiperOrigin-RevId: 316735014
Change-Id: Ic4d93a211e52844f9804ed6c2c4a0346052ceb1e
This commit is contained in:
A. Unique TensorFlower 2020-06-16 12:31:05 -07:00 committed by TensorFlower Gardener
parent aee694363c
commit 1a03ea7e61
1 changed files with 7 additions and 80 deletions

View File

@ -35,7 +35,6 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
static constexpr int VectorSizeElements = 8;
namespace functor {
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
@ -94,65 +93,6 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
}
}
__global__ void ReluGradHalfKernelVector(
const Eigen::half* __restrict__ gradient,
const Eigen::half* __restrict__ feature, Eigen::half* __restrict__ backprop,
int32 count) {
int32 half8_count = count / VectorSizeElements;
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < half8_count) {
// Cast to xx_h8 for vector load and store.
float4 gradient_h8 = reinterpret_cast<const float4*>(gradient)[index];
float4 feature_h8 = reinterpret_cast<const float4*>(feature)[index];
float4* p_backprop_h8 = reinterpret_cast<float4*>(backprop) + index;
half2* gradient_h2 = reinterpret_cast<half2*>(&gradient_h8);
half2* feature_h2 = reinterpret_cast<half2*>(&feature_h8);
float4 backprop_h8;
half2* p_backprop_h2 = reinterpret_cast<half2*>(&backprop_h8);
// Fast path, when half2 primitives are available.
#if __CUDA_ARCH__ >= 530
const half2 kZeroH2 = __float2half2_rn(0.f);
#endif
for (int i = 0; i < VectorSizeElements / 2; i++) {
#if __CUDA_ARCH__ >= 530
// mask = (feature > 0)
half2 mask_h2 = __hgt2(feature_h2[i], kZeroH2);
// backprop = mask * gradient
half2 backprop_h2 = __hmul2(mask_h2, gradient_h2[i]);
#else
// Fall back: convert half2 to float2 for processing.
float2 feature_f2 = __half22float2(feature_h2[i]);
float2 gradient_f2 = __half22float2(gradient_h2[i]);
float2 backprop_f2 = make_float2((feature_f2.x > 0) ? gradient_f2.x : 0,
(feature_f2.y > 0) ? gradient_f2.y : 0);
// Convert back to half2.
half2 backprop_h2 = __float22half2_rn(backprop_f2);
#endif
p_backprop_h2[i] = backprop_h2;
}
// Write back the result.
*p_backprop_h8 = backprop_h8;
}
int remaining_count = (count % VectorSizeElements);
if (index < remaining_count) {
// Use first threads to process the remaining elements.
Eigen::half grad_h = gradient[half8_count * VectorSizeElements + index];
Eigen::half feature_h = feature[half8_count * VectorSizeElements + index];
float grad_f = static_cast<float>(grad_h);
float feature_f = static_cast<float>(feature_h);
float backprop_f = (feature_f > 0) ? grad_f : 0;
Eigen::half backprop_h(backprop_f);
backprop[half8_count * VectorSizeElements + index] = backprop_h;
}
}
template <typename Device>
struct ReluGrad<Device, Eigen::half> {
// Computes ReluGrad backprop.
@ -168,28 +108,15 @@ struct ReluGrad<Device, Eigen::half> {
// NOTE: When the activation is exactly zero, we do not propagate the
// associated gradient value. This allows the output of the Relu to be used,
// as well as its input.
auto gradient_ptr = reinterpret_cast<uintptr_t>(gradient.data());
auto feature_ptr = reinterpret_cast<uintptr_t>(feature.data());
auto backprop_ptr = reinterpret_cast<uintptr_t>(backprop.data());
bool aligned = gradient_ptr % 16 == 0 && feature_ptr % 16 == 0 &&
backprop_ptr % 16 == 0;
int32 count = gradient.size();
constexpr int32 kThreadInBlock = 512;
if (count == 0) return;
if (aligned) {
int32 half8_count = Eigen::divup(count, VectorSizeElements);
int32 kBlock = Eigen::divup(half8_count, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernelVector, kBlock, kThreadInBlock, 0, d.stream(),
gradient.data(), feature.data(), backprop.data(), count));
} else {
int32 half2_count = Eigen::divup(count, 2);
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
}
int32 half2_count = Eigen::divup(count, 2);
constexpr int32 kThreadInBlock = 512;
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
}
};