In the CUDA path of depthwise_conv2d, use compile-time constants if the filter size and depth_multiplier match the xception model.

Hardening the depthwise_conv2d forward test by using non-uniform filter values.
Change: 154985456
This commit is contained in:
A. Unique TensorFlower 2017-05-03 11:04:22 -08:00 committed by TensorFlower Gardener
parent 87ffdd2d50
commit 7828637e07
2 changed files with 157 additions and 91 deletions

View File

@ -24,28 +24,32 @@ limitations under the License.
#if !defined(_MSC_VER)
#define UNROLL _Pragma("unroll")
#define NOUNROLL _Pragma("nounroll")
#else
#define UNROLL
#define NOUNROLL
#endif
namespace tensorflow {
namespace {
typedef Eigen::GpuDevice GPUDevice;
using Eigen::GpuDevice;
// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
const T* input, const T* filter,
T* output, int num_outputs) {
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows;
const int filter_cols = args.filter_cols;
const int depth_multiplier = args.depth_multiplier;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride;
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
@ -114,16 +118,20 @@ __global__ void DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args,
// A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format.
template <typename T>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
const T* input, const T* filter,
T* output, int num_outputs) {
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows;
const int filter_cols = args.filter_cols;
const int depth_multiplier = args.depth_multiplier;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride;
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
@ -235,29 +243,41 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args,
}
}
} // namespace
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
const T* input, const T* filter, T* output,
TensorFormat data_format) {
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dGPUKernelNHWC<T, kKnownFilterWidth, kKnownFilterHeight,
kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dGPUKernelNCHW<T, kKnownFilterWidth, kKnownFilterHeight,
kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else {
assert(false);
}
}
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
struct DepthwiseConv2dGPULaunch {
static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input,
static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input,
const T* filter, T* output, TensorFormat data_format) {
// In this kernel, each thread is computing the gradients from one element
// in the out_backprop. Note that one element in the out_backprop can map
// to multiple filter elements.
const int num_outputs =
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dGPUKernelNHWC<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dGPUKernelNCHW<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, input, filter, output, num_outputs);
if (args.filter_rows == 3 && args.filter_cols == 3 &&
args.depth_multiplier == 1) {
LaunchDepthwiseConv2dGPU<T, 3, 3, 1>(d, args, input, filter, output,
data_format);
} else {
assert(false);
LaunchDepthwiseConv2dGPU<T, -1, -1, -1>(d, args, input, filter, output,
data_format);
}
}
};
@ -266,18 +286,20 @@ template struct DepthwiseConv2dGPULaunch<float>;
template struct DepthwiseConv2dGPULaunch<double>;
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
template <typename T, int KNOWN_DEPTH_MULTIPLIER>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
const DepthwiseArgs args, const T* out_backprop, const T* filter,
T* in_backprop, int num_in_backprop) {
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows;
const int filter_cols = args.filter_cols;
const int depth_multiplier = KNOWN_DEPTH_MULTIPLIER == -1
? args.depth_multiplier
: KNOWN_DEPTH_MULTIPLIER;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride;
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
@ -301,14 +323,12 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride);
const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride);
#pragma nounroll
for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
NOUNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
const int f_r = in_r + pad_rows - out_r * stride;
const int temp_out_backprop_offset =
out_depth * out_cols * (out_r + out_rows * b);
const int temp_filter_offset = filter_cols * f_r;
#pragma nounroll
for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
NOUNROLL for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
const int f_c = in_c + pad_cols - out_c * stride;
int filter_offset =
depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset));
@ -328,7 +348,8 @@ __global__ void DepthwiseConv2dBackpropInputGPUKernelNHWC(
}
}
template <typename T>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void __launch_bounds__(1024)
DepthwiseConv2dBackpropInputGPUKernelNCHW(const DepthwiseArgs args,
const T* out_backprop,
@ -337,9 +358,12 @@ __global__ void __launch_bounds__(1024)
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows;
const int filter_cols = args.filter_cols;
const int depth_multiplier = args.depth_multiplier;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride;
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
@ -395,34 +419,52 @@ __global__ void __launch_bounds__(1024)
}
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
const DepthwiseArgs args,
const T* out_backprop,
const T* filter, T* in_backprop,
TensorFormat data_format) {
const int num_in_backprop =
args.batch * args.in_rows * args.in_cols * args.in_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d);
// Increase block count for when there are more warps/SM than threads/SM.
// TODO(csigg): this is pretty arbitraty and should be generalized using
// cudaOccupancyMaxPotentialBlockSize().
config.block_count *= 4;
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dBackpropInputGPUKernelNHWC<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropInputGPUKernelNCHW<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else {
assert(false);
}
}
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
struct DepthwiseConv2dBackpropInputGPULaunch {
static void Run(const GPUDevice& d, const DepthwiseArgs args,
static void Run(const GpuDevice& d, const DepthwiseArgs args,
const T* out_backprop, const T* filter, T* in_backprop,
TensorFormat data_format) {
const int num_in_backprop =
args.batch * args.in_rows * args.in_cols * args.in_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_in_backprop, d);
// Increase block count for when there are more warps/SM than threads/SM.
config.block_count *= 4;
if (data_format == FORMAT_NHWC) {
if (args.depth_multiplier == 1) {
DepthwiseConv2dBackpropInputGPUKernelNHWC<T, 1>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
if (args.depth_multiplier == 1) {
if (args.filter_rows == 3 && args.filter_cols == 3) {
LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3, 1>(
d, args, out_backprop, filter, in_backprop, data_format);
} else {
DepthwiseConv2dBackpropInputGPUKernelNHWC<T, -1>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, 1>(
d, args, out_backprop, filter, in_backprop, data_format);
}
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropInputGPUKernelNCHW<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, filter, in_backprop, num_in_backprop);
} else {
assert(false);
LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1, -1>(
d, args, out_backprop, filter, in_backprop, data_format);
}
}
};
@ -431,16 +473,20 @@ template struct DepthwiseConv2dBackpropInputGPULaunch<float>;
template struct DepthwiseConv2dBackpropInputGPULaunch<double>;
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template <typename T>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
const DepthwiseArgs args, const T* out_backprop, const T* input,
T* filter_backprop, int num_out_backprop) {
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows;
const int filter_cols = args.filter_cols;
const int depth_multiplier = args.depth_multiplier;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride;
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
@ -518,16 +564,20 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNHWC(
}
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template <typename T>
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
__global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
const DepthwiseArgs args, const T* out_backprop, const T* input,
T* filter_backprop, int num_out_backprop) {
const int in_rows = args.in_rows;
const int in_cols = args.in_cols;
const int in_depth = args.in_depth;
const int filter_rows = args.filter_rows;
const int filter_cols = args.filter_cols;
const int depth_multiplier = args.depth_multiplier;
const int filter_rows =
kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight;
const int filter_cols =
kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth;
const int depth_multiplier =
kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier;
const int stride = args.stride;
const int pad_rows = args.pad_rows;
const int pad_cols = args.pad_cols;
@ -610,28 +660,44 @@ __global__ void DepthwiseConv2dBackpropFilterGPUKernelNCHW(
}
}
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
int kKnownDepthMultiplier>
void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
const DepthwiseArgs args,
const T* out_backprop,
const T* input, T* filter_backprop,
TensorFormat data_format) {
const int num_out_backprop =
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dBackpropFilterGPUKernelNHWC<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropFilterGPUKernelNCHW<
T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else {
assert(false);
}
}
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
template <typename T>
struct DepthwiseConv2dBackpropFilterGPULaunch {
static void Run(const GPUDevice& d, const DepthwiseArgs args,
static void Run(const GpuDevice& d, const DepthwiseArgs args,
const T* out_backprop, const T* input, T* filter_backprop,
TensorFormat data_format) {
// In this kernel, each thread is computing the gradients for one element in
// the out_backprop.
const int num_out_backprop =
args.batch * args.out_rows * args.out_cols * args.out_depth;
CudaLaunchConfig config = GetCudaLaunchConfig(num_out_backprop, d);
if (data_format == FORMAT_NHWC) {
DepthwiseConv2dBackpropFilterGPUKernelNHWC<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
} else if (data_format == FORMAT_NCHW) {
DepthwiseConv2dBackpropFilterGPUKernelNCHW<T>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
args, out_backprop, input, filter_backprop, num_out_backprop);
if (args.filter_rows == 3 && args.filter_cols == 3 &&
args.depth_multiplier == 1) {
LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3, 1>(
d, args, out_backprop, input, filter_backprop, data_format);
} else {
assert(false);
LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1, -1>(
d, args, out_backprop, input, filter_backprop, data_format);
}
}
};

View File

@ -113,10 +113,9 @@ class DepthwiseConv2DTest(test.TestCase):
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
# Initializes the input and filter tensor with numbers incrementing from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [1.0 for f in range(1, total_size_2 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
with self.test_session(use_gpu=use_gpu) as sess:
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t1.set_shape(tensor_in_sizes)
@ -147,8 +146,9 @@ class DepthwiseConv2DTest(test.TestCase):
native_result = sess.run(conv_native)
interface_result = sess.run(conv_interface)
print("diff matrix:",
np.amax(np.ravel(native_result) - np.ravel(interface_result)))
print("depthwise conv_2d: ", tensor_in_sizes, "*", filter_in_sizes,
", stride:", stride, ", padding: ", padding, ", max diff: ",
np.amax(np.absolute(native_result - interface_result)))
self.assertArrayNear(
np.ravel(native_result), np.ravel(interface_result), 1e-5)
self.assertShapeEqual(native_result, conv_native)