Enable 3D Group Convolutions

Implemented 3D Group Convolutions by passing group_count to ConvolutionDescriptor.
This commit is contained in:
Dan Ganea 2019-08-04 22:59:51 +02:00
parent 65849ef4e0
commit b382c69cfa
3 changed files with 62 additions and 18 deletions

View File

@ -877,15 +877,35 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
// Check that the input tensor and the filter tensor agree on the channel
// count.
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
int64 input_depth_value = c->Value(input_depth_dim),
filter_input_depth_value = c->Value(filter_input_depth_dim);
if (input_depth_value % filter_input_depth_value != 0)
return errors::InvalidArgument(
"Depth of input (", input_depth_value,
") is not a multiple of input depth of filter (",
filter_input_depth_value, ")");
if (input_depth_value != filter_input_depth_value) {
int64 num_groups = input_depth_value / filter_input_depth_value;
if (c->ValueKnown(output_depth_dim)) {
int64 output_depth_value = c->Value(output_depth_dim);
if (output_depth_value % num_groups != 0)
return errors::InvalidArgument(
"Depth of output (", output_depth_value,
") is not a multiple of the number of groups (", num_groups, ")");
}
}
}
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));

View File

@ -963,9 +963,15 @@ TEST(CommonShapeFnsTest, Conv3DShapeTest) {
INFER_OK(op, "[1,2,2,2,1];[1,1,1,?,1]", "[d0_0,2,2,2,d1_4]");
INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,?]", "[d0_0,2,2,2,d1_4]");
// input depths must match.
INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
"[1,2,2,2,10];[1,1,1,10000,20]");
// input depth must be multiple of filter depth for group convolutions
INFER_ERROR(
"Depth of input (10) is not a multiple of input depth of filter (6)", op,
"[1,2,2,2,10];[1,1,1,6,20]");
// Output dimensions must be multiple of group number
INFER_ERROR(
"Depth of output (1) is not a multiple of the number of groups (2)", op,
"[1,2,2,2,10];[1,1,1,5,1]");
// 2x2x2 filter
set_op({{1, 1, 1, 1, 1}}, "VALID");
@ -983,6 +989,17 @@ TEST(CommonShapeFnsTest, Conv3DShapeTest) {
set_op({{1, 1, 1, 1, 1}}, "SAME");
INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
// 4x4 input of depth 10, 2x2 filter with depth 5, 1x1 stride
INFER_OK(op, "[1,4,4,4,10];[2,2,2,5,2]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
// test output multiple of group size is ok
// 4x4 input of depth 10, 2x2 filter with depth 5, 1x1 stride
INFER_OK(op, "[1,4,4,4,10];[2,2,2,5,2]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
// Depthwise convolution first step
// 4x4 input of depth 10, 2x2 filter with depth 1, 1x1 stride
INFER_OK(op, "[1,4,4,4,10];[2,2,2,1,10]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
// with SAME, filter doesn't matter except for last dim.
set_op({{1, 1, 1, 1, 1}}, "SAME");
INFER_OK(op, "[?,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");

View File

@ -132,10 +132,13 @@ class Conv3DOp : public BinaryOp<T> {
const int64 in_depth = GetTensorDim(input, data_format_, 'C');
const int64 in_batch = GetTensorDim(input, data_format_, 'N');
const int64 filter_depth = filter.dim_size(3);
const int64 out_depth = filter.dim_size(4);
OP_REQUIRES(
context, in_depth == filter.dim_size(3),
errors::InvalidArgument("input and filter must have the same depth"));
OP_REQUIRES(context, in_depth % filter_depth == 0,
errors::InvalidArgument(
"Input depth must be evenly divisible by filter depth: ",
in_depth, " vs ", filter_depth));
// Dimension order for these arrays is: z, y, x.
std::array<int64, 3> input_size = {
@ -218,6 +221,7 @@ struct LaunchConvOp<GPUDevice, T> {
const int64 filter_planes = filter.dim_size(0);
const int64 filter_rows = filter.dim_size(1);
const int64 filter_cols = filter.dim_size(2);
const int64 filter_depth = filter.dim_size(3);
const int64 out_depth = filter.dim_size(4);
int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
@ -234,11 +238,13 @@ struct LaunchConvOp<GPUDevice, T> {
0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
}
bool is_grouped_convolution = filter_depth != in_depth;
// NOTE: This only works in NHWC.
if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 &&
dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1 &&
strides[0] == 1 && strides[1] == 1 && strides[2] == 1 &&
data_format == FORMAT_NHWC) {
if (!is_grouped_convolution && filter_planes == 1 && filter_rows == 1 &&
filter_cols == 1 && dilations[0] == 1 && dilations[1] == 1 &&
dilations[2] == 1 && strides[0] == 1 && strides[1] == 1 &&
strides[2] == 1 && data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
const uint64 m = in_batch * in_planes * in_rows * in_cols;
const uint64 k = in_depth;
@ -262,9 +268,9 @@ struct LaunchConvOp<GPUDevice, T> {
", n=", n, ", k=", k));
}
return;
} else if (filter_planes == in_planes && filter_rows == in_rows &&
filter_cols == in_cols && padding == Padding::VALID &&
data_format == FORMAT_NHWC) {
} else if (!is_grouped_convolution && filter_planes == in_planes &&
filter_rows == in_rows && filter_cols == in_cols &&
padding == Padding::VALID && data_format == FORMAT_NHWC) {
// The input data and filter have the same planes/height/width, so call
// cublas directly.
const uint64 m = in_batch;
@ -365,7 +371,7 @@ struct LaunchConvOp<GPUDevice, T> {
filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
.set_spatial_dim(DimIndex::Y, filter_rows)
.set_spatial_dim(DimIndex::Z, filter_planes)
.set_input_feature_map_count(in_depth)
.set_input_feature_map_count(filter_depth)
.set_output_feature_map_count(out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
@ -376,7 +382,8 @@ struct LaunchConvOp<GPUDevice, T> {
.set_filter_stride(DimIndex::Z, strides[0])
.set_zero_padding(DimIndex::X, pad_cols / 2)
.set_zero_padding(DimIndex::Y, pad_rows / 2)
.set_zero_padding(DimIndex::Z, pad_planes / 2);
.set_zero_padding(DimIndex::Z, pad_planes / 2)
.set_group_count(in_depth / filter_depth);
Tensor transformed_filter;
OP_REQUIRES_OK(