Fixed backpropagation gradient for grouped 3D conv

This commit is contained in:
Dan Ganea 2019-08-07 23:58:55 +02:00
parent f605ea0a14
commit 75c2f170c4

View File

@ -1164,11 +1164,12 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
const uint64 m = dims.batch_size * dims.input_size(0) *
dims.input_size(1) * dims.input_size(2);
const uint64 k = dims.out_depth;
@ -1194,7 +1195,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
} else if (dims.filter_size(0) == dims.input_size(0) &&
} else if (!is_grouped_convolution &&
dims.filter_size(0) == dims.input_size(0) &&
dims.filter_size(1) == dims.input_size(1) &&
dims.filter_size(2) == dims.input_size(2) &&
padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
@ -1269,8 +1271,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
.set_input_feature_map_count(filter_shape.dim_size(3))
.set_output_feature_map_count(filter_shape.dim_size(4));
se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
@ -1280,17 +1282,18 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
.set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
.set_zero_padding(DimIndex::Z, padding_planes / 2)
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
// Shape: out, in, z, y, x.
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
dims.filter_size(1), dims.filter_size(2)}),
&transformed_filter));
context, context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({filter_shape.dim_size(4),
filter_shape.dim_size(3), dims.filter_size(0),
dims.filter_size(1), dims.filter_size(2)}),
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
context->eigen_device<GPUDevice>(), FORMAT_OIHW,
To32Bit(filter.tensor<T, 5>()),
@ -1566,11 +1569,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
const uint64 m = dims.in_depth;
const uint64 k = dims.batch_size * dims.input_size(1) *
dims.input_size(2) * dims.input_size(0);
@ -1605,7 +1609,8 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
} else if (dims.filter_size(0) == dims.input_size(0) &&
} else if (!is_grouped_convolution &&
dims.filter_size(0) == dims.input_size(0) &&
dims.filter_size(1) == dims.input_size(1) &&
dims.filter_size(2) == dims.input_size(2) &&
padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
@ -1685,8 +1690,8 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
.set_input_feature_map_count(filter_shape.dim_size(3))
.set_output_feature_map_count(filter_shape.dim_size(4));
se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
@ -1696,16 +1701,16 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
.set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
.set_zero_padding(DimIndex::Z, padding_planes / 2)
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
dims.filter_size(1), dims.filter_size(2)}),
&pre_transformed_filter_backprop));
context, context->allocate_temp(
DataTypeToEnum<T>::value,
TensorShape({filter_shape.dim_size(4),
filter_shape.dim_size(3), dims.filter_size(0),
dims.filter_size(1), dims.filter_size(2)}),
&pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {