diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 801fbb994d0..8586530b2df 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -877,15 +877,35 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { DimensionHandle in_planes_dim = c->Dim(input_shape, 1); DimensionHandle in_rows_dim = c->Dim(input_shape, 2); DimensionHandle in_cols_dim = c->Dim(input_shape, 3); + DimensionHandle input_depth_dim = c->Dim(input_shape, 4); DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0); DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1); DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2); + DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3); DimensionHandle output_depth_dim = c->Dim(filter_shape, 4); - DimensionHandle unused; - TF_RETURN_IF_ERROR( - c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused)); + // Check that the input tensor and the filter tensor agree on the channel + // count. + if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) { + int64 input_depth_value = c->Value(input_depth_dim), + filter_input_depth_value = c->Value(filter_input_depth_dim); + if (input_depth_value % filter_input_depth_value != 0) + return errors::InvalidArgument( + "Depth of input (", input_depth_value, + ") is not a multiple of input depth of filter (", + filter_input_depth_value, ")"); + if (input_depth_value != filter_input_depth_value) { + int64 num_groups = input_depth_value / filter_input_depth_value; + if (c->ValueKnown(output_depth_dim)) { + int64 output_depth_value = c->Value(output_depth_dim); + if (output_depth_value % num_groups != 0) + return errors::InvalidArgument( + "Depth of output (", output_depth_value, + ") is not a multiple of the number of groups (", num_groups, ")"); + } + } + } Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 40ca891f929..19642efe389 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -963,9 +963,15 @@ TEST(CommonShapeFnsTest, Conv3DShapeTest) { INFER_OK(op, "[1,2,2,2,1];[1,1,1,?,1]", "[d0_0,2,2,2,d1_4]"); INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,?]", "[d0_0,2,2,2,d1_4]"); - // input depths must match. - INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op, - "[1,2,2,2,10];[1,1,1,10000,20]"); + // input depth must be multiple of filter depth for group convolutions + INFER_ERROR( + "Depth of input (10) is not a multiple of input depth of filter (6)", op, + "[1,2,2,2,10];[1,1,1,6,20]"); + + // Output dimensions must be multiple of group number + INFER_ERROR( + "Depth of output (1) is not a multiple of the number of groups (2)", op, + "[1,2,2,2,10];[1,1,1,5,1]"); // 2x2x2 filter set_op({{1, 1, 1, 1, 1}}, "VALID"); @@ -983,6 +989,17 @@ TEST(CommonShapeFnsTest, Conv3DShapeTest) { set_op({{1, 1, 1, 1, 1}}, "SAME"); INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + // 4x4 input of depth 10, 2x2 filter with depth 5, 1x1 stride + INFER_OK(op, "[1,4,4,4,10];[2,2,2,5,2]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + + // test output multiple of group size is ok + // 4x4 input of depth 10, 2x2 filter with depth 5, 1x1 stride + INFER_OK(op, "[1,4,4,4,10];[2,2,2,5,2]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + + // Depthwise convolution first step + // 4x4 input of depth 10, 2x2 filter with depth 1, 1x1 stride + INFER_OK(op, "[1,4,4,4,10];[2,2,2,1,10]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); + // with SAME, filter doesn't matter except for last dim. set_op({{1, 1, 1, 1, 1}}, "SAME"); INFER_OK(op, "[?,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]"); diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index c92dccd2f81..2b7a624be69 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -132,10 +132,13 @@ class Conv3DOp : public BinaryOp { const int64 in_depth = GetTensorDim(input, data_format_, 'C'); const int64 in_batch = GetTensorDim(input, data_format_, 'N'); + const int64 filter_depth = filter.dim_size(3); const int64 out_depth = filter.dim_size(4); - OP_REQUIRES( - context, in_depth == filter.dim_size(3), - errors::InvalidArgument("input and filter must have the same depth")); + + OP_REQUIRES(context, in_depth % filter_depth == 0, + errors::InvalidArgument( + "Input depth must be evenly divisible by filter depth: ", + in_depth, " vs ", filter_depth)); // Dimension order for these arrays is: z, y, x. std::array input_size = { @@ -218,6 +221,7 @@ struct LaunchConvOp { const int64 filter_planes = filter.dim_size(0); const int64 filter_rows = filter.dim_size(1); const int64 filter_cols = filter.dim_size(2); + const int64 filter_depth = filter.dim_size(3); const int64 out_depth = filter.dim_size(4); int64 pad_planes = 0, pad_rows = 0, pad_cols = 0; @@ -234,11 +238,13 @@ struct LaunchConvOp { 0, (out_cols - 1) * strides[2] + filter_cols - in_cols); } + bool is_grouped_convolution = filter_depth != in_depth; + // NOTE: This only works in NHWC. - if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 && - dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1 && - strides[0] == 1 && strides[1] == 1 && strides[2] == 1 && - data_format == FORMAT_NHWC) { + if (!is_grouped_convolution && filter_planes == 1 && filter_rows == 1 && + filter_cols == 1 && dilations[0] == 1 && dilations[1] == 1 && + dilations[2] == 1 && strides[0] == 1 && strides[1] == 1 && + strides[2] == 1 && data_format == FORMAT_NHWC) { // 1x1 filter, so call cublas directly. const uint64 m = in_batch * in_planes * in_rows * in_cols; const uint64 k = in_depth; @@ -262,9 +268,9 @@ struct LaunchConvOp { ", n=", n, ", k=", k)); } return; - } else if (filter_planes == in_planes && filter_rows == in_rows && - filter_cols == in_cols && padding == Padding::VALID && - data_format == FORMAT_NHWC) { + } else if (!is_grouped_convolution && filter_planes == in_planes && + filter_rows == in_rows && filter_cols == in_cols && + padding == Padding::VALID && data_format == FORMAT_NHWC) { // The input data and filter have the same planes/height/width, so call // cublas directly. const uint64 m = in_batch; @@ -365,7 +371,7 @@ struct LaunchConvOp { filter_desc.set_spatial_dim(DimIndex::X, filter_cols) .set_spatial_dim(DimIndex::Y, filter_rows) .set_spatial_dim(DimIndex::Z, filter_planes) - .set_input_feature_map_count(in_depth) + .set_input_feature_map_count(filter_depth) .set_output_feature_map_count(out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) @@ -376,7 +382,8 @@ struct LaunchConvOp { .set_filter_stride(DimIndex::Z, strides[0]) .set_zero_padding(DimIndex::X, pad_cols / 2) .set_zero_padding(DimIndex::Y, pad_rows / 2) - .set_zero_padding(DimIndex::Z, pad_planes / 2); + .set_zero_padding(DimIndex::Z, pad_planes / 2) + .set_group_count(in_depth / filter_depth); Tensor transformed_filter; OP_REQUIRES_OK(