diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 6ab51781f6d..3ba6a9a6f39 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -1164,11 +1164,12 @@ class Conv3DBackpropInputOp : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 && - dims.filter_size(2) == 1 && dims.dilation(0) == 1 && - dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 && - dims.stride(1) == 1 && dims.stride(2) == 1 && - data_format_ == FORMAT_NHWC) { + bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth; + if (!is_grouped_convolution && dims.filter_size(0) == 1 && + dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && + dims.dilation(0) == 1 && dims.dilation(1) == 1 && + dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 && + dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) { const uint64 m = dims.batch_size * dims.input_size(0) * dims.input_size(1) * dims.input_size(2); const uint64 k = dims.out_depth; @@ -1194,7 +1195,8 @@ class Conv3DBackpropInputOp : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (dims.filter_size(0) == dims.input_size(0) && + } else if (!is_grouped_convolution && + dims.filter_size(0) == dims.input_size(0) && dims.filter_size(1) == dims.input_size(1) && dims.filter_size(2) == dims.input_size(2) && padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { @@ -1269,8 +1271,8 @@ class Conv3DBackpropInputOp : public OpKernel { filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); + .set_input_feature_map_count(filter_shape.dim_size(3)) + .set_output_feature_map_count(filter_shape.dim_size(4)); se::dnn::ConvolutionDescriptor conv_desc(3); conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) .set_dilation_rate(DimIndex::Y, dims.dilation(1)) @@ -1280,17 +1282,18 @@ class Conv3DBackpropInputOp : public OpKernel { .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) - .set_zero_padding(DimIndex::Z, padding_planes / 2); + .set_zero_padding(DimIndex::Z, padding_planes / 2) + .set_group_count(dims.in_depth / filter_shape.dim_size(3)); // Shape: out, in, z, y, x. Tensor transformed_filter; OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum::value, - TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), - dims.filter_size(1), dims.filter_size(2)}), - &transformed_filter)); + context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({filter_shape.dim_size(4), + filter_shape.dim_size(3), dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &transformed_filter)); functor::TransformFilter()( context->eigen_device(), FORMAT_OIHW, To32Bit(filter.tensor()), @@ -1566,11 +1569,12 @@ class Conv3DBackpropFilterOp : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && - dims.filter_size(0) == 1 && dims.dilation(2) == 1 && - dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 && - dims.stride(1) == 1 && dims.stride(0) == 1 && - data_format_ == FORMAT_NHWC) { + bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth; + if (!is_grouped_convolution && dims.filter_size(1) == 1 && + dims.filter_size(2) == 1 && dims.filter_size(0) == 1 && + dims.dilation(2) == 1 && dims.dilation(1) == 1 && + dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 && + dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) { const uint64 m = dims.in_depth; const uint64 k = dims.batch_size * dims.input_size(1) * dims.input_size(2) * dims.input_size(0); @@ -1605,7 +1609,8 @@ class Conv3DBackpropFilterOp : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (dims.filter_size(0) == dims.input_size(0) && + } else if (!is_grouped_convolution && + dims.filter_size(0) == dims.input_size(0) && dims.filter_size(1) == dims.input_size(1) && dims.filter_size(2) == dims.input_size(2) && padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { @@ -1685,8 +1690,8 @@ class Conv3DBackpropFilterOp : public OpKernel { filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); + .set_input_feature_map_count(filter_shape.dim_size(3)) + .set_output_feature_map_count(filter_shape.dim_size(4)); se::dnn::ConvolutionDescriptor conv_desc(3); conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) .set_dilation_rate(DimIndex::Y, dims.dilation(1)) @@ -1696,16 +1701,16 @@ class Conv3DBackpropFilterOp : public OpKernel { .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) - .set_zero_padding(DimIndex::Z, padding_planes / 2); - + .set_zero_padding(DimIndex::Z, padding_planes / 2) + .set_group_count(dims.in_depth / filter_shape.dim_size(3)); Tensor pre_transformed_filter_backprop; OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum::value, - TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), - dims.filter_size(1), dims.filter_size(2)}), - &pre_transformed_filter_backprop)); + context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({filter_shape.dim_size(4), + filter_shape.dim_size(3), dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &pre_transformed_filter_backprop)); Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) {