From 3c02d1100788789b04e04feb93761f0ad898ea77 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 7 May 2017 22:47:19 -0800 Subject: [PATCH] In the CUDA path of depthwise_conv2d, add __launch_bounds__ and use CUDA runtime helpers to determine LaunchConfig which maximizes occupancy. Change: 155347912 --- .../core/kernels/depthwise_conv_op_gpu.cu.cc | 93 +++++++++++++------ tensorflow/core/util/cuda_kernel_helper.h | 22 +++++ 2 files changed, 88 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index b16adf6102b..051d4772449 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -38,9 +38,9 @@ using Eigen::GpuDevice; // in NHWC format. template -__global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, - const T* input, const T* filter, - T* output, int num_outputs) { +__global__ void __launch_bounds__(1024, 2) + DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* input, + const T* filter, T* output, int num_outputs) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; @@ -120,9 +120,9 @@ __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, // in NCHW format. template -__global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, - const T* input, const T* filter, - T* output, int num_outputs) { +__global__ void __launch_bounds__(1024, 2) + DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* input, + const T* filter, T* output, int num_outputs) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; @@ -250,17 +250,34 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, TensorFormat data_format) { const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d); + // The compile-time constant version runs faster with a single block. + const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || + kKnownDepthMultiplier < 0 || + args.out_rows * args.out_cols <= 256 + ? std::numeric_limits::max() + : d.getNumCudaMultiProcessors(); if (data_format == FORMAT_NHWC) { + CudaLaunchConfig config = GetCudaLaunchConfig( + num_outputs, d, + DepthwiseConv2dGPUKernelNHWC, + 0); DepthwiseConv2dGPUKernelNHWC - <<>>( - args, input, filter, output, num_outputs); + <<>>(args, input, filter, + output, num_outputs); } else if (data_format == FORMAT_NCHW) { + CudaLaunchConfig config = GetCudaLaunchConfig( + num_outputs, d, + DepthwiseConv2dGPUKernelNCHW, + 0); DepthwiseConv2dGPUKernelNCHW - <<>>( - args, input, filter, output, num_outputs); + <<>>(args, input, filter, + output, num_outputs); } else { assert(false); } @@ -288,9 +305,11 @@ template struct DepthwiseConv2dGPULaunch; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. template -__global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC( - const DepthwiseArgs args, const T* out_backprop, const T* filter, - T* in_backprop, int num_in_backprop) { +__global__ void __launch_bounds__(640, 2) + DepthwiseConv2dBackpropInputGPUKernelNHWC(const DepthwiseArgs args, + const T* out_backprop, + const T* filter, T* in_backprop, + int num_in_backprop) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; @@ -350,7 +369,7 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC( template -__global__ void __launch_bounds__(1024) +__global__ void __launch_bounds__(640, 2) DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args, const T* out_backprop, const T* filter, T* in_backprop, @@ -428,17 +447,22 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, TensorFormat data_format) { const int num_in_backprop = args.batch * args.in_rows * args.in_cols * args.in_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d); - // Increase block count for when there are more warps/SM than threads/SM. - // TODO(csigg): this is pretty arbitraty and should be generalized using - // cudaOccupancyMaxPotentialBlockSize(). - config.block_count *= 4; if (data_format == FORMAT_NHWC) { + CudaLaunchConfig config = GetCudaLaunchConfig( + num_in_backprop, d, + DepthwiseConv2dBackpropInputGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, + 0); DepthwiseConv2dBackpropInputGPUKernelNHWC< T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> <<>>( args, out_backprop, filter, in_backprop, num_in_backprop); } else if (data_format == FORMAT_NCHW) { + CudaLaunchConfig config = GetCudaLaunchConfig( + num_in_backprop, d, + DepthwiseConv2dBackpropInputGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, + 0); DepthwiseConv2dBackpropInputGPUKernelNCHW< T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> <<>>( @@ -475,9 +499,12 @@ template struct DepthwiseConv2dBackpropInputGPULaunch; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template -__global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC( - const DepthwiseArgs args, const T* out_backprop, const T* input, - T* filter_backprop, int num_out_backprop) { +__global__ void __launch_bounds__(640, 2) + DepthwiseConv2dBackpropFilterGPUKernelNHWC(const DepthwiseArgs args, + const T* out_backprop, + const T* input, + T* filter_backprop, + int num_out_backprop) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; @@ -566,9 +593,12 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC( // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template -__global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW( - const DepthwiseArgs args, const T* out_backprop, const T* input, - T* filter_backprop, int num_out_backprop) { +__global__ void __launch_bounds__(640, 2) + DepthwiseConv2dBackpropFilterGPUKernelNCHW(const DepthwiseArgs args, + const T* out_backprop, + const T* input, + T* filter_backprop, + int num_out_backprop) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; @@ -669,13 +699,22 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, TensorFormat data_format) { const int num_out_backprop = args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d); if (data_format == FORMAT_NHWC) { + CudaLaunchConfig config = GetCudaLaunchConfig( + num_out_backprop, d, + DepthwiseConv2dBackpropFilterGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, + 0); DepthwiseConv2dBackpropFilterGPUKernelNHWC< T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> <<>>( args, out_backprop, input, filter_backprop, num_out_backprop); } else if (data_format == FORMAT_NCHW) { + CudaLaunchConfig config = GetCudaLaunchConfig( + num_out_backprop, d, + DepthwiseConv2dBackpropFilterGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, + 0); DepthwiseConv2dBackpropFilterGPUKernelNCHW< T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> <<>>( diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 8a3f6c587ed..46ea68687c7 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -63,6 +63,28 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, return config; } +// Calculate the Cuda launch config we should use for a kernel launch. This +// variant takes the resource limits of func into account to maximize occupancy. +template +inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, + const GPUDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size) { + int block_count = 0; + int thread_per_block = 0; + cudaOccupancyMaxPotentialBlockSize(&block_count, &thread_per_block, func, + dynamic_shared_memory_size, + work_element_count); + block_count = + std::min(block_count, + (work_element_count + thread_per_block - 1) / thread_per_block); + + CudaLaunchConfig config; + config.virtual_thread_count = work_element_count; + config.thread_per_block = thread_per_block; + config.block_count = block_count; + return config; +} + struct Cuda2DLaunchConfig { dim3 virtual_thread_count; dim3 thread_per_block;