diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 5377d09ec69..b16adf6102b 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -24,28 +24,32 @@ limitations under the License. #if !defined(_MSC_VER) #define UNROLL _Pragma("unroll") +#define NOUNROLL _Pragma("nounroll") #else #define UNROLL +#define NOUNROLL #endif namespace tensorflow { -namespace { - -typedef Eigen::GpuDevice GPUDevice; +using Eigen::GpuDevice; // A Cuda kernel to compute the depthwise convolution forward pass // in NHWC format. -template +template __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, const T* input, const T* filter, T* output, int num_outputs) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = args.filter_rows; - const int filter_cols = args.filter_cols; - const int depth_multiplier = args.depth_multiplier; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; const int pad_rows = args.pad_rows; const int pad_cols = args.pad_cols; @@ -114,16 +118,20 @@ __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. -template +template __global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, const T* input, const T* filter, T* output, int num_outputs) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = args.filter_rows; - const int filter_cols = args.filter_cols; - const int depth_multiplier = args.depth_multiplier; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; const int pad_rows = args.pad_rows; const int pad_cols = args.pad_cols; @@ -235,29 +243,41 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, } } -} // namespace +template +void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, + const T* input, const T* filter, T* output, + TensorFormat data_format) { + const int num_outputs = + args.batch * args.out_rows * args.out_cols * args.out_depth; + CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d); + if (data_format == FORMAT_NHWC) { + DepthwiseConv2dGPUKernelNHWC + <<>>( + args, input, filter, output, num_outputs); + } else if (data_format == FORMAT_NCHW) { + DepthwiseConv2dGPUKernelNCHW + <<>>( + args, input, filter, output, num_outputs); + } else { + assert(false); + } +} // A simple launch pad to launch the Cuda kernel for depthwise convolution. template struct DepthwiseConv2dGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input, + static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input, const T* filter, T* output, TensorFormat data_format) { - // In this kernel, each thread is computing the gradients from one element - // in the out_backprop. Note that one element in the out_backprop can map - // to multiple filter elements. - const int num_outputs = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d); - if (data_format == FORMAT_NHWC) { - DepthwiseConv2dGPUKernelNHWC - <<>>( - args, input, filter, output, num_outputs); - } else if (data_format == FORMAT_NCHW) { - DepthwiseConv2dGPUKernelNCHW - <<>>( - args, input, filter, output, num_outputs); + if (args.filter_rows == 3 && args.filter_cols == 3 && + args.depth_multiplier == 1) { + LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + data_format); } else { - assert(false); + LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + data_format); } } }; @@ -266,18 +286,20 @@ template struct DepthwiseConv2dGPULaunch; template struct DepthwiseConv2dGPULaunch; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. -template +template __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC( const DepthwiseArgs args, const T* out_backprop, const T* filter, T* in_backprop, int num_in_backprop) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = args.filter_rows; - const int filter_cols = args.filter_cols; - const int depth_multiplier = KNOWN_DEPTH_MULTIPLIER == -1 - ? args.depth_multiplier - : KNOWN_DEPTH_MULTIPLIER; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; const int pad_rows = args.pad_rows; const int pad_cols = args.pad_cols; @@ -301,14 +323,12 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC( tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride); const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride); -#pragma nounroll - for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { + NOUNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { const int f_r = in_r + pad_rows - out_r * stride; const int temp_out_backprop_offset = out_depth * out_cols * (out_r + out_rows * b); const int temp_filter_offset = filter_cols * f_r; -#pragma nounroll - for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { + NOUNROLL for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { const int f_c = in_c + pad_cols - out_c * stride; int filter_offset = depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset)); @@ -328,7 +348,8 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC( } } -template +template __global__ void __launch_bounds__(1024) DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args, const T* out_backprop, @@ -337,9 +358,12 @@ __global__ void __launch_bounds__(1024) const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = args.filter_rows; - const int filter_cols = args.filter_cols; - const int depth_multiplier = args.depth_multiplier; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; const int pad_rows = args.pad_rows; const int pad_cols = args.pad_cols; @@ -395,34 +419,52 @@ __global__ void __launch_bounds__(1024) } } +template +void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, + const DepthwiseArgs args, + const T* out_backprop, + const T* filter, T* in_backprop, + TensorFormat data_format) { + const int num_in_backprop = + args.batch * args.in_rows * args.in_cols * args.in_depth; + CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d); + // Increase block count for when there are more warps/SM than threads/SM. + // TODO(csigg): this is pretty arbitraty and should be generalized using + // cudaOccupancyMaxPotentialBlockSize(). + config.block_count *= 4; + if (data_format == FORMAT_NHWC) { + DepthwiseConv2dBackpropInputGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> + <<>>( + args, out_backprop, filter, in_backprop, num_in_backprop); + } else if (data_format == FORMAT_NCHW) { + DepthwiseConv2dBackpropInputGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> + <<>>( + args, out_backprop, filter, in_backprop, num_in_backprop); + } else { + assert(false); + } +} + // A simple launch pad to launch the Cuda kernel for depthwise convolution. template struct DepthwiseConv2dBackpropInputGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, + static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { - const int num_in_backprop = - args.batch * args.in_rows * args.in_cols * args.in_depth; - - CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d); - // Increase block count for when there are more warps/SM than threads/SM. - config.block_count *= 4; - if (data_format == FORMAT_NHWC) { - if (args.depth_multiplier == 1) { - DepthwiseConv2dBackpropInputGPUKernelNHWC - <<>>( - args, out_backprop, filter, in_backprop, num_in_backprop); + if (args.depth_multiplier == 1) { + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dBackpropInputGPU( + d, args, out_backprop, filter, in_backprop, data_format); } else { - DepthwiseConv2dBackpropInputGPUKernelNHWC - <<>>( - args, out_backprop, filter, in_backprop, num_in_backprop); + LaunchDepthwiseConv2dBackpropInputGPU( + d, args, out_backprop, filter, in_backprop, data_format); } - } else if (data_format == FORMAT_NCHW) { - DepthwiseConv2dBackpropInputGPUKernelNCHW - <<>>( - args, out_backprop, filter, in_backprop, num_in_backprop); } else { - assert(false); + LaunchDepthwiseConv2dBackpropInputGPU( + d, args, out_backprop, filter, in_backprop, data_format); } } }; @@ -431,16 +473,20 @@ template struct DepthwiseConv2dBackpropInputGPULaunch; template struct DepthwiseConv2dBackpropInputGPULaunch; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. -template +template __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC( const DepthwiseArgs args, const T* out_backprop, const T* input, T* filter_backprop, int num_out_backprop) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = args.filter_rows; - const int filter_cols = args.filter_cols; - const int depth_multiplier = args.depth_multiplier; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; const int pad_rows = args.pad_rows; const int pad_cols = args.pad_cols; @@ -518,16 +564,20 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC( } // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. -template +template __global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW( const DepthwiseArgs args, const T* out_backprop, const T* input, T* filter_backprop, int num_out_backprop) { const int in_rows = args.in_rows; const int in_cols = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = args.filter_rows; - const int filter_cols = args.filter_cols; - const int depth_multiplier = args.depth_multiplier; + const int filter_rows = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_cols = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; const int pad_rows = args.pad_rows; const int pad_cols = args.pad_cols; @@ -610,28 +660,44 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW( } } +template +void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, + const DepthwiseArgs args, + const T* out_backprop, + const T* input, T* filter_backprop, + TensorFormat data_format) { + const int num_out_backprop = + args.batch * args.out_rows * args.out_cols * args.out_depth; + CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d); + if (data_format == FORMAT_NHWC) { + DepthwiseConv2dBackpropFilterGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> + <<>>( + args, out_backprop, input, filter_backprop, num_out_backprop); + } else if (data_format == FORMAT_NCHW) { + DepthwiseConv2dBackpropFilterGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> + <<>>( + args, out_backprop, input, filter_backprop, num_out_backprop); + } else { + assert(false); + } +} + // A simple launch pad to launch the Cuda kernel for depthwise convolution. template struct DepthwiseConv2dBackpropFilterGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, + static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - // In this kernel, each thread is computing the gradients for one element in - // the out_backprop. - const int num_out_backprop = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d); - - if (data_format == FORMAT_NHWC) { - DepthwiseConv2dBackpropFilterGPUKernelNHWC - <<>>( - args, out_backprop, input, filter_backprop, num_out_backprop); - } else if (data_format == FORMAT_NCHW) { - DepthwiseConv2dBackpropFilterGPUKernelNCHW - <<>>( - args, out_backprop, input, filter_backprop, num_out_backprop); + if (args.filter_rows == 3 && args.filter_cols == 3 && + args.depth_multiplier == 1) { + LaunchDepthwiseConv2dBackpropFilterGPU( + d, args, out_backprop, input, filter_backprop, data_format); } else { - assert(false); + LaunchDepthwiseConv2dBackpropFilterGPU( + d, args, out_backprop, input, filter_backprop, data_format); } } }; diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index a881ed0dc9a..2fc34bd4d17 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -113,10 +113,9 @@ class DepthwiseConv2DTest(test.TestCase): total_size_1 *= s for s in filter_in_sizes: total_size_2 *= s - # Initializes the input tensor with array containing incrementing - # numbers from 1. + # Initializes the input and filter tensor with numbers incrementing from 1. x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] - x2 = [1.0 for f in range(1, total_size_2 + 1)] + x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] with self.test_session(use_gpu=use_gpu) as sess: t1 = constant_op.constant(x1, shape=tensor_in_sizes) t1.set_shape(tensor_in_sizes) @@ -147,8 +146,9 @@ class DepthwiseConv2DTest(test.TestCase): native_result = sess.run(conv_native) interface_result = sess.run(conv_interface) - print("diff matrix:", - np.amax(np.ravel(native_result) - np.ravel(interface_result))) + print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes, + ", stride:", stride, ", padding: ", padding, ", max diff: ", + np.amax(np.absolute(native_result - interface_result))) self.assertArrayNear( np.ravel(native_result), np.ravel(interface_result), 1e-5) self.assertShapeEqual(native_result, conv_native)