Merge pull request #40552 from kaixih:relu_grad_vect_pr
PiperOrigin-RevId: 317199014 Change-Id: I8b8d8aae8f3549a34bcb32df7ee5d37544426058
This commit is contained in:
commit
3d904e9c83
@ -35,6 +35,7 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
static constexpr int VectorSizeElements = 8;
|
||||
namespace functor {
|
||||
|
||||
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
|
||||
@ -93,6 +94,66 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ReluGradHalfKernelVector(
|
||||
const Eigen::half* __restrict__ gradient,
|
||||
const Eigen::half* __restrict__ feature, Eigen::half* __restrict__ backprop,
|
||||
int32 count) {
|
||||
int32 half8_count = count / VectorSizeElements;
|
||||
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index < half8_count) {
|
||||
// Cast to xx_h8 for vector load and store.
|
||||
float4 gradient_h8 = reinterpret_cast<const float4*>(gradient)[index];
|
||||
float4 feature_h8 = reinterpret_cast<const float4*>(feature)[index];
|
||||
float4* p_backprop_h8 = reinterpret_cast<float4*>(backprop) + index;
|
||||
|
||||
half2* gradient_h2 = reinterpret_cast<half2*>(&gradient_h8);
|
||||
half2* feature_h2 = reinterpret_cast<half2*>(&feature_h8);
|
||||
float4 backprop_h8;
|
||||
half2* p_backprop_h2 = reinterpret_cast<half2*>(&backprop_h8);
|
||||
|
||||
// Fast path, when half2 primitives are available.
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
const half2 kZeroH2 = __float2half2_rn(0.f);
|
||||
#endif
|
||||
for (int i = 0; i < VectorSizeElements / 2; i++) {
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
// mask = (feature > 0)
|
||||
half2 mask_h2 = __hgt2(feature_h2[i], kZeroH2);
|
||||
// backprop = mask * gradient
|
||||
half2 backprop_h2 = __hmul2(mask_h2, gradient_h2[i]);
|
||||
#else
|
||||
// Fall back: convert half2 to float2 for processing.
|
||||
float2 feature_f2 = __half22float2(feature_h2[i]);
|
||||
float2 gradient_f2 = __half22float2(gradient_h2[i]);
|
||||
float2 backprop_f2 =
|
||||
make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
|
||||
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
|
||||
// Convert back to half2.
|
||||
half2 backprop_h2 = __float22half2_rn(backprop_f2);
|
||||
#endif
|
||||
p_backprop_h2[i] = backprop_h2;
|
||||
}
|
||||
// Write back the result.
|
||||
*p_backprop_h8 = backprop_h8;
|
||||
}
|
||||
|
||||
int remaining_count = (count % VectorSizeElements);
|
||||
|
||||
if (index < remaining_count) {
|
||||
// Use first threads to process the remaining elements.
|
||||
Eigen::half grad_h = gradient[half8_count * VectorSizeElements + index];
|
||||
Eigen::half feature_h = feature[half8_count * VectorSizeElements + index];
|
||||
|
||||
float grad_f = static_cast<float>(grad_h);
|
||||
float feature_f = static_cast<float>(feature_h);
|
||||
float backprop_f = (feature_f > 0) ? grad_f : 0;
|
||||
|
||||
Eigen::half backprop_h(backprop_f);
|
||||
backprop[half8_count * VectorSizeElements + index] = backprop_h;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
struct ReluGrad<Device, Eigen::half> {
|
||||
// Computes ReluGrad backprop.
|
||||
@ -108,15 +169,28 @@ struct ReluGrad<Device, Eigen::half> {
|
||||
// NOTE: When the activation is exactly zero, we do not propagate the
|
||||
// associated gradient value. This allows the output of the Relu to be used,
|
||||
// as well as its input.
|
||||
auto gradient_ptr = reinterpret_cast<uintptr_t>(gradient.data());
|
||||
auto feature_ptr = reinterpret_cast<uintptr_t>(feature.data());
|
||||
auto backprop_ptr = reinterpret_cast<uintptr_t>(backprop.data());
|
||||
bool aligned = gradient_ptr % 16 == 0 && feature_ptr % 16 == 0 &&
|
||||
backprop_ptr % 16 == 0;
|
||||
int32 count = gradient.size();
|
||||
if (count == 0) return;
|
||||
int32 half2_count = Eigen::divup(count, 2);
|
||||
constexpr int32 kThreadInBlock = 512;
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||
if (count == 0) return;
|
||||
if (aligned) {
|
||||
int32 half8_count = Eigen::divup(count, VectorSizeElements);
|
||||
int32 kBlock = Eigen::divup(half8_count, kThreadInBlock);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernelVector, kBlock, kThreadInBlock, 0, d.stream(),
|
||||
gradient.data(), feature.data(), backprop.data(), count));
|
||||
} else {
|
||||
int32 half2_count = Eigen::divup(count, 2);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user