In the CUDA path of depthwise_conv2d, add __launch_bounds__ and use CUDA runtime helpers to determine LaunchConfig which maximizes occupancy.
Change: 155347912
This commit is contained in:
parent
f7935b8f8e
commit
3c02d11007
@ -38,9 +38,9 @@ using Eigen::GpuDevice;
|
|||||||
// in NHWC format.
|
// in NHWC format.
|
||||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||||
int kKnownDepthMultiplier>
|
int kKnownDepthMultiplier>
|
||||||
__global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
|
__global__ void __launch_bounds__(1024, 2)
|
||||||
const T* input, const T* filter,
|
DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* input,
|
||||||
T* output, int num_outputs) {
|
const T* filter, T* output, int num_outputs) {
|
||||||
const int in_rows = args.in_rows;
|
const int in_rows = args.in_rows;
|
||||||
const int in_cols = args.in_cols;
|
const int in_cols = args.in_cols;
|
||||||
const int in_depth = args.in_depth;
|
const int in_depth = args.in_depth;
|
||||||
@ -120,9 +120,9 @@ __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
|
|||||||
// in NCHW format.
|
// in NCHW format.
|
||||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||||
int kKnownDepthMultiplier>
|
int kKnownDepthMultiplier>
|
||||||
__global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
|
__global__ void __launch_bounds__(1024, 2)
|
||||||
const T* input, const T* filter,
|
DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* input,
|
||||||
T* output, int num_outputs) {
|
const T* filter, T* output, int num_outputs) {
|
||||||
const int in_rows = args.in_rows;
|
const int in_rows = args.in_rows;
|
||||||
const int in_cols = args.in_cols;
|
const int in_cols = args.in_cols;
|
||||||
const int in_depth = args.in_depth;
|
const int in_depth = args.in_depth;
|
||||||
@ -250,17 +250,34 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
|
|||||||
TensorFormat data_format) {
|
TensorFormat data_format) {
|
||||||
const int num_outputs =
|
const int num_outputs =
|
||||||
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
||||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d);
|
// The compile-time constant version runs faster with a single block.
|
||||||
|
const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 ||
|
||||||
|
kKnownDepthMultiplier < 0 ||
|
||||||
|
args.out_rows * args.out_cols <= 256
|
||||||
|
? std::numeric_limits<int>::max()
|
||||||
|
: d.getNumCudaMultiProcessors();
|
||||||
if (data_format == FORMAT_NHWC) {
|
if (data_format == FORMAT_NHWC) {
|
||||||
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
|
num_outputs, d,
|
||||||
|
DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
|
||||||
|
kKnownDepthMultiplier>,
|
||||||
|
0);
|
||||||
DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
|
DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
|
||||||
kKnownDepthMultiplier>
|
kKnownDepthMultiplier>
|
||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<std::min(max_block_count, config.block_count),
|
||||||
args, input, filter, output, num_outputs);
|
config.thread_per_block, 0, d.stream()>>>(args, input, filter,
|
||||||
|
output, num_outputs);
|
||||||
} else if (data_format == FORMAT_NCHW) {
|
} else if (data_format == FORMAT_NCHW) {
|
||||||
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
|
num_outputs, d,
|
||||||
|
DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
|
||||||
|
kKnownDepthMultiplier>,
|
||||||
|
0);
|
||||||
DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
|
DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
|
||||||
kKnownDepthMultiplier>
|
kKnownDepthMultiplier>
|
||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<std::min(max_block_count, config.block_count),
|
||||||
args, input, filter, output, num_outputs);
|
config.thread_per_block, 0, d.stream()>>>(args, input, filter,
|
||||||
|
output, num_outputs);
|
||||||
} else {
|
} else {
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
@ -288,9 +305,11 @@ template struct DepthwiseConv2dGPULaunch<double>;
|
|||||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
|
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
|
||||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||||
int kKnownDepthMultiplier>
|
int kKnownDepthMultiplier>
|
||||||
__global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
|
__global__ void __launch_bounds__(640, 2)
|
||||||
const DepthwiseArgs args, const T* out_backprop, const T* filter,
|
DepthwiseConv2dBackpropInputGPUKernelNHWC(const DepthwiseArgs args,
|
||||||
T* in_backprop, int num_in_backprop) {
|
const T* out_backprop,
|
||||||
|
const T* filter, T* in_backprop,
|
||||||
|
int num_in_backprop) {
|
||||||
const int in_rows = args.in_rows;
|
const int in_rows = args.in_rows;
|
||||||
const int in_cols = args.in_cols;
|
const int in_cols = args.in_cols;
|
||||||
const int in_depth = args.in_depth;
|
const int in_depth = args.in_depth;
|
||||||
@ -350,7 +369,7 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
|
|||||||
|
|
||||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||||
int kKnownDepthMultiplier>
|
int kKnownDepthMultiplier>
|
||||||
__global__ void __launch_bounds__(1024)
|
__global__ void __launch_bounds__(640, 2)
|
||||||
DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
|
DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
|
||||||
const T* out_backprop,
|
const T* out_backprop,
|
||||||
const T* filter, T* in_backprop,
|
const T* filter, T* in_backprop,
|
||||||
@ -428,17 +447,22 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
|
|||||||
TensorFormat data_format) {
|
TensorFormat data_format) {
|
||||||
const int num_in_backprop =
|
const int num_in_backprop =
|
||||||
args.batch * args.in_rows * args.in_cols * args.in_depth;
|
args.batch * args.in_rows * args.in_cols * args.in_depth;
|
||||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d);
|
|
||||||
// Increase block count for when there are more warps/SM than threads/SM.
|
|
||||||
// TODO(csigg): this is pretty arbitraty and should be generalized using
|
|
||||||
// cudaOccupancyMaxPotentialBlockSize().
|
|
||||||
config.block_count *= 4;
|
|
||||||
if (data_format == FORMAT_NHWC) {
|
if (data_format == FORMAT_NHWC) {
|
||||||
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
|
num_in_backprop, d,
|
||||||
|
DepthwiseConv2dBackpropInputGPUKernelNHWC<
|
||||||
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>,
|
||||||
|
0);
|
||||||
DepthwiseConv2dBackpropInputGPUKernelNHWC<
|
DepthwiseConv2dBackpropInputGPUKernelNHWC<
|
||||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
args, out_backprop, filter, in_backprop, num_in_backprop);
|
args, out_backprop, filter, in_backprop, num_in_backprop);
|
||||||
} else if (data_format == FORMAT_NCHW) {
|
} else if (data_format == FORMAT_NCHW) {
|
||||||
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
|
num_in_backprop, d,
|
||||||
|
DepthwiseConv2dBackpropInputGPUKernelNCHW<
|
||||||
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>,
|
||||||
|
0);
|
||||||
DepthwiseConv2dBackpropInputGPUKernelNCHW<
|
DepthwiseConv2dBackpropInputGPUKernelNCHW<
|
||||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
@ -475,9 +499,12 @@ template struct DepthwiseConv2dBackpropInputGPULaunch<double>;
|
|||||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||||
int kKnownDepthMultiplier>
|
int kKnownDepthMultiplier>
|
||||||
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
|
__global__ void __launch_bounds__(640, 2)
|
||||||
const DepthwiseArgs args, const T* out_backprop, const T* input,
|
DepthwiseConv2dBackpropFilterGPUKernelNHWC(const DepthwiseArgs args,
|
||||||
T* filter_backprop, int num_out_backprop) {
|
const T* out_backprop,
|
||||||
|
const T* input,
|
||||||
|
T* filter_backprop,
|
||||||
|
int num_out_backprop) {
|
||||||
const int in_rows = args.in_rows;
|
const int in_rows = args.in_rows;
|
||||||
const int in_cols = args.in_cols;
|
const int in_cols = args.in_cols;
|
||||||
const int in_depth = args.in_depth;
|
const int in_depth = args.in_depth;
|
||||||
@ -566,9 +593,12 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
|
|||||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||||
int kKnownDepthMultiplier>
|
int kKnownDepthMultiplier>
|
||||||
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
|
__global__ void __launch_bounds__(640, 2)
|
||||||
const DepthwiseArgs args, const T* out_backprop, const T* input,
|
DepthwiseConv2dBackpropFilterGPUKernelNCHW(const DepthwiseArgs args,
|
||||||
T* filter_backprop, int num_out_backprop) {
|
const T* out_backprop,
|
||||||
|
const T* input,
|
||||||
|
T* filter_backprop,
|
||||||
|
int num_out_backprop) {
|
||||||
const int in_rows = args.in_rows;
|
const int in_rows = args.in_rows;
|
||||||
const int in_cols = args.in_cols;
|
const int in_cols = args.in_cols;
|
||||||
const int in_depth = args.in_depth;
|
const int in_depth = args.in_depth;
|
||||||
@ -669,13 +699,22 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
|
|||||||
TensorFormat data_format) {
|
TensorFormat data_format) {
|
||||||
const int num_out_backprop =
|
const int num_out_backprop =
|
||||||
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
args.batch * args.out_rows * args.out_cols * args.out_depth;
|
||||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d);
|
|
||||||
if (data_format == FORMAT_NHWC) {
|
if (data_format == FORMAT_NHWC) {
|
||||||
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
|
num_out_backprop, d,
|
||||||
|
DepthwiseConv2dBackpropFilterGPUKernelNHWC<
|
||||||
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>,
|
||||||
|
0);
|
||||||
DepthwiseConv2dBackpropFilterGPUKernelNHWC<
|
DepthwiseConv2dBackpropFilterGPUKernelNHWC<
|
||||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
args, out_backprop, input, filter_backprop, num_out_backprop);
|
args, out_backprop, input, filter_backprop, num_out_backprop);
|
||||||
} else if (data_format == FORMAT_NCHW) {
|
} else if (data_format == FORMAT_NCHW) {
|
||||||
|
CudaLaunchConfig config = GetCudaLaunchConfig(
|
||||||
|
num_out_backprop, d,
|
||||||
|
DepthwiseConv2dBackpropFilterGPUKernelNCHW<
|
||||||
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>,
|
||||||
|
0);
|
||||||
DepthwiseConv2dBackpropFilterGPUKernelNCHW<
|
DepthwiseConv2dBackpropFilterGPUKernelNCHW<
|
||||||
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
|
||||||
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
@ -63,6 +63,28 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
|||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate the Cuda launch config we should use for a kernel launch. This
|
||||||
|
// variant takes the resource limits of func into account to maximize occupancy.
|
||||||
|
template <typename DeviceFunc>
|
||||||
|
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||||
|
const GPUDevice& d, DeviceFunc func,
|
||||||
|
size_t dynamic_shared_memory_size) {
|
||||||
|
int block_count = 0;
|
||||||
|
int thread_per_block = 0;
|
||||||
|
cudaOccupancyMaxPotentialBlockSize(&block_count, &thread_per_block, func,
|
||||||
|
dynamic_shared_memory_size,
|
||||||
|
work_element_count);
|
||||||
|
block_count =
|
||||||
|
std::min(block_count,
|
||||||
|
(work_element_count + thread_per_block - 1) / thread_per_block);
|
||||||
|
|
||||||
|
CudaLaunchConfig config;
|
||||||
|
config.virtual_thread_count = work_element_count;
|
||||||
|
config.thread_per_block = thread_per_block;
|
||||||
|
config.block_count = block_count;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
struct Cuda2DLaunchConfig {
|
struct Cuda2DLaunchConfig {
|
||||||
dim3 virtual_thread_count;
|
dim3 virtual_thread_count;
|
||||||
dim3 thread_per_block;
|
dim3 thread_per_block;
|
||||||
|
Loading…
Reference in New Issue
Block a user