Added group_count parameter to autotune descriptor

Some cuDNN algorithms only support a certain number of groups. However,
if the group_count is not taken into account when trying to auto_tune
for the best cuDNN algorithm then a grouped convolution might receive a
"cached" algorithm which does not support the amount of groups it
needs.
This commit is contained in:
Dan Ganea 2019-08-09 20:20:58 +02:00
parent 75c2f170c4
commit 65849ef4e0
7 changed files with 17 additions and 5 deletions

View File

@ -965,6 +965,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
common_padding_cols}}, // padding_cols common_padding_cols}}, // padding_cols
dtype, // tensor datatype dtype, // tensor datatype
device_id, // device_id device_id, // device_id
conv_desc.group_count() // group_count
}; };
AlgorithmConfig algorithm_config; AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(

View File

@ -1091,6 +1091,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
common_padding_cols}}, // padding_cols common_padding_cols}}, // padding_cols
dtype, // tensor data type dtype, // tensor data type
device_id, // device_id device_id, // device_id
conv_desc.group_count() // group_count
}; };
AlgorithmConfig algorithm_config; AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(

View File

@ -1352,6 +1352,7 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
{{padding_planes, padding_rows, padding_cols}}, {{padding_planes, padding_rows, padding_cols}},
dtype, dtype,
device_id, device_id,
conv_desc.group_count()
}; };
using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmConfig;
@ -1777,6 +1778,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
{{padding_planes, padding_rows, padding_cols}}, {{padding_planes, padding_rows, padding_cols}},
dtype, dtype,
device_id, device_id,
conv_desc.group_count()
}; };
using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmConfig;

View File

@ -984,6 +984,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
common_padding_cols}}, // padding_cols common_padding_cols}}, // padding_cols
dtype, // tensor datatype dtype, // tensor datatype
device_id, // device_id device_id, // device_id
conv_desc.group_count()
}; };
AlgorithmConfig algorithm_config; AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && if (cudnn_use_autotune &&

View File

@ -425,6 +425,7 @@ struct LaunchConvOp<GPUDevice, T> {
{{pad_planes, pad_rows, pad_cols}}, {{pad_planes, pad_rows, pad_cols}},
dtype, dtype,
device_id, device_id,
conv_desc.group_count()
}; };
using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmConfig;

View File

@ -581,6 +581,7 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
common_padding_cols}}, // padding_cols common_padding_cols}}, // padding_cols
dtype, // tensor datatype dtype, // tensor datatype
device_id, // device_id device_id, // device_id
conv_desc.group_count()
}, },
dnn_activation_mode // activation_mode dnn_activation_mode // activation_mode
}; };

View File

@ -97,7 +97,7 @@ class ConvParameters {
TensorFormat data_format, int64 out_depths, TensorFormat data_format, int64 out_depths,
const SpatialArray& filter, const SpatialArray& dilation, const SpatialArray& filter, const SpatialArray& dilation,
const SpatialArray& stride, const SpatialArray& padding, const SpatialArray& stride, const SpatialArray& padding,
DataType dtype, int device_id) DataType dtype, int device_id, int group_count = 1)
: batch_(batch), : batch_(batch),
in_depths_(in_depths), in_depths_(in_depths),
out_depths_(out_depths), out_depths_(out_depths),
@ -108,7 +108,8 @@ class ConvParameters {
stride_(CheckSpatialArraySize(stride)), stride_(CheckSpatialArraySize(stride)),
padding_(CheckSpatialArraySize(padding)), padding_(CheckSpatialArraySize(padding)),
dtype_(dtype), dtype_(dtype),
device_id_(device_id) { device_id_(device_id),
group_count_(group_count) {
hash_code_ = batch; hash_code_ = batch;
hash_code_ = Hash64Combine(hash_code_, in_depths); hash_code_ = Hash64Combine(hash_code_, in_depths);
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val); for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
@ -120,7 +121,9 @@ class ConvParameters {
for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val); for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val);
hash_code_ = Hash64Combine(hash_code_, dtype); hash_code_ = Hash64Combine(hash_code_, dtype);
hash_code_ = Hash64Combine(hash_code_, device_id); hash_code_ = Hash64Combine(hash_code_, device_id);
hash_code_ = Hash64Combine(hash_code_, group_count);
} }
bool operator==(const ConvParameters& other) const { bool operator==(const ConvParameters& other) const {
return this->get_data_as_tuple() == other.get_data_as_tuple(); return this->get_data_as_tuple() == other.get_data_as_tuple();
} }
@ -142,7 +145,8 @@ class ConvParameters {
"(", str_util::Join(stride_, ", "), "), ", "(", str_util::Join(stride_, ", "), "), ",
"(", str_util::Join(padding_, ", "), "), ", "(", str_util::Join(padding_, ", "), "), ",
dtype_, ", ", dtype_, ", ",
device_id_); device_id_,
group_count_);
// clang-format on // clang-format on
} }
@ -166,12 +170,12 @@ class ConvParameters {
protected: protected:
using ParameterDataType = using ParameterDataType =
std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray, std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray,
SpatialArray, SpatialArray, SpatialArray, DataType, int>; SpatialArray, SpatialArray, SpatialArray, DataType, int, int>;
ParameterDataType get_data_as_tuple() const { ParameterDataType get_data_as_tuple() const {
return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_, return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_,
filter_, dilation_, stride_, padding_, dtype_, filter_, dilation_, stride_, padding_, dtype_,
device_id_); device_id_, group_count_);
} }
uint64 hash_code_; uint64 hash_code_;
@ -208,6 +212,7 @@ class ConvParameters {
SpatialArray padding_; SpatialArray padding_;
DataType dtype_; DataType dtype_;
int device_id_; int device_id_;
int group_count_;
}; };
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;