added
This commit is contained in:
parent
1fb966eb22
commit
35c466bf4f
@ -25,12 +25,78 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
// auxilary 16-byte datatype for ResizeBilinearKernel_faster
|
||||||
|
// the fields are not important. The only purpose of this is to read 16 bytes
|
||||||
|
// from GPU gloal memory
|
||||||
|
struct four_floats{
|
||||||
|
float a;
|
||||||
|
float b;
|
||||||
|
float c;
|
||||||
|
float d;
|
||||||
|
};
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, int C_UNROLL>
|
||||||
|
__global__ void ResizeBilinearKernel_faster(const int num_channel_thread, const T* __restrict__ images,
|
||||||
|
float height_scale, float width_scale,
|
||||||
|
int batch, int in_height, int in_width,
|
||||||
|
int channels, int out_height,
|
||||||
|
int out_width, float* __restrict__ output) {
|
||||||
|
|
||||||
|
for (int out_idx = blockIdx.x * blockDim.x + threadIdx.x; out_idx < out_width*out_height*num_channel_per_thread; out_idx += blockDim.x * gridDim.x){
|
||||||
|
int idx = out_idx;
|
||||||
|
const int c_start = idx % num_channel_thread;
|
||||||
|
idx /= num_channel_thread;
|
||||||
|
const int x = idx % out_width;
|
||||||
|
idx /= out_width;
|
||||||
|
const int y = idx % out_height;
|
||||||
|
|
||||||
|
const float in_y = (static_cast<float>(y) + 0.5f) * height_scale - 0.5f;
|
||||||
|
|
||||||
|
const int top_y_index = in_y > 0.0 ? floorf(in_y) : 0;
|
||||||
|
const int bottom_y_index =
|
||||||
|
(in_y < in_height - 1) ? ceilf(in_y) : in_height - 1;
|
||||||
|
const float y_lerp = in_y - floorf(in_y);
|
||||||
|
|
||||||
|
const float in_x = (static_cast<float>(x) + 0.5f) * width_scale - 0.5f;
|
||||||
|
const int left_x_index = in_x > 0.0 ? floorf(in_x) : 0;
|
||||||
|
const int right_x_index =
|
||||||
|
(in_x < in_width - 1) ? ceilf(in_x) : in_width - 1;
|
||||||
|
const float x_lerp = in_x - left_x_index;
|
||||||
|
|
||||||
|
|
||||||
|
float top_left_reg[C_UNROLL];
|
||||||
|
float top_right_reg[C_UNROLL];
|
||||||
|
float bottom_left_reg[C_UNROLL];
|
||||||
|
float bottom_right_reg[C_UNROLL];
|
||||||
|
float out_reg[C_UNROLL];
|
||||||
|
for (int b =0; b < batch; b++) {
|
||||||
|
for (int c = c_start*C_UNROLL; c < channels; c+= C_UNROLL*num_channel_per_thread) {
|
||||||
|
|
||||||
|
// 16 byte read from global memroy and cache them in registers
|
||||||
|
((four_floats*) top_left_reg)[0] = ((four_floats*) images)[(((b * in_height + top_y_index) * in_width + left_x_index) * channels + c)/4 ];
|
||||||
|
((four_floats*) top_right_reg)[0] = ((four_floats*) images)[(((b * in_height + top_y_index) * in_width + right_x_index) * channels + c)/4];
|
||||||
|
((four_floats*) bottom_left_reg)[0] = ((four_floats*) images)[(((b * in_height + bottom_y_index) * in_width + left_x_index) * channels + c)/4];
|
||||||
|
((four_floats*) bottom_right_reg)[0] = ((four_floats*) images)[(((b * in_height + bottom_y_index) * in_width + right_x_index) * channels +c)/4];
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < C_UNROLL; unroll+=1){
|
||||||
|
const float top = top_left_reg[unroll] + (top_right_reg[unroll] - top_left_reg[unroll]) * x_lerp;
|
||||||
|
const float bottom = bottom_left_reg[unroll] + (bottom_right_reg[unroll] - bottom_left_reg[unroll]) * x_lerp;
|
||||||
|
out_reg[unroll] = top + (bottom - top) * y_lerp;
|
||||||
|
}
|
||||||
|
((four_floats*) output)[(((b *out_height + y) * out_width + x) * channels + c)/4] = ((four_floats*) out_reg)[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
|
__global__ void ResizeBilinearKernel(const int32 nthreads, const T* images,
|
||||||
float height_scale, float width_scale,
|
float height_scale, float width_scale,
|
||||||
@ -278,23 +344,55 @@ struct ResizeBilinear<GPUDevice, T> {
|
|||||||
const int total_count = batch * out_height * out_width * channels;
|
const int total_count = batch * out_height * out_width * channels;
|
||||||
if (total_count == 0) return;
|
if (total_count == 0) return;
|
||||||
|
|
||||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
// ResizeBilinearKernel_faster is 30 ~ 50% faster than ResizeBilinearKernel
|
||||||
if (half_pixel_centers) {
|
// but can only be used when channels is multiple of 4 and size of input
|
||||||
|
// elemnt is the same as float
|
||||||
|
if (channels % 4 == 0 && sizeof(float) == sizeof(T) && half_pixel_centers) {
|
||||||
|
// since each thread reads 16 bytes, and we need at most 8 of such threads
|
||||||
|
// to make the full use of 128 bytes of global memroy read & write
|
||||||
|
const int channel_per_thread = 16 / sizeof(float);
|
||||||
|
|
||||||
|
// since each global memroy read from L1 cahce is 128 bytes, and each thread
|
||||||
|
// reads 16 bytes, we need 8 threads to fully coalesce 128 bytes of read & store
|
||||||
|
const int max_num_channel_thread = 8;
|
||||||
|
|
||||||
|
// number of threads that will iterate through the channel dimension
|
||||||
|
const int num_channel_thread = std::min(max_num_channel_per_thread,
|
||||||
|
num_channels/channel_per_thread);
|
||||||
|
|
||||||
|
GpuLaunchConfig config = GetCudaLaunchConfig(out_height * out_width *
|
||||||
|
num_channel_thread, d);
|
||||||
|
|
||||||
|
TF_CHECK_OK(CudaLaunchKernel(
|
||||||
|
ResizeBilinearKernel_faster<T, channel_per_thread>,
|
||||||
|
config.block_count, config.thread_per_block, 0, d.stream(),
|
||||||
|
num_channel_thread, images.data(), height_scale, width_scale, batch,
|
||||||
|
in_height, in_width, channels, out_height, out_width, output.data()));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d);
|
||||||
|
|
||||||
|
if (half_pixel_centers) {
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
TF_CHECK_OK(CudaLaunchKernel(
|
||||||
ResizeBilinearKernel<T>, config.block_count, config.thread_per_block,
|
ResizeBilinearKernel<T>, config.block_count, config.thread_per_block,
|
||||||
0, d.stream(), config.virtual_thread_count, images.data(),
|
0, d.stream(), config.virtual_thread_count, images.data(),
|
||||||
height_scale, width_scale, batch, in_height, in_width, channels,
|
height_scale, width_scale, batch, in_height, in_width, channels,
|
||||||
out_height, out_width, output.data()));
|
out_height, out_width, output.data()));
|
||||||
} else {
|
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
} else {
|
||||||
LegacyResizeBilinearKernel<T>, config.block_count,
|
TF_CHECK_OK(CudaLaunchKernel(
|
||||||
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
LegacyResizeBilinearKernel<T>, config.block_count,
|
||||||
images.data(), height_scale, width_scale, batch, in_height, in_width,
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||||
channels, out_height, out_width, output.data()));
|
images.data(), height_scale, width_scale, batch, in_height, in_width,
|
||||||
|
channels, out_height, out_width, output.data()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
|
// Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ResizeBilinearGrad<GPUDevice, T> {
|
struct ResizeBilinearGrad<GPUDevice, T> {
|
||||||
@ -312,19 +410,19 @@ struct ResizeBilinearGrad<GPUDevice, T> {
|
|||||||
const int resized_width = input_grad.dimension(2);
|
const int resized_width = input_grad.dimension(2);
|
||||||
|
|
||||||
int total_count;
|
int total_count;
|
||||||
CudaLaunchConfig config;
|
GpuLaunchConfig config;
|
||||||
|
|
||||||
// Initialize output_grad with all zeros.
|
// Initialize output_grad with all zeros.
|
||||||
total_count = batch * original_height * original_width * channels;
|
total_count = batch * original_height * original_width * channels;
|
||||||
if (total_count == 0) return;
|
if (total_count == 0) return;
|
||||||
config = GetCudaLaunchConfig(total_count, d);
|
config = GetGpuLaunchConfig(total_count, d);
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
TF_CHECK_OK(CudaLaunchKernel(
|
||||||
SetZero<T>, config.block_count, config.thread_per_block, 0, d.stream(),
|
SetZero<T>, config.block_count, config.thread_per_block, 0, d.stream(),
|
||||||
config.virtual_thread_count, output_grad.data()));
|
config.virtual_thread_count, output_grad.data()));
|
||||||
|
|
||||||
// Accumulate.
|
// Accumulate.
|
||||||
total_count = batch * resized_height * resized_width * channels;
|
total_count = batch * resized_height * resized_width * channels;
|
||||||
config = GetCudaLaunchConfig(total_count, d);
|
config = GetGpuLaunchConfig(total_count, d);
|
||||||
if (half_pixel_centers) {
|
if (half_pixel_centers) {
|
||||||
TF_CHECK_OK(CudaLaunchKernel(
|
TF_CHECK_OK(CudaLaunchKernel(
|
||||||
ResizeBilinearGradKernel<T>, config.block_count,
|
ResizeBilinearGradKernel<T>, config.block_count,
|
||||||
|
Loading…
Reference in New Issue
Block a user