In the CUDA path of depthwise_conv2d, optimize backward filter convolution for images 2 or 4 times smaller than 16x16. Also initialize in_cols from blockDim, to fix the regression caused in CL 157906773.

PiperOrigin-RevId: 158296136
This commit is contained in:
A. Unique TensorFlower 2017-06-07 11:45:30 -07:00 committed by TensorFlower Gardener
parent 492afc2e37
commit f105df0478

View File

@ -1308,9 +1308,11 @@ __global__ void __launch_bounds__(640, 2)
// a partial convolution for two elements, one each in the lower and upper half
// of a tile. The intermediate result of 4 consecutive columns are then
// accumulated and written to shared memory. Finally, the values in shared
// memory are warp-accumulated (in chunks of 32 elements) and summed up in
// global memory using atomics.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
// up in global memory using atomics.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
// Requirement: kAccumPixels * 8 >= args.in_rows * args.in_cols
int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
@ -1321,7 +1323,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int batches = args.batch;
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_cols = blockDim.y; // slower (see b/62280718): args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
@ -1352,8 +1354,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const int tensor_offset = block_rows * in_row_size;
// The accumulator has a fixed number of pixels that can be reduced by one
// warp. Pixels beyond block_pixels/4 are never written.
const int accum_pixels = 32;
const int accum_increment = accum_pixels * block_slices;
const int accum_increment = kAccumPixels * block_slices;
const int accum_size = filter_pixels * accum_increment;
const int thread_depth = threadIdx.x;
@ -1383,7 +1384,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Position in accumulator (1 per 4 threads, depth major).
const int accum_pix = thread_pix / 4;
const int accum_idx = thread_depth * accum_pixels + accum_pix;
const int accum_idx = thread_depth * kAccumPixels + accum_pix;
const int max_depth = in_depth - thread_depth;
const int accum_offset = tile_size + accum_idx;
@ -1438,19 +1439,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
const int filter_idx = i / accum_pixels;
const int filter_idx = i / kAccumPixels;
const int filter_pix = filter_idx / block_slices;
const int filter_depth = filter_idx % block_slices + start_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
// Sum up the 32 pixels of the same depth from the accumulator.
val += CudaShuffleDown(val, 16);
val += CudaShuffleDown(val, 8);
val += CudaShuffleDown(val, 4);
val += CudaShuffleDown(val, 2);
val += CudaShuffleDown(val, 1);
if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) {
// Warp-accumulate the pixels of the same depth from the accumulator.
for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
val += CudaShuffleDown(val, delta);
}
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
}
@ -1567,9 +1566,11 @@ __global__ void __launch_bounds__(640, 2)
// a partial convolution for two elements, one each in the lower and upper half
// of a tile. The intermediate result of 4 consecutive columns are then
// accumulated and written to shared memory. Finally, the values in shared
// memory are warp-accumulated (in chunks of 32 elements) and summed up in
// global memory using atomics.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed
// up in global memory using atomics.
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
// Requirement: kAccumPixels * 8 >= args.in_rows * args.in_cols
int kAccumPixels>
__global__
__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const DepthwiseArgs args, const T* output, const T* input, T* filter) {
@ -1580,7 +1581,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const int batches = args.batch;
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_cols = blockDim.x; // slower (see b/62280718): args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
@ -1610,8 +1611,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const int in_blocks = (in_slices + block_slices - 1) / block_slices;
// The accumulator has a fixed number of pixels that can be reduced by one
// warp. Pixels beyond block_pixels/4 are never written.
const int accum_pixels = 32;
const int accum_increment = accum_pixels * block_slices;
const int accum_increment = kAccumPixels * block_slices;
const int accum_size = filter_pixels * accum_increment;
const int thread_col = threadIdx.x;
@ -1640,7 +1640,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Position in accumulator (1 per 4 threads, depth major).
const int accum_pix = thread_pix / 4;
const int accum_idx = thread_depth * accum_pixels + accum_pix;
const int accum_idx = thread_depth * kAccumPixels + accum_pix;
const int max_slice = in_slices - thread_depth;
const int accum_offset = tile_size + accum_idx;
@ -1692,19 +1692,17 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
const T* const accum_data = tile_size + shared_data;
for (int i = thread_idx; i < accum_size; i += block_size) {
const int filter_idx = i / accum_pixels;
const int filter_idx = i / kAccumPixels;
const int filter_pix = filter_idx / block_slices;
const int filter_depth = (slice + filter_idx % block_slices) % in_depth;
const int filter_offset = filter_pix * in_depth + filter_depth;
if (filter_depth < in_depth) {
T val = accum_data[i];
// Sum up 32 pixels of the same depth from the accumulator.
val += CudaShuffleDown(val, 16);
val += CudaShuffleDown(val, 8);
val += CudaShuffleDown(val, 4);
val += CudaShuffleDown(val, 2);
val += CudaShuffleDown(val, 1);
if (!(thread_idx & 31) /* i.e. 'lane_idx == 0' */) {
// Warp-accumulate pixels of the same depth from the accumulator.
for (int delta = kAccumPixels / 2; delta > 0; delta /= 2) {
val += CudaShuffleDown(val, delta);
}
if (!(thread_idx & kAccumPixels - 1)) {
CudaAtomicAdd(filter_offset + filter, val);
}
}
@ -1712,7 +1710,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
}
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kAccumPixels>
void LaunchDepthwiseConv2dBackpropFilterGPUSmall(
const GpuDevice& d, const DepthwiseArgs args, int block_rows,
int shared_memory_size, const T* out_backprop, const T* input,
@ -1724,22 +1723,22 @@ void LaunchDepthwiseConv2dBackpropFilterGPUSmall(
dim3 block_dim = dim3(block_slices, args.in_cols, block_rows);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_out_backprop, d,
DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
kKnownFilterHeight>,
DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<T, kKnownFilterWidth,
kKnownFilterHeight>
DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, input, filter_backprop);
} else if (data_format == FORMAT_NCHW) {
dim3 block_dim = dim3(args.in_cols, block_rows, block_slices);
CudaLaunchConfig config = GetCudaLaunchConfig(
num_out_backprop, d,
DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<T, kKnownFilterWidth,
kKnownFilterHeight>,
DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>,
shared_memory_size, block_dim.x * block_dim.y * block_dim.z);
DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<T, kKnownFilterWidth,
kKnownFilterHeight>
DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall<
T, kKnownFilterWidth, kKnownFilterHeight, kAccumPixels>
<<<config.block_count, block_dim, shared_memory_size, d.stream()>>>(
args, out_backprop, input, filter_backprop);
} else {
@ -1759,21 +1758,39 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall(
return false;
}
const int in_pixels = args.in_rows * args.in_cols;
int accum_pixels = 8;
while (accum_pixels * 8 < in_pixels) {
accum_pixels *= 2;
}
const int block_slices = 8;
const int tile_cols = args.in_cols + args.filter_cols - 1;
const int tile_rows = block_rows * 2 + args.filter_rows - 1;
const int tile_pixels = tile_rows * tile_cols;
const int accum_size = args.filter_rows * args.filter_cols * 32;
const int filter_pixels = args.filter_rows * args.filter_cols;
const int shared_memory_size =
block_slices * (tile_pixels + accum_size) * sizeof(T);
block_slices * (tile_pixels + filter_pixels * accum_pixels) * sizeof(T);
if (shared_memory_size > d.sharedMemPerBlock()) {
return false;
}
LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
kKnownFilterHeight>(
d, args, block_rows, shared_memory_size, out_backprop, input,
filter_backprop, data_format);
if (accum_pixels == 8) {
LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
kKnownFilterHeight, 8>(
d, args, block_rows, shared_memory_size, out_backprop, input,
filter_backprop, data_format);
} else if (accum_pixels == 16) {
LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
kKnownFilterHeight, 16>(
d, args, block_rows, shared_memory_size, out_backprop, input,
filter_backprop, data_format);
} else {
LaunchDepthwiseConv2dBackpropFilterGPUSmall<T, kKnownFilterWidth,
kKnownFilterHeight, 32>(
d, args, block_rows, shared_memory_size, out_backprop, input,
filter_backprop, data_format);
}
return true;
}