Merge pull request #31492 from danganea:3DGroupConvs
PiperOrigin-RevId: 262987099
This commit is contained in:
commit
ce69ecacca
@ -877,15 +877,35 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
|
||||
DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
|
||||
DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
|
||||
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
|
||||
DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
|
||||
|
||||
DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
|
||||
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
|
||||
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
|
||||
DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
|
||||
DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
|
||||
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
|
||||
// Check that the input tensor and the filter tensor agree on the channel
|
||||
// count.
|
||||
if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
|
||||
int64 input_depth_value = c->Value(input_depth_dim),
|
||||
filter_input_depth_value = c->Value(filter_input_depth_dim);
|
||||
if (input_depth_value % filter_input_depth_value != 0)
|
||||
return errors::InvalidArgument(
|
||||
"Depth of input (", input_depth_value,
|
||||
") is not a multiple of input depth of filter (",
|
||||
filter_input_depth_value, ")");
|
||||
if (input_depth_value != filter_input_depth_value) {
|
||||
int64 num_groups = input_depth_value / filter_input_depth_value;
|
||||
if (c->ValueKnown(output_depth_dim)) {
|
||||
int64 output_depth_value = c->Value(output_depth_dim);
|
||||
if (output_depth_value % num_groups != 0)
|
||||
return errors::InvalidArgument(
|
||||
"Depth of output (", output_depth_value,
|
||||
") is not a multiple of the number of groups (", num_groups, ")");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Padding padding;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
||||
|
@ -963,9 +963,15 @@ TEST(CommonShapeFnsTest, Conv3DShapeTest) {
|
||||
INFER_OK(op, "[1,2,2,2,1];[1,1,1,?,1]", "[d0_0,2,2,2,d1_4]");
|
||||
INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,?]", "[d0_0,2,2,2,d1_4]");
|
||||
|
||||
// input depths must match.
|
||||
INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
|
||||
"[1,2,2,2,10];[1,1,1,10000,20]");
|
||||
// input depth must be multiple of filter depth for group convolutions
|
||||
INFER_ERROR(
|
||||
"Depth of input (10) is not a multiple of input depth of filter (6)", op,
|
||||
"[1,2,2,2,10];[1,1,1,6,20]");
|
||||
|
||||
// Output dimensions must be multiple of group number
|
||||
INFER_ERROR(
|
||||
"Depth of output (1) is not a multiple of the number of groups (2)", op,
|
||||
"[1,2,2,2,10];[1,1,1,5,1]");
|
||||
|
||||
// 2x2x2 filter
|
||||
set_op({{1, 1, 1, 1, 1}}, "VALID");
|
||||
@ -983,6 +989,17 @@ TEST(CommonShapeFnsTest, Conv3DShapeTest) {
|
||||
set_op({{1, 1, 1, 1, 1}}, "SAME");
|
||||
INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
||||
|
||||
// 4x4 input of depth 10, 2x2 filter with depth 5, 1x1 stride
|
||||
INFER_OK(op, "[1,4,4,4,10];[2,2,2,5,2]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
||||
|
||||
// test output multiple of group size is ok
|
||||
// 4x4 input of depth 10, 2x2 filter with depth 5, 1x1 stride
|
||||
INFER_OK(op, "[1,4,4,4,10];[2,2,2,5,2]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
||||
|
||||
// Depthwise convolution first step
|
||||
// 4x4 input of depth 10, 2x2 filter with depth 1, 1x1 stride
|
||||
INFER_OK(op, "[1,4,4,4,10];[2,2,2,1,10]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
||||
|
||||
// with SAME, filter doesn't matter except for last dim.
|
||||
set_op({{1, 1, 1, 1, 1}}, "SAME");
|
||||
INFER_OK(op, "[?,4,4,4,1];[2,2,2,1,1]", "[d0_0,d0_1,d0_2,d0_3,d1_4]");
|
||||
|
@ -970,6 +970,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
common_padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
conv_desc.group_count() // group_count
|
||||
};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
|
||||
|
@ -1096,6 +1096,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
common_padding_cols}}, // padding_cols
|
||||
dtype, // tensor data type
|
||||
device_id, // device_id
|
||||
conv_desc.group_count() // group_count
|
||||
};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
|
||||
|
@ -1171,11 +1171,12 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
|
||||
dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
|
||||
dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
|
||||
dims.stride(1) == 1 && dims.stride(2) == 1 &&
|
||||
data_format_ == FORMAT_NHWC) {
|
||||
bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
|
||||
if (!is_grouped_convolution && dims.filter_size(0) == 1 &&
|
||||
dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
|
||||
dims.dilation(0) == 1 && dims.dilation(1) == 1 &&
|
||||
dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 &&
|
||||
dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) {
|
||||
const uint64 m = dims.batch_size * dims.input_size(0) *
|
||||
dims.input_size(1) * dims.input_size(2);
|
||||
const uint64 k = dims.out_depth;
|
||||
@ -1201,7 +1202,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (dims.filter_size(0) == dims.input_size(0) &&
|
||||
} else if (!is_grouped_convolution &&
|
||||
dims.filter_size(0) == dims.input_size(0) &&
|
||||
dims.filter_size(1) == dims.input_size(1) &&
|
||||
dims.filter_size(2) == dims.input_size(2) &&
|
||||
padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
|
||||
@ -1276,8 +1278,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
|
||||
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
|
||||
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
|
||||
.set_input_feature_map_count(dims.in_depth)
|
||||
.set_output_feature_map_count(dims.out_depth);
|
||||
.set_input_feature_map_count(filter_shape.dim_size(3))
|
||||
.set_output_feature_map_count(filter_shape.dim_size(4));
|
||||
se::dnn::ConvolutionDescriptor conv_desc(3);
|
||||
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
|
||||
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
|
||||
@ -1287,17 +1289,18 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
.set_filter_stride(DimIndex::Z, dims.stride(0))
|
||||
.set_zero_padding(DimIndex::X, padding_cols / 2)
|
||||
.set_zero_padding(DimIndex::Y, padding_rows / 2)
|
||||
.set_zero_padding(DimIndex::Z, padding_planes / 2);
|
||||
.set_zero_padding(DimIndex::Z, padding_planes / 2)
|
||||
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
|
||||
|
||||
// Shape: out, in, z, y, x.
|
||||
Tensor transformed_filter;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
|
||||
dims.filter_size(1), dims.filter_size(2)}),
|
||||
&transformed_filter));
|
||||
context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({filter_shape.dim_size(4),
|
||||
filter_shape.dim_size(3), dims.filter_size(0),
|
||||
dims.filter_size(1), dims.filter_size(2)}),
|
||||
&transformed_filter));
|
||||
functor::TransformFilter<GPUDevice, T, int, 5>()(
|
||||
context->eigen_device<GPUDevice>(), FORMAT_OIHW,
|
||||
To32Bit(filter.tensor<T, 5>()),
|
||||
@ -1356,7 +1359,7 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
{{padding_planes, padding_rows, padding_cols}},
|
||||
dtype,
|
||||
device_id,
|
||||
};
|
||||
conv_desc.group_count()};
|
||||
|
||||
using se::dnn::AlgorithmConfig;
|
||||
using se::dnn::AlgorithmDesc;
|
||||
@ -1608,11 +1611,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
|
||||
dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
|
||||
dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
|
||||
dims.stride(1) == 1 && dims.stride(0) == 1 &&
|
||||
data_format_ == FORMAT_NHWC) {
|
||||
bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth;
|
||||
if (!is_grouped_convolution && dims.filter_size(1) == 1 &&
|
||||
dims.filter_size(2) == 1 && dims.filter_size(0) == 1 &&
|
||||
dims.dilation(2) == 1 && dims.dilation(1) == 1 &&
|
||||
dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 &&
|
||||
dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) {
|
||||
const uint64 m = dims.in_depth;
|
||||
const uint64 k = dims.batch_size * dims.input_size(1) *
|
||||
dims.input_size(2) * dims.input_size(0);
|
||||
@ -1647,7 +1651,8 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (dims.filter_size(0) == dims.input_size(0) &&
|
||||
} else if (!is_grouped_convolution &&
|
||||
dims.filter_size(0) == dims.input_size(0) &&
|
||||
dims.filter_size(1) == dims.input_size(1) &&
|
||||
dims.filter_size(2) == dims.input_size(2) &&
|
||||
padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
|
||||
@ -1727,8 +1732,8 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
|
||||
.set_spatial_dim(DimIndex::Y, dims.filter_size(1))
|
||||
.set_spatial_dim(DimIndex::Z, dims.filter_size(0))
|
||||
.set_input_feature_map_count(dims.in_depth)
|
||||
.set_output_feature_map_count(dims.out_depth);
|
||||
.set_input_feature_map_count(filter_shape.dim_size(3))
|
||||
.set_output_feature_map_count(filter_shape.dim_size(4));
|
||||
se::dnn::ConvolutionDescriptor conv_desc(3);
|
||||
conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
|
||||
.set_dilation_rate(DimIndex::Y, dims.dilation(1))
|
||||
@ -1738,16 +1743,16 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
.set_filter_stride(DimIndex::Z, dims.stride(0))
|
||||
.set_zero_padding(DimIndex::X, padding_cols / 2)
|
||||
.set_zero_padding(DimIndex::Y, padding_rows / 2)
|
||||
.set_zero_padding(DimIndex::Z, padding_planes / 2);
|
||||
|
||||
.set_zero_padding(DimIndex::Z, padding_planes / 2)
|
||||
.set_group_count(dims.in_depth / filter_shape.dim_size(3));
|
||||
Tensor pre_transformed_filter_backprop;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
|
||||
dims.filter_size(1), dims.filter_size(2)}),
|
||||
&pre_transformed_filter_backprop));
|
||||
context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({filter_shape.dim_size(4),
|
||||
filter_shape.dim_size(3), dims.filter_size(0),
|
||||
dims.filter_size(1), dims.filter_size(2)}),
|
||||
&pre_transformed_filter_backprop));
|
||||
|
||||
Tensor transformed_out_backprop;
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
@ -1814,7 +1819,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
{{padding_planes, padding_rows, padding_cols}},
|
||||
dtype,
|
||||
device_id,
|
||||
};
|
||||
conv_desc.group_count()};
|
||||
|
||||
using se::dnn::AlgorithmConfig;
|
||||
using se::dnn::AlgorithmDesc;
|
||||
|
@ -920,25 +920,24 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = input.dtype();
|
||||
ConvParameters conv_parameters = {
|
||||
in_batch, // batch
|
||||
in_depths, // in_depths
|
||||
{{in_rows, // in_rows
|
||||
in_cols}}, // in_cols
|
||||
compute_data_format, // compute_data_format
|
||||
out_depths, // out_depths
|
||||
{{patch_rows, // filter_rows
|
||||
patch_cols, // filter_cols
|
||||
patch_depths}}, // filter_depths
|
||||
{{row_dilation, // dilation_rows
|
||||
col_dilation}}, // dilation_cols
|
||||
{{row_stride, // stride_rows
|
||||
col_stride}}, // stride_cols
|
||||
{{common_padding_rows, // padding_rows
|
||||
common_padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
};
|
||||
ConvParameters conv_parameters = {in_batch, // batch
|
||||
in_depths, // in_depths
|
||||
{{in_rows, // in_rows
|
||||
in_cols}}, // in_cols
|
||||
compute_data_format, // compute_data_format
|
||||
out_depths, // out_depths
|
||||
{{patch_rows, // filter_rows
|
||||
patch_cols, // filter_cols
|
||||
patch_depths}}, // filter_depths
|
||||
{{row_dilation, // dilation_rows
|
||||
col_dilation}}, // dilation_cols
|
||||
{{row_stride, // stride_rows
|
||||
col_stride}}, // stride_cols
|
||||
{{common_padding_rows, // padding_rows
|
||||
common_padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
conv_desc.group_count()};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune &&
|
||||
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
|
||||
|
@ -137,10 +137,13 @@ class Conv3DOp : public BinaryOp<T> {
|
||||
const int64 in_depth = GetTensorDim(input, data_format_, 'C');
|
||||
const int64 in_batch = GetTensorDim(input, data_format_, 'N');
|
||||
|
||||
const int64 filter_depth = filter.dim_size(3);
|
||||
const int64 out_depth = filter.dim_size(4);
|
||||
OP_REQUIRES(
|
||||
context, in_depth == filter.dim_size(3),
|
||||
errors::InvalidArgument("input and filter must have the same depth"));
|
||||
|
||||
OP_REQUIRES(context, in_depth % filter_depth == 0,
|
||||
errors::InvalidArgument(
|
||||
"Input depth must be evenly divisible by filter depth: ",
|
||||
in_depth, " vs ", filter_depth));
|
||||
|
||||
// Dimension order for these arrays is: z, y, x.
|
||||
std::array<int64, 3> input_size = {
|
||||
@ -223,6 +226,7 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
const int64 filter_planes = filter.dim_size(0);
|
||||
const int64 filter_rows = filter.dim_size(1);
|
||||
const int64 filter_cols = filter.dim_size(2);
|
||||
const int64 filter_depth = filter.dim_size(3);
|
||||
const int64 out_depth = filter.dim_size(4);
|
||||
|
||||
int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
|
||||
@ -239,11 +243,13 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
|
||||
}
|
||||
|
||||
bool is_grouped_convolution = filter_depth != in_depth;
|
||||
|
||||
// NOTE: This only works in NHWC.
|
||||
if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 &&
|
||||
dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1 &&
|
||||
strides[0] == 1 && strides[1] == 1 && strides[2] == 1 &&
|
||||
data_format == FORMAT_NHWC) {
|
||||
if (!is_grouped_convolution && filter_planes == 1 && filter_rows == 1 &&
|
||||
filter_cols == 1 && dilations[0] == 1 && dilations[1] == 1 &&
|
||||
dilations[2] == 1 && strides[0] == 1 && strides[1] == 1 &&
|
||||
strides[2] == 1 && data_format == FORMAT_NHWC) {
|
||||
// 1x1 filter, so call cublas directly.
|
||||
const uint64 m = in_batch * in_planes * in_rows * in_cols;
|
||||
const uint64 k = in_depth;
|
||||
@ -267,9 +273,9 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (filter_planes == in_planes && filter_rows == in_rows &&
|
||||
filter_cols == in_cols && padding == Padding::VALID &&
|
||||
data_format == FORMAT_NHWC) {
|
||||
} else if (!is_grouped_convolution && filter_planes == in_planes &&
|
||||
filter_rows == in_rows && filter_cols == in_cols &&
|
||||
padding == Padding::VALID && data_format == FORMAT_NHWC) {
|
||||
// The input data and filter have the same planes/height/width, so call
|
||||
// cublas directly.
|
||||
const uint64 m = in_batch;
|
||||
@ -370,7 +376,7 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
|
||||
.set_spatial_dim(DimIndex::Y, filter_rows)
|
||||
.set_spatial_dim(DimIndex::Z, filter_planes)
|
||||
.set_input_feature_map_count(in_depth)
|
||||
.set_input_feature_map_count(filter_depth)
|
||||
.set_output_feature_map_count(out_depth);
|
||||
se::dnn::ConvolutionDescriptor conv_desc(3);
|
||||
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
|
||||
@ -381,7 +387,8 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
.set_filter_stride(DimIndex::Z, strides[0])
|
||||
.set_zero_padding(DimIndex::X, pad_cols / 2)
|
||||
.set_zero_padding(DimIndex::Y, pad_rows / 2)
|
||||
.set_zero_padding(DimIndex::Z, pad_planes / 2);
|
||||
.set_zero_padding(DimIndex::Z, pad_planes / 2)
|
||||
.set_group_count(in_depth / filter_depth);
|
||||
|
||||
Tensor transformed_filter;
|
||||
OP_REQUIRES_OK(
|
||||
@ -430,7 +437,7 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
{{pad_planes, pad_rows, pad_cols}},
|
||||
dtype,
|
||||
device_id,
|
||||
};
|
||||
conv_desc.group_count()};
|
||||
|
||||
using se::dnn::AlgorithmConfig;
|
||||
using se::dnn::AlgorithmDesc;
|
||||
|
@ -586,25 +586,24 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = input.dtype();
|
||||
FusedConvParameters conv_parameters = {
|
||||
{
|
||||
in_batch, // batch
|
||||
in_depths, // in_depths
|
||||
{{in_rows, // in_rows
|
||||
in_cols}}, // in_cols
|
||||
FORMAT_NCHW, // compute_data_format
|
||||
out_depths, // out_depths
|
||||
{{patch_rows, // filter_rows
|
||||
patch_cols, // filter_cols
|
||||
patch_depths}}, // filter_depths
|
||||
{{dimensions.dilation_rows, // dilation_rows
|
||||
dimensions.dilation_cols}}, // dilation_cols
|
||||
{{dimensions.stride_rows, // stride_rows
|
||||
dimensions.stride_cols}}, // stride_cols
|
||||
{{common_padding_rows, // padding_rows
|
||||
common_padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
},
|
||||
{in_batch, // batch
|
||||
in_depths, // in_depths
|
||||
{{in_rows, // in_rows
|
||||
in_cols}}, // in_cols
|
||||
FORMAT_NCHW, // compute_data_format
|
||||
out_depths, // out_depths
|
||||
{{patch_rows, // filter_rows
|
||||
patch_cols, // filter_cols
|
||||
patch_depths}}, // filter_depths
|
||||
{{dimensions.dilation_rows, // dilation_rows
|
||||
dimensions.dilation_cols}}, // dilation_cols
|
||||
{{dimensions.stride_rows, // stride_rows
|
||||
dimensions.stride_cols}}, // stride_cols
|
||||
{{common_padding_rows, // padding_rows
|
||||
common_padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
conv_desc.group_count()},
|
||||
dnn_activation_mode // activation_mode
|
||||
};
|
||||
|
||||
|
@ -96,7 +96,7 @@ class ConvParameters {
|
||||
TensorFormat data_format, int64 out_depths,
|
||||
const SpatialArray& filter, const SpatialArray& dilation,
|
||||
const SpatialArray& stride, const SpatialArray& padding,
|
||||
DataType dtype, int device_id)
|
||||
DataType dtype, int device_id, int group_count = 1)
|
||||
: batch_(batch),
|
||||
in_depths_(in_depths),
|
||||
out_depths_(out_depths),
|
||||
@ -107,7 +107,8 @@ class ConvParameters {
|
||||
stride_(CheckSpatialArraySize(stride)),
|
||||
padding_(CheckSpatialArraySize(padding)),
|
||||
dtype_(dtype),
|
||||
device_id_(device_id) {
|
||||
device_id_(device_id),
|
||||
group_count_(group_count) {
|
||||
hash_code_ = batch;
|
||||
hash_code_ = Hash64Combine(hash_code_, in_depths);
|
||||
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
|
||||
@ -119,7 +120,9 @@ class ConvParameters {
|
||||
for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val);
|
||||
hash_code_ = Hash64Combine(hash_code_, dtype);
|
||||
hash_code_ = Hash64Combine(hash_code_, device_id);
|
||||
hash_code_ = Hash64Combine(hash_code_, group_count);
|
||||
}
|
||||
|
||||
bool operator==(const ConvParameters& other) const {
|
||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||
}
|
||||
@ -141,7 +144,8 @@ class ConvParameters {
|
||||
"(", str_util::Join(stride_, ", "), "), ",
|
||||
"(", str_util::Join(padding_, ", "), "), ",
|
||||
dtype_, ", ",
|
||||
device_id_);
|
||||
device_id_,
|
||||
group_count_);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -165,12 +169,12 @@ class ConvParameters {
|
||||
protected:
|
||||
using ParameterDataType =
|
||||
std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray,
|
||||
SpatialArray, SpatialArray, SpatialArray, DataType, int>;
|
||||
SpatialArray, SpatialArray, SpatialArray, DataType, int, int>;
|
||||
|
||||
ParameterDataType get_data_as_tuple() const {
|
||||
return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_,
|
||||
filter_, dilation_, stride_, padding_, dtype_,
|
||||
device_id_);
|
||||
device_id_, group_count_);
|
||||
}
|
||||
|
||||
uint64 hash_code_;
|
||||
@ -207,6 +211,7 @@ class ConvParameters {
|
||||
SpatialArray padding_;
|
||||
DataType dtype_;
|
||||
int device_id_;
|
||||
int group_count_;
|
||||
};
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
Loading…
Reference in New Issue
Block a user