Merge pull request #28758 from ThisIsIsaac:master
PiperOrigin-RevId: 256383659
This commit is contained in:
commit
81acfa851e
@ -31,6 +31,87 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void ResizeBilinearKernel_faster(
|
||||||
|
const int num_channel_threads, const T* __restrict__ images,
|
||||||
|
float height_scale, float width_scale, int batch, int in_height,
|
||||||
|
int in_width, int channels, int out_height, int out_width,
|
||||||
|
float* __restrict__ output) {
|
||||||
|
constexpr int kChannelsPerThread = 16 / sizeof(T);
|
||||||
|
for (int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
out_idx < out_width * out_height * num_channel_threads;
|
||||||
|
out_idx += blockDim.x * gridDim.x) {
|
||||||
|
int idx = out_idx;
|
||||||
|
const int c_start = idx % num_channel_threads;
|
||||||
|
idx /= num_channel_threads;
|
||||||
|
const int x = idx % out_width;
|
||||||
|
idx /= out_width;
|
||||||
|
const int y = idx % out_height;
|
||||||
|
|
||||||
|
const float in_y = (static_cast<float>(y) + 0.5f) * height_scale - 0.5f;
|
||||||
|
|
||||||
|
const int top_y_index = in_y > 0.0 ? floorf(in_y) : 0;
|
||||||
|
const int bottom_y_index =
|
||||||
|
(in_y < in_height - 1) ? ceilf(in_y) : in_height - 1;
|
||||||
|
const float y_lerp = in_y - floorf(in_y);
|
||||||
|
|
||||||
|
const float in_x = (static_cast<float>(x) + 0.5f) * width_scale - 0.5f;
|
||||||
|
const int left_x_index = in_x > 0.0 ? floorf(in_x) : 0;
|
||||||
|
const int right_x_index =
|
||||||
|
(in_x < in_width - 1) ? ceilf(in_x) : in_width - 1;
|
||||||
|
const float x_lerp = in_x - left_x_index;
|
||||||
|
|
||||||
|
float top_left_reg[kChannelsPerThread];
|
||||||
|
float top_right_reg[kChannelsPerThread];
|
||||||
|
float bottom_left_reg[kChannelsPerThread];
|
||||||
|
float bottom_right_reg[kChannelsPerThread];
|
||||||
|
float out_reg[kChannelsPerThread];
|
||||||
|
for (int b = 0; b < batch; b++) {
|
||||||
|
for (int c = c_start * kChannelsPerThread; c < channels;
|
||||||
|
c += kChannelsPerThread * num_channel_threads) {
|
||||||
|
// 16 byte read from global memory and cache them in registers.
|
||||||
|
((float4*)top_left_reg)[0] =
|
||||||
|
((float4*)images)[(((b * in_height + top_y_index) * in_width +
|
||||||
|
left_x_index) *
|
||||||
|
channels +
|
||||||
|
c) /
|
||||||
|
4];
|
||||||
|
((float4*)top_right_reg)[0] =
|
||||||
|
((float4*)images)[(((b * in_height + top_y_index) * in_width +
|
||||||
|
right_x_index) *
|
||||||
|
channels +
|
||||||
|
c) /
|
||||||
|
4];
|
||||||
|
((float4*)bottom_left_reg)[0] =
|
||||||
|
((float4*)images)[(((b * in_height + bottom_y_index) * in_width +
|
||||||
|
left_x_index) *
|
||||||
|
channels +
|
||||||
|
c) /
|
||||||
|
4];
|
||||||
|
((float4*)bottom_right_reg)[0] =
|
||||||
|
((float4*)images)[(((b * in_height + bottom_y_index) * in_width +
|
||||||
|
right_x_index) *
|
||||||
|
channels +
|
||||||
|
c) /
|
||||||
|
4];
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < kChannelsPerThread; ++unroll) {
|
||||||
|
const float top =
|
||||||
|
top_left_reg[unroll] +
|
||||||
|
(top_right_reg[unroll] - top_left_reg[unroll]) * x_lerp;
|
||||||
|
const float bottom =
|
||||||
|
bottom_left_reg[unroll] +
|
||||||
|
(bottom_right_reg[unroll] - bottom_left_reg[unroll]) * x_lerp;
|
||||||
|
out_reg[unroll] = top + (bottom - top) * y_lerp;
|
||||||
|
}
|
||||||
|
((float4*)
|
||||||
|
output)[(((b * out_height + y) * out_width + x) * channels + c) /
|
||||||
|
4] = ((float4*)out_reg)[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
|
__global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
|
||||||
float height_scale, float width_scale,
|
float height_scale, float width_scale,
|
||||||
@ -279,21 +360,36 @@ struct ResizeBilinear<GPUDevice, T> {
|
|||||||
if (total_count == 0) return;
|
if (total_count == 0) return;
|
||||||
|
|
||||||
GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
|
||||||
|
void (*kernel)(const int num_threads, const T* __restrict__ images,
|
||||||
|
float height_scale, float width_scale, int batch,
|
||||||
|
int in_height, int in_width, int channels, int out_height,
|
||||||
|
int out_width, float* __restrict__ output) =
|
||||||
|
LegacyResizeBilinearKernel<T>;
|
||||||
|
|
||||||
if (half_pixel_centers) {
|
if (half_pixel_centers) {
|
||||||
TF_CHECK_OK(
|
// If centers are not at half-pixel, use the legacy kernel instead.
|
||||||
GpuLaunchKernel(ResizeBilinearKernel<T>, dim3(config.block_count),
|
kernel = ResizeBilinearKernel<T>;
|
||||||
dim3(config.thread_per_block), 0, d.stream(),
|
|
||||||
config.virtual_thread_count, images.data(),
|
// 16 bytes per thread and 8 threads for coalesced 128 bytes global memory
|
||||||
height_scale, width_scale, batch, in_height, in_width,
|
// access.
|
||||||
channels, out_height, out_width, output.data()));
|
constexpr int max_num_threads_per_pixel = 8;
|
||||||
} else {
|
constexpr int channels_per_thread = 16 / sizeof(T);
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
if (channels % channels_per_thread == 0 &&
|
||||||
LegacyResizeBilinearKernel<T>, dim3(config.block_count),
|
std::is_same<float, T>::value) {
|
||||||
dim3(config.thread_per_block), 0, d.stream(),
|
int num_threads_per_pixel =
|
||||||
config.virtual_thread_count, images.data(), height_scale, width_scale,
|
std::min(max_num_threads_per_pixel, channels / channels_per_thread);
|
||||||
batch, in_height, in_width, channels, out_height, out_width,
|
config = GetGpuLaunchConfig(
|
||||||
output.data()));
|
out_height * out_width * num_threads_per_pixel, d);
|
||||||
|
config.virtual_thread_count = num_threads_per_pixel;
|
||||||
|
kernel = ResizeBilinearKernel_faster<T>;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_CHECK_OK(
|
||||||
|
GpuLaunchKernel(kernel, config.block_count, config.thread_per_block, 0,
|
||||||
|
d.stream(), config.virtual_thread_count, images.data(),
|
||||||
|
height_scale, width_scale, batch, in_height, in_width,
|
||||||
|
channels, out_height, out_width, output.data()));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user