Adding a depthwise convolution kernel op (with label 'cudnn_grouped_convolution') which forwards to cuDNN grouped convolutions.

PiperOrigin-RevId: 194780352
This commit is contained in:
A. Unique TensorFlower 2018-04-30 06:59:23 -07:00 committed by TensorFlower Gardener
parent 914796d5e9
commit a5a51ad3a1
11 changed files with 632 additions and 233 deletions

View File

@ -3299,7 +3299,10 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
] + if_cuda(["@cub_archive//:cub"]),
] + if_cuda([
"@cub_archive//:cub",
"@local_config_cuda//cuda:cudnn",
]),
)
tf_kernel_library(
@ -3310,12 +3313,15 @@ tf_kernel_library(
prefix = "depthwise_conv_grad_op",
deps = [
":bounds_check",
":conv_ops",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
],
] + if_cuda([
"@local_config_cuda//cuda:cudnn",
]),
)
cc_library(

View File

@ -96,7 +96,8 @@ template <typename T>
struct LaunchConv2DBackpropFilterOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
int row_stride, int col_stride, const Padding& padding,
int row_dilation, int col_dilation, int row_stride,
int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format) {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()(
@ -275,7 +276,8 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
#endif
LaunchConv2DBackpropFilterOp<Device, T>()(
context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
context, false, false, out_backprop, input,
/*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
}
@ -523,6 +525,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
// To be used inside depthwise_conv_grad_op.cc.
template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
// GPU definitions.
#if GOOGLE_CUDA
// The slow version (but compiles for GPU)
@ -690,10 +697,15 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
return;
}
// If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
// input depth, it's a depthwise convolution. More generally, if the filter
// in-depth divides but is smaller than the input depth, it is a grouped
// convolution.
bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
if (!cudnn_disable_conv_1x1_optimization_ &&
dims.spatial_dims[0].filter_size == 1 &&
dims.spatial_dims[1].filter_size == 1 &&
dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
data_format == FORMAT_NHWC) {
const uint64 m = dims.in_depth;
@ -734,9 +746,10 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
dims.spatial_dims[0].input_size &&
dims.spatial_dims[1].filter_size ==
dims.spatial_dims[1].input_size &&
padding == VALID && data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
!is_grouped_convolution && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, and we are not
// using grouped convolution, so call cublas directly.
const uint64 m = dims.spatial_dims[0].input_size *
dims.spatial_dims[1].input_size * dims.in_depth;
const uint64 k = dims.batch_size;
@ -802,15 +815,16 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
.set_input_feature_map_count(filter_shape.dim_size(2))
.set_output_feature_map_count(filter_shape.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
.set_zero_padding_width(padding_cols / 2)
.set_group_count(dims.in_depth / filter_shape.dim_size(2));
// NOTE(zhengxq):
// cuDNN only supports the following layouts :
@ -891,21 +905,22 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
ConvParameters conv_parameters = {
dims.batch_size, // batch
dims.in_depth, // in_depths
{{input_desc.height(), // in_rows
input_desc.width()}}, // in_cols
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size}}, // filter_cols
{{dims.spatial_dims[0].dilation, // dilation_rows
dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
padding_cols}}, // padding_cols
dtype, // tensor datatype
device_id, // device_id
dims.batch_size, // batch
dims.in_depth, // in_depths
{{input_desc.height(), // in_rows
input_desc.width()}}, // in_cols
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size, // filter_cols
filter_shape.dim_size(2)}}, // filter_depth
{{dims.spatial_dims[0].dilation, // dilation_rows
dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
padding_cols}}, // padding_cols
dtype, // tensor datatype
device_id, // device_id
};
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
@ -1019,9 +1034,9 @@ namespace functor {
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;
DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
@ -1040,6 +1055,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
.TypeConstraint<Eigen::half>("T")
.HostMemory("filter_sizes"),
Conv2DSlowBackpropFilterOp<GPUDevice, Eigen::half>);
// To be used inside depthwise_conv_grad_op.cc.
template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -101,8 +101,9 @@ template <typename T>
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
int row_stride, int col_stride, const Padding& padding,
Tensor* in_backprop, TensorFormat data_format) {
int row_dilation, int col_dilation, int row_stride,
int col_stride, const Padding& padding, Tensor* in_backprop,
TensorFormat data_format) {
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
@ -280,8 +281,8 @@ class Conv2DFastBackpropInputOp : public OpKernel {
LaunchConv2DBackpropInputOp<Device, T>()(
context, false, false, out_backprop, filter,
dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
in_backprop, data_format_);
/*row_dilation=*/1, /*col_dilation=*/1, dims.spatial_dims[0].stride,
dims.spatial_dims[1].stride, padding_, in_backprop, data_format_);
}
private:
@ -595,6 +596,11 @@ TF_CALL_float(REGISTER_CPU_KERNELS);
TF_CALL_double(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
// To be used inside depthwise_conv_grad_op.cc.
template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
// GPU definitions.
#if GOOGLE_CUDA
// The slow version (but compiles for GPU)
@ -761,8 +767,13 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
return;
}
// If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
// input depth, it's a depthwise convolution. More generally, if the filter
// in-depth divides but is smaller than the input depth, it is a grouped
// convolution.
bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
if (dims.spatial_dims[0].filter_size == 1 &&
dims.spatial_dims[1].filter_size == 1 &&
dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
@ -795,9 +806,10 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
dims.spatial_dims[0].input_size &&
dims.spatial_dims[1].filter_size ==
dims.spatial_dims[1].input_size &&
padding == VALID && data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
!is_grouped_convolution && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, and we are not
// using grouped convolution, so call cublas directly.
const uint64 m = dims.batch_size;
const uint64 k = dims.out_depth;
const uint64 n = dims.spatial_dims[0].input_size *
@ -856,15 +868,16 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
.set_input_feature_map_count(filter_shape.dim_size(2))
.set_output_feature_map_count(filter_shape.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
.set_zero_padding_width(padding_cols / 2)
.set_group_count(dims.in_depth / filter_shape.dim_size(2));
// NOTE(keveman):
// cuDNN only supports the following layouts :
@ -940,21 +953,22 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
int device_id = stream->parent()->device_ordinal();
DataType dtype = out_backprop.dtype();
ConvParameters conv_parameters = {
dims.batch_size, // batch
dims.in_depth, // in_depths
{{input_desc.height(), // in_rows
input_desc.width()}}, // in_cols
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size}}, // filter_cols
{{dims.spatial_dims[0].dilation, // dilation_rows
dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
padding_cols}}, // padding_cols
dtype, // tensor data type
device_id, // device_id
dims.batch_size, // batch
dims.in_depth, // in_depths
{{input_desc.height(), // in_rows
input_desc.width()}}, // in_cols
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size, // filter_cols
filter_shape.dim_size(2)}}, // filter_depths
{{dims.spatial_dims[0].dilation, // dilation_rows
dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
padding_cols}}, // padding_cols
dtype, // tensor data type
device_id, // device_id
};
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
@ -1092,9 +1106,9 @@ namespace functor {
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;
DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
@ -1113,6 +1127,12 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
.TypeConstraint<Eigen::half>("T")
.HostMemory("input_sizes"),
Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>);
// To be used inside depthwise_conv_grad_op.cc.
template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -127,16 +127,17 @@ Status ConvBackpropComputeDimensionsV2(
dims->in_depth = input_shape.dim_size(feature_dim);
// The input and output feature dimensions are the second last and last
// dimensions of the filter Tensor.
if (dims->in_depth != filter_shape.dim_size(num_dims - 2)) {
VLOG(2) << "input vs filter_in depth " << dims->in_depth << " "
<< filter_shape.dim_size(num_dims - 2);
if (dims->in_depth % filter_shape.dim_size(num_dims - 2)) {
return errors::InvalidArgument(
label, ": input and filter must have the same depth");
label, ": input depth must be evenly divisible by filter depth");
}
dims->out_depth = filter_shape.dim_size(num_dims - 1);
if (dims->out_depth != out_backprop_shape.dim_size(feature_dim)) {
return errors::InvalidArgument(
label, ": filter and out_backprop must have the same out_depth");
}
dims->spatial_dims.resize(num_spatial_dims);
for (int i = 0; i < num_spatial_dims; ++i) {
int image_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);

View File

@ -18,10 +18,16 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#include "tensorflow/core/kernels/conv_ops.h"
#include <string.h>
#include <map>
#include <vector>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -32,9 +38,6 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/deep_conv2d.h"
#include "tensorflow/core/kernels/ops_util.h"
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
#include "tensorflow/core/kernels/xsmm_conv2d.h"
#endif
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
@ -45,6 +48,10 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
#include "tensorflow/core/kernels/xsmm_conv2d.h"
#endif
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
@ -123,6 +130,10 @@ struct LaunchConv2DOp<CPUDevice, T> {
"NHWC tensor format for now."));
return;
}
const int64 in_depth = GetTensorDim(input, data_format, 'C');
OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
errors::Unimplemented("Generic conv implementation does not "
"support grouped convolutions for now."));
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
row_dilation, col_dilation, padding, output,
data_format);
@ -324,12 +335,13 @@ class Conv2DOp : public BinaryOp<T> {
}
// The last dimension for input is in_depth. It must be the same as the
// filter's in_depth.
// filter's in_depth or be evenly divisible by filter's in_depth.
const int64 in_depth = GetTensorDim(input, data_format_, 'C');
OP_REQUIRES(context, in_depth == filter.dim_size(2),
const int64 patch_depth = filter.dim_size(2);
OP_REQUIRES(context, in_depth % patch_depth == 0,
errors::InvalidArgument(
"input and filter must have the same depth: ", in_depth,
" vs ", filter.dim_size(2)));
"input depth must be evenly divisible by filter depth: ",
in_depth, " vs ", patch_depth));
// The last dimension for filter is out_depth.
const int out_depth = static_cast<int>(filter.dim_size(3));
@ -386,6 +398,7 @@ class Conv2DOp : public BinaryOp<T> {
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
VLOG(2) << "Conv2D: in_depth = " << in_depth
<< ", patch_depth = " << patch_depth
<< ", input_cols = " << input_cols
<< ", filter_cols = " << filter_cols
<< ", input_rows = " << input_rows
@ -450,7 +463,9 @@ TF_CALL_double(REGISTER_CPU);
#endif // USE_GEMM_FOR_CONV
// To be used inside depthwise_conv_op.cc.
template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DOp<CPUDevice, float>;
template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
@ -498,13 +513,24 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
}
Tensor input = input_param;
const int64 in_batch = GetTensorDim(input, data_format, 'N');
int64 in_rows = GetTensorDim(input, data_format, 'H');
int64 in_cols = GetTensorDim(input, data_format, 'W');
const int64 in_depths = GetTensorDim(input, data_format, 'C');
const int64 patch_rows = filter.dim_size(0);
const int64 patch_cols = filter.dim_size(1);
const int64 patch_depths = filter.dim_size(2);
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
data_format == FORMAT_NHWC) {
// If the filter in-depth (patch_depths) is 1 and smaller than the input
// depth, it's a depthwise convolution. More generally, if the filter in-depth
// divides but is smaller than the input depth, it is a grouped convolution.
bool is_grouped_convolution = patch_depths != in_depths;
if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
col_stride == 1 && data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
const uint64 k = filter.dim_size(2);
const uint64 m = in_batch * in_rows * in_cols;
const uint64 k = patch_depths;
const uint64 n = filter.dim_size(3);
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
@ -525,15 +551,14 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
", n=", n, ", k=", k));
}
return;
} else if (filter.dim_size(0) == input.dim_size(1) &&
filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
} else if (patch_rows == in_rows && patch_cols == in_cols &&
!is_grouped_convolution && row_dilation == 1 &&
col_dilation == 1 && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
const uint64 m = input.dim_size(0);
const uint64 k =
filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
const uint64 m = in_batch;
const uint64 k = patch_rows * patch_cols * patch_depths;
const uint64 n = filter.dim_size(3);
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
@ -558,16 +583,10 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
int padding_rows = 0;
int padding_cols = 0;
const int64 in_batch = GetTensorDim(input, data_format, 'N');
int64 in_rows = GetTensorDim(input, data_format, 'H');
int64 in_cols = GetTensorDim(input, data_format, 'W');
const int64 in_depths = GetTensorDim(input, data_format, 'C');
const int64 out_batch = GetTensorDim(*output, data_format, 'N');
const int64 out_rows = GetTensorDim(*output, data_format, 'H');
const int64 out_cols = GetTensorDim(*output, data_format, 'W');
const int64 out_depths = GetTensorDim(*output, data_format, 'C');
const int64 patch_rows = filter.dim_size(0);
const int64 patch_cols = filter.dim_size(1);
if (padding == SAME) {
// Total padding on rows and cols is
// Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
@ -642,9 +661,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_feature_map_count(out_depths)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(filter.dim_size(0))
.set_input_filter_width(filter.dim_size(1))
.set_input_feature_map_count(filter.dim_size(2))
filter_desc.set_input_filter_height(patch_rows)
.set_input_filter_width(patch_cols)
.set_input_feature_map_count(patch_depths)
.set_output_feature_map_count(filter.dim_size(3));
se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(row_dilation)
@ -652,7 +671,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
.set_zero_padding_width(padding_cols / 2)
.set_group_count(in_depths / patch_depths);
Tensor transformed_filter;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
@ -695,7 +715,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
in_cols}}, // in_cols
out_depths, // out_depths
{{patch_rows, // filter_rows
patch_cols}}, // filter_cols
patch_cols, // filter_cols
patch_depths}}, // filter_depths
{{row_dilation, // dilation_rows
col_dilation}}, // dilation_cols
{{row_stride, // stride_rows
@ -812,9 +833,9 @@ namespace functor {
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>
DECLARE_GPU_SPEC(double);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(double);
#undef DECLARE_GPU_SPEC
} // namespace functor
@ -830,7 +851,9 @@ REGISTER_KERNEL_BUILDER(
Conv2DOp<GPUDevice, double>);
// To be used inside depthwise_conv_op.cc.
template class LaunchConv2DOp<GPUDevice, float>;
template struct LaunchConv2DOp<GPUDevice, float>;
template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
template struct LaunchConv2DOp<GPUDevice, double>;
#endif // GOOGLE_CUDA

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
@ -33,9 +34,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "cuda/include/cudnn.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -509,8 +512,19 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
}
}
// Extern template instantiated in conv_grad_input_ops.cc.
extern template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
extern template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
extern template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
#if GOOGLE_CUDA
// Extern template instantiated in conv_grad_input_ops.cc.
extern template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
extern template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
extern template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
@ -548,6 +562,12 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
// For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
use_cudnn_grouped_conv_ = false;
dtype_ = DataTypeToEnum<T>::value;
}
void Compute(OpKernelContext* context) override {
@ -560,6 +580,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
input_sizes.dims()));
TensorShape input_shape;
const int32* in_sizes_data = input_sizes.template flat<int32>().data();
for (int i = 0; i < input_sizes.NumElements(); ++i) {
OP_REQUIRES(context, in_sizes_data[i] >= 0,
errors::InvalidArgument("Dimension ", i,
@ -568,27 +589,77 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
}
const TensorShape& filter_shape = filter.shape();
EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
Tensor* in_backprop = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, input_shape, &in_backprop));
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto filter_ptr = filter.template flat<T>().data();
auto in_backprop_ptr = in_backprop->template flat<T>().data();
// If there is nothing to compute, return.
if (input_shape.num_elements() == 0) {
return;
}
// If in_depth==1, this operation is just a standard convolution.
// Depthwise convolution is a special case of cuDNN's grouped convolution.
bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
<< ", " << in_depth << "]; Filter: [" << filter_rows << ", "
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
<< ", " << out_depth << "], stride = " << stride_
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
<< ", Use cuDNN: " << use_cudnn;
if (use_cudnn) {
// Reshape from TF depthwise filter to cuDNN grouped convolution filter:
//
// | TensorFlow | cuDNN
// --------------------------------------------------------------------
// filter_out_depth | depth_multiplier | depth_multiplier * group_count
// filter_in_depth | in_depth | in_depth / group_count
//
// For depthwise convolution, we have group_count == in_depth.
int32 filter_in_depth = 1;
TensorShape shape =
TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
Tensor reshaped_filter(/*type=*/dtype_);
OP_REQUIRES(
context, reshaped_filter.CopyFrom(filter, shape),
errors::Internal(
"Failed to reshape filter tensor for grouped convolution."));
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
// conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop,
reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
stride_, stride_, padding_, in_backprop, data_format_);
return;
}
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto filter_ptr = filter.template flat<T>().data();
auto in_backprop_ptr = in_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropInputOp<Device, T>()(
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
data_format_);
}
protected:
bool use_cudnn_grouped_conv_;
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
int64 stride_;
// For in_depth == 1 and grouped convolutions.
LaunchConv2DBackpropInputOp<Device, T> launcher_;
bool use_cudnn_;
bool cudnn_use_autotune_;
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp);
};
@ -597,23 +668,52 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
DepthwiseConv2dNativeBackpropInputOp<CPUDevice, T>);
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
#endif
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.HostMemory("input_sizes"),
DepthwiseConv2dNativeBackpropInputOp<GPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("DepthwiseConv2dNativeBackpropInput")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T")
.HostMemory("input_sizes"),
DepthwiseConv2dNativeBackpropInputOp<GPUDevice, double>);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("input_sizes"), \
DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>)
TF_CALL_half(REGISTER_GPU_KERNEL);
TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#if CUDNN_VERSION >= 7000
template <typename T>
class DepthwiseConv2dGroupedConvBackpropInputOp
: public DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T> {
public:
DepthwiseConv2dGroupedConvBackpropInputOp(OpKernelConstruction* context)
: DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>(context) {
this->use_cudnn_grouped_conv_ = true;
}
};
#define REGISTER_GROUPED_CONV_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("input_sizes") \
.Label("cudnn_grouped_convolution"), \
DepthwiseConv2dGroupedConvBackpropInputOp<T>)
TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
#undef REGISTER_GROUPED_CONV_KERNEL
#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
// Kernels to compute the gradients of the filters for depthwise convolution.
@ -885,8 +985,19 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
}
}
// Extern template instantiated in conv_grad_filter_ops.cc.
extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
#if GOOGLE_CUDA
// Extern template instantiated in conv_grad_filter_ops.cc.
extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
Eigen::half>;
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
@ -924,6 +1035,21 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
// For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
use_cudnn_grouped_conv_ = false;
if (std::is_same<T, Eigen::half>::value) {
dtype_ = DT_HALF;
} else if (std::is_same<T, float>::value) {
dtype_ = DT_FLOAT;
} else if (std::is_same<T, double>::value) {
dtype_ = DT_DOUBLE;
} else {
LOG(ERROR) << "Only half, float, and double are supported.";
}
}
void Compute(OpKernelContext* context) override {
@ -949,24 +1075,73 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{1}, 0, filter_shape, &filter_backprop));
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto input_ptr = input.template flat<T>().data();
auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
// If there is nothing to compute, return.
if (filter_shape.num_elements() == 0) {
return;
}
// If in_depth==1, this operation is just a standard convolution.
// Depthwise convolution is a special case of cuDNN's grouped convolution.
bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
<< ", " << in_depth << "]; Filter: [" << filter_rows << ", "
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
<< ", " << out_depth << "], stride = " << stride_
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
<< ", Use cuDNN: " << use_cudnn;
if (use_cudnn) {
// Reshape from TF depthwise filter to cuDNN grouped convolution filter:
//
// | TensorFlow | cuDNN
// --------------------------------------------------------------------
// filter_out_depth | depth_multiplier | depth_multiplier * group_count
// filter_in_depth | in_depth | in_depth / group_count
//
// For depthwise convolution, we have group_count == in_depth.
int32 filter_in_depth = 1;
TensorShape shape =
TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
Tensor reshaped_filter(/*type=*/dtype_);
OP_REQUIRES(
context, reshaped_filter.CopyFrom(*filter_backprop, shape),
errors::Internal(
"Failed to reshape filter tensor for grouped convolution."));
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
// conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
/*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
padding_, &reshaped_filter, data_format_);
return;
}
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto input_ptr = input.template flat<T>().data();
auto filter_backprop_ptr = filter_backprop->template flat<T>().data();
LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
data_format_);
}
protected:
bool use_cudnn_grouped_conv_;
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
int64 stride_;
// For in_depth == 1 and grouped convolutions.
LaunchConv2DBackpropFilterOp<Device, T> launcher_;
bool use_cudnn_;
bool cudnn_use_autotune_;
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp);
};
@ -976,24 +1151,50 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
DepthwiseConv2dNativeBackpropFilterOp<CPUDevice, T>);
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
#endif
#undef REGISTER_CPU_KERNEL
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
Name("DepthwiseConv2dNativeBackpropFilter")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.HostMemory("filter_sizes"),
DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, float>);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("filter_sizes"), \
DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>)
REGISTER_KERNEL_BUILDER(
Name("DepthwiseConv2dNativeBackpropFilter")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T")
.HostMemory("filter_sizes"),
DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, double>);
TF_CALL_half(REGISTER_GPU_KERNEL);
TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#if CUDNN_VERSION >= 7000
template <typename T>
class DepthwiseConv2dGroupedConvBackpropFilterOp
: public DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T> {
public:
DepthwiseConv2dGroupedConvBackpropFilterOp(OpKernelConstruction* context)
: DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>(context) {
this->use_cudnn_grouped_conv_ = true;
}
};
#define REGISTER_GROUPED_CONV_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.HostMemory("filter_sizes") \
.Label("cudnn_grouped_convolution"), \
DepthwiseConv2dGroupedConvBackpropFilterOp<T>)
TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
#undef REGISTER_GROUPED_CONV_KERNEL
#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "cuda/include/cudnn.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -241,18 +242,22 @@ struct LaunchDepthwiseConvOp<CPUDevice, T> {
};
// Extern template instantiated in conv_ops.cc.
extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
extern template struct LaunchConv2DOp<CPUDevice, float>;
extern template struct LaunchConv2DOp<CPUDevice, double>;
#if GOOGLE_CUDA
// Extern template instantiated in conv_ops.cc.
extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
extern template struct LaunchConv2DOp<GPUDevice, float>;
extern template struct LaunchConv2DOp<GPUDevice, double>;
// Extern template instantiated in depthwise_conv_op_gpu.cc.
extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
// Extern template instantiated in conv_ops.cc.
extern template struct LaunchConv2DOp<GPUDevice, float>;
#endif
template <typename Device, typename T>
@ -284,9 +289,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
// For special case when in_depth == 1.
// For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
use_cudnn_grouped_conv_ = false;
dtype_ = DataTypeToEnum<T>::value;
}
void Compute(OpKernelContext* context) override {
@ -357,27 +364,47 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
VLOG(2) << "DepthwiseConv2dNative: "
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
<< ", " << in_depth << "]; Filter: [" << filter_rows << ", "
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; stride = " << stride_ << ", pad_rows = " << pad_rows
<< ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
<< out_rows << ", " << out_cols << ", " << out_depth << "]";
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
return;
}
// If in_depth==1, this operation is just a standard convolution, so
// invoke that op.
if (std::is_same<T, float>::value && in_depth == 1) {
// TODO(csigg): Have autotune decide if native is faster than cuDNN.
// If in_depth==1, this operation is just a standard convolution.
// Depthwise convolution is a special case of cuDNN's grouped convolution.
bool use_cudnn = use_cudnn_ && (in_depth == 1 || use_cudnn_grouped_conv_);
VLOG(2) << "DepthwiseConv2dNative: "
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
<< ", " << in_depth << "]; Filter: [" << filter_rows << ", "
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
<< ", " << out_depth << "], stride = " << stride_
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
<< ", Use cuDNN: " << use_cudnn;
if (use_cudnn) {
// Reshape from TF depthwise filter to cuDNN grouped convolution filter:
//
// | TensorFlow | cuDNN
// --------------------------------------------------------------------
// filter_out_depth | depth_multiplier | depth_multiplier * group_count
// filter_in_depth | in_depth | in_depth / group_count
//
// For depthwise convolution, we have group_count == in_depth.
int32 filter_in_depth = 1;
TensorShape shape =
TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
Tensor reshaped_filter(/*type=*/dtype_);
OP_REQUIRES(
context, reshaped_filter.CopyFrom(filter, shape),
errors::Internal(
"Failed to reshape filter tensor for grouped convolution."));
// TODO(yangzihao): Send in arbitrary dilation rates after the dilated
// conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
/*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
padding_, output, data_format_);
launcher_(context, use_cudnn_, cudnn_use_autotune_, input,
reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
stride_, stride_, padding_, output, data_format_);
return;
}
@ -403,6 +430,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
output_ptr, data_format_);
}
protected:
bool use_cudnn_grouped_conv_;
private:
std::vector<int32> strides_;
Padding padding_;
@ -410,10 +440,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
int64 stride_; // in height/width dimension.
// For the case in_depth == 1.
// For in_depth == 1 and grouped convolutions.
LaunchConv2DOp<Device, T> launcher_;
bool use_cudnn_;
bool cudnn_use_autotune_;
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
};
@ -421,7 +452,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
DepthwiseConv2dNativeOp<CPUDevice, T>);
DepthwiseConv2dNativeOp<CPUDevice, T>)
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
@ -430,19 +461,38 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
#endif
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T"),
DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"),
DepthwiseConv2dNativeOp<GPUDevice, float>);
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
DepthwiseConv2dNativeOp<GPUDevice, T>)
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T"),
DepthwiseConv2dNativeOp<GPUDevice, double>);
#endif
TF_CALL_half(REGISTER_GPU_KERNEL);
TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);
#if CUDNN_VERSION >= 7000
template <typename T>
class DepthwiseConv2dGroupedConvOp
: public DepthwiseConv2dNativeOp<GPUDevice, T> {
public:
DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
: DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
this->use_cudnn_grouped_conv_ = true;
}
};
#define REGISTER_GROUPED_CONV_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.Label("cudnn_grouped_convolution"), \
DepthwiseConv2dGroupedConvOp<T>)
TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
#endif // CUDNN_VERSION
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -22,12 +22,15 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def ConfigsToTest():
@ -98,6 +101,7 @@ class DepthwiseConv2DTest(test.TestCase):
padding,
data_type,
use_gpu,
grouped_conv=False,
data_format="NHWC"):
"""Verifies the output values of the convolution function.
@ -110,25 +114,26 @@ class DepthwiseConv2DTest(test.TestCase):
padding: Padding type.
data_type: The data type to use.
use_gpu: Whether to use GPU.
grouped_conv: Whether to use cuDNN 7's grouped convolution.
data_format: The data_format of the input. "NHWC" or "NCHW".
"""
total_size_1 = 1
total_size_2 = 1
input_size = 1
filter_size = 1
for s in tensor_in_sizes:
total_size_1 *= s
input_size *= s
for s in filter_in_sizes:
total_size_2 *= s
filter_size *= s
# Initializes the input and filter tensor with numbers incrementing from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
with self.test_session(use_gpu=use_gpu) as sess:
if data_type == dtypes.float16:
tolerance = 1e-5
elif data_type == dtypes.float32:
tolerance = 1e-5
else:
self.assertEqual(data_type, dtypes.float64)
tolerance = 1e-8
x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-2,
dtypes.float32: 1e-8,
dtypes.float64: 1e-13,
}[data_type]
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
t1.set_shape(tensor_in_sizes)
@ -142,25 +147,39 @@ class DepthwiseConv2DTest(test.TestCase):
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
conv_native = nn_ops.depthwise_conv2d_native(
native_t1,
t2,
strides=strides,
data_format=data_format,
padding=padding)
with sess.graph._kernel_label_map({
"DepthwiseConv2dNative": "cudnn_grouped_convolution"
} if grouped_conv else {}):
conv_native = nn_ops.depthwise_conv2d_native(
native_t1,
t2,
strides=strides,
data_format=data_format,
padding=padding)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
try:
native_result = sess.run(conv_native)
except errors.InvalidArgumentError as e:
# Grouped convolution kernel is only registered for cuDNN 7. Silently
# return when we are running on an earlier version or without GPU.
if e.message.startswith(
"No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
tf_logging.warn("Skipping grouped convolution test")
return
raise e
conv_interface = nn_impl.depthwise_conv2d(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
native_result = sess.run(conv_native)
interface_result = sess.run(conv_interface)
print("data_type:", data_type, "use_gpu:", use_gpu, "max diff = ",
np.amax(np.absolute(native_result - interface_result)))
tf_logging.info(
"data_type: %r, use_gpu: %r, grouped_conv: %r, max diff = %f",
data_type, use_gpu, grouped_conv,
np.amax(np.absolute(native_result - interface_result)))
self.assertArrayNear(
np.ravel(native_result), np.ravel(interface_result), tolerance)
self.assertShapeEqual(native_result, conv_native)
@ -169,11 +188,22 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2D(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
filter_size, "stride:", stride, "padding:", padding)
tf_logging.info(
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
"%s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
tf_logging.info("Testing with grouped_conv")
self._VerifyValues(
input_size,
filter_size,
stride,
padding,
data_type,
use_gpu=True,
grouped_conv=True)
def testDepthwiseConv2DFormat(self):
if not test.is_gpu_available():
@ -181,8 +211,9 @@ class DepthwiseConv2DTest(test.TestCase):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
"*", filter_size, "stride:", stride, "padding:", padding)
tf_logging.info(
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._VerifyValues(
input_size,
@ -226,7 +257,7 @@ class DepthwiseConv2DTest(test.TestCase):
conv = nn_ops.depthwise_conv2d_native(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
value = sess.run(conv)
print("value = ", value)
tf_logging.info("value = %r", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
@ -296,7 +327,7 @@ class DepthwiseConv2DTest(test.TestCase):
expected=expected_output,
use_gpu=True)
# Gradient checkers.This tests depthwise gradient computations for both
# Gradient checkers. This tests depthwise gradient computations for both
# BackpropFilter and BackpropInput by comparing gradients computed by the
# depthwise gradient ops with the gradients computed numerically (details can
# be found in the compute_gradient_error().
@ -310,6 +341,7 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input,
use_gpu,
grouped_conv=False,
data_format="NHWC"):
input_size = 1
for x in input_shape:
@ -319,14 +351,14 @@ class DepthwiseConv2DTest(test.TestCase):
filter_size *= x
input_data = [x * 1.0 / input_size for x in range(0, input_size)]
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
with self.test_session(use_gpu=use_gpu):
if data_type == dtypes.float16:
tolerance = 0.002
elif data_type == dtypes.float32:
tolerance = 0.002
else:
self.assertEqual(data_type, dtypes.float64)
tolerance = 1e-8
ops.reset_default_graph()
graph = ops.get_default_graph()
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 2e-0,
dtypes.float32: 5e-4,
dtypes.float64: 1e-12,
}[data_type]
input_tensor = constant_op.constant(
input_data, shape=input_shape, dtype=data_type, name="input")
@ -347,35 +379,49 @@ class DepthwiseConv2DTest(test.TestCase):
]
strides = [1, 1, stride, stride]
depthwise_conv2d = nn_ops.depthwise_conv2d_native(
native_input,
filter_tensor,
strides,
padding,
data_format=data_format,
name="depthwise_conv2d")
with sess.graph._kernel_label_map({
"DepthwiseConv2dNative": "cudnn_grouped_convolution",
"DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
"DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
} if grouped_conv else {}):
depthwise_conv2d = nn_ops.depthwise_conv2d_native(
native_input,
filter_tensor,
strides,
padding,
data_format=data_format,
name="depthwise_conv2d")
self.assertEqual(output_shape, depthwise_conv2d.get_shape())
if test_input:
err = gradient_checker.compute_gradient_error(
native_input, input_shape, depthwise_conv2d, output_shape)
else:
err = gradient_checker.compute_gradient_error(filter_tensor,
filter_shape,
depthwise_conv2d,
output_shape)
print("data_type:", data_type, "use_gpu:", use_gpu, ", error = ", err)
try:
if test_input:
err = gradient_checker.compute_gradient_error(
native_input, input_shape, depthwise_conv2d, output_shape)
else:
err = gradient_checker.compute_gradient_error(
filter_tensor, filter_shape, depthwise_conv2d, output_shape)
except errors.InvalidArgumentError as e:
# Grouped convolution kernel is only registered for cuDNN 7. Silently
# return when we are running on an earlier version or without GPU.
if grouped_conv and e.message.startswith(
"No OpKernel was registered to support Op 'DepthwiseConv2dNative'"):
tf_logging.warn("Skipping grouped convolution test")
return
raise e
tf_logging.info(
"data_type: %r, use_gpu: %r, grouped_conv: %r, error = %f", data_type,
use_gpu, grouped_conv, err)
self.assertLess(err, tolerance)
def testDepthwiseConv2DInputGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
print("Testing DepthwiseConv2DInputGrad,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
# Note: float16 test for DepthwiseConv2DInputGrad is not enabled,
# calculations are not very precise.
for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info(
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@ -385,6 +431,16 @@ class DepthwiseConv2DTest(test.TestCase):
data_type,
test_input=True,
use_gpu=True)
self._ConstructAndTestGradient(
input_size,
filter_size,
output_size,
stride,
padding,
data_type,
test_input=True,
use_gpu=True,
grouped_conv=True)
def testDepthwiseConv2DInputGradFormat(self):
if not test.is_gpu_available():
@ -392,12 +448,11 @@ class DepthwiseConv2DTest(test.TestCase):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
print("Testing DepthwiseConv2DInputGradFormat,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
# Note: float16 test for DepthwiseConv2DInputGradFormat is not enabled,
# calculations are not very precise.
for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info(
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@ -412,12 +467,10 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DFilterGrad(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
print("Testing DepthwiseConv2DFilterGrad,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
# Note: float16 test for DepthwiseConv2DFilterGrad is not enabled,
# calculations are not very precise.
for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info(
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
"%d, padding: %s", index, input_size, filter_size, stride, padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@ -434,12 +487,11 @@ class DepthwiseConv2DTest(test.TestCase):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(CheckGradConfigsToTest()):
print("Testing DepthwiseConv2DFilterGradFormat,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
# Note: float16 test for DepthwiseConv2DFilterGradFormat is not enabled,
# calculations are not very precise.
for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info(
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@ -494,9 +546,10 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DInputGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
tf_logging.info(
"Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
self._CompareBackpropInputFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropInputDouble(input_size, filter_size, output_size,
@ -545,9 +598,10 @@ class DepthwiseConv2DTest(test.TestCase):
def testDepthwiseConv2DFilterGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
tf_logging.info(
"Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
stride, padding)
self._CompareBackpropFilterDouble(input_size, filter_size, output_size,

View File

@ -337,7 +337,9 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
#if CUDNN_VERSION >= 7000
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
__macro(cudnnSetConvolutionMathType) \
__macro(cudnnSetRNNMatrixMathType)
__macro(cudnnSetRNNMatrixMathType) \
__macro(cudnnSetConvolutionGroupCount) \
__macro(cudnnGetConvolutionGroupCount)
// clang-format on
CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
@ -779,6 +781,20 @@ class ScopedConvolutionDescriptor {
// NOTE(benbarsdell): This only applies if tensor op math is enabled
// and algo selection is set to Default.
this->set_use_tensor_op_math(true);
#if CUDNN_MAJOR >= 7
VLOG(2) << "Requesting grouped convolution: "
<< convolution_descriptor.group_count();
status = wrap::cudnnSetConvolutionGroupCount(
parent_, handle_, convolution_descriptor.group_count());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "could not set cudnn convolution group count: "
<< ToString(status);
}
#else
CHECK_EQ(convolution_descriptor.group_count(), 1)
<< "Requested grouped convolution for cuDNN version < 7";
#endif
}
void set_use_tensor_op_math(bool use_tensor_op_math) {

View File

@ -434,6 +434,7 @@ ConvolutionDescriptor::ConvolutionDescriptor(int ndims)
filter_strides_(ndims, 1),
dilation_rates_(ndims, 1),
pad_alignment_(PadAlignment::kDefault),
group_count_(1),
ndims_(ndims) {}
ConvolutionDescriptor::ConvolutionDescriptor()

View File

@ -543,6 +543,10 @@ class ConvolutionDescriptor {
pad_alignment_ = pad_alignment;
return *this;
}
ConvolutionDescriptor& set_group_count(int group_count) {
group_count_ = group_count;
return *this;
}
int64 zero_padding_height() const {
return GetDim(zero_padding_, DimIndex::Y);
}
@ -566,6 +570,7 @@ class ConvolutionDescriptor {
int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); }
int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); }
PadAlignment pad_alignment() const { return pad_alignment_; }
int group_count() const { return group_count_; }
int ndims() const { return ndims_; }
std::vector<int64> strides() const { return filter_strides_; }
@ -578,6 +583,7 @@ class ConvolutionDescriptor {
std::vector<int64> filter_strides_;
std::vector<int64> dilation_rates_;
PadAlignment pad_alignment_;
int group_count_;
int ndims_;
// TODO(leary) cudnn provides these fields, but need to characterize what
// their effect is -- they may be boolean rather than integral.