Added group_count parameter to autotune descriptor
Some cuDNN algorithms only support a certain number of groups. However, if the group_count is not taken into account when trying to auto_tune for the best cuDNN algorithm then a grouped convolution might receive a "cached" algorithm which does not support the amount of groups it needs.
This commit is contained in:
parent
75c2f170c4
commit
65849ef4e0
@ -965,6 +965,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
|||||||
common_padding_cols}}, // padding_cols
|
common_padding_cols}}, // padding_cols
|
||||||
dtype, // tensor datatype
|
dtype, // tensor datatype
|
||||||
device_id, // device_id
|
device_id, // device_id
|
||||||
|
conv_desc.group_count() // group_count
|
||||||
};
|
};
|
||||||
AlgorithmConfig algorithm_config;
|
AlgorithmConfig algorithm_config;
|
||||||
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
|
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
|
||||||
|
@ -1091,6 +1091,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
|||||||
common_padding_cols}}, // padding_cols
|
common_padding_cols}}, // padding_cols
|
||||||
dtype, // tensor data type
|
dtype, // tensor data type
|
||||||
device_id, // device_id
|
device_id, // device_id
|
||||||
|
conv_desc.group_count() // group_count
|
||||||
};
|
};
|
||||||
AlgorithmConfig algorithm_config;
|
AlgorithmConfig algorithm_config;
|
||||||
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
|
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
|
||||||
|
@ -1352,6 +1352,7 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
|||||||
{{padding_planes, padding_rows, padding_cols}},
|
{{padding_planes, padding_rows, padding_cols}},
|
||||||
dtype,
|
dtype,
|
||||||
device_id,
|
device_id,
|
||||||
|
conv_desc.group_count()
|
||||||
};
|
};
|
||||||
|
|
||||||
using se::dnn::AlgorithmConfig;
|
using se::dnn::AlgorithmConfig;
|
||||||
@ -1777,6 +1778,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
|||||||
{{padding_planes, padding_rows, padding_cols}},
|
{{padding_planes, padding_rows, padding_cols}},
|
||||||
dtype,
|
dtype,
|
||||||
device_id,
|
device_id,
|
||||||
|
conv_desc.group_count()
|
||||||
};
|
};
|
||||||
|
|
||||||
using se::dnn::AlgorithmConfig;
|
using se::dnn::AlgorithmConfig;
|
||||||
|
@ -984,6 +984,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
common_padding_cols}}, // padding_cols
|
common_padding_cols}}, // padding_cols
|
||||||
dtype, // tensor datatype
|
dtype, // tensor datatype
|
||||||
device_id, // device_id
|
device_id, // device_id
|
||||||
|
conv_desc.group_count()
|
||||||
};
|
};
|
||||||
AlgorithmConfig algorithm_config;
|
AlgorithmConfig algorithm_config;
|
||||||
if (cudnn_use_autotune &&
|
if (cudnn_use_autotune &&
|
||||||
|
@ -425,6 +425,7 @@ struct LaunchConvOp<GPUDevice, T> {
|
|||||||
{{pad_planes, pad_rows, pad_cols}},
|
{{pad_planes, pad_rows, pad_cols}},
|
||||||
dtype,
|
dtype,
|
||||||
device_id,
|
device_id,
|
||||||
|
conv_desc.group_count()
|
||||||
};
|
};
|
||||||
|
|
||||||
using se::dnn::AlgorithmConfig;
|
using se::dnn::AlgorithmConfig;
|
||||||
|
@ -581,6 +581,7 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
|||||||
common_padding_cols}}, // padding_cols
|
common_padding_cols}}, // padding_cols
|
||||||
dtype, // tensor datatype
|
dtype, // tensor datatype
|
||||||
device_id, // device_id
|
device_id, // device_id
|
||||||
|
conv_desc.group_count()
|
||||||
},
|
},
|
||||||
dnn_activation_mode // activation_mode
|
dnn_activation_mode // activation_mode
|
||||||
};
|
};
|
||||||
|
@ -97,7 +97,7 @@ class ConvParameters {
|
|||||||
TensorFormat data_format, int64 out_depths,
|
TensorFormat data_format, int64 out_depths,
|
||||||
const SpatialArray& filter, const SpatialArray& dilation,
|
const SpatialArray& filter, const SpatialArray& dilation,
|
||||||
const SpatialArray& stride, const SpatialArray& padding,
|
const SpatialArray& stride, const SpatialArray& padding,
|
||||||
DataType dtype, int device_id)
|
DataType dtype, int device_id, int group_count = 1)
|
||||||
: batch_(batch),
|
: batch_(batch),
|
||||||
in_depths_(in_depths),
|
in_depths_(in_depths),
|
||||||
out_depths_(out_depths),
|
out_depths_(out_depths),
|
||||||
@ -108,7 +108,8 @@ class ConvParameters {
|
|||||||
stride_(CheckSpatialArraySize(stride)),
|
stride_(CheckSpatialArraySize(stride)),
|
||||||
padding_(CheckSpatialArraySize(padding)),
|
padding_(CheckSpatialArraySize(padding)),
|
||||||
dtype_(dtype),
|
dtype_(dtype),
|
||||||
device_id_(device_id) {
|
device_id_(device_id),
|
||||||
|
group_count_(group_count) {
|
||||||
hash_code_ = batch;
|
hash_code_ = batch;
|
||||||
hash_code_ = Hash64Combine(hash_code_, in_depths);
|
hash_code_ = Hash64Combine(hash_code_, in_depths);
|
||||||
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
|
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
|
||||||
@ -120,7 +121,9 @@ class ConvParameters {
|
|||||||
for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val);
|
for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val);
|
||||||
hash_code_ = Hash64Combine(hash_code_, dtype);
|
hash_code_ = Hash64Combine(hash_code_, dtype);
|
||||||
hash_code_ = Hash64Combine(hash_code_, device_id);
|
hash_code_ = Hash64Combine(hash_code_, device_id);
|
||||||
|
hash_code_ = Hash64Combine(hash_code_, group_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const ConvParameters& other) const {
|
bool operator==(const ConvParameters& other) const {
|
||||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||||
}
|
}
|
||||||
@ -142,7 +145,8 @@ class ConvParameters {
|
|||||||
"(", str_util::Join(stride_, ", "), "), ",
|
"(", str_util::Join(stride_, ", "), "), ",
|
||||||
"(", str_util::Join(padding_, ", "), "), ",
|
"(", str_util::Join(padding_, ", "), "), ",
|
||||||
dtype_, ", ",
|
dtype_, ", ",
|
||||||
device_id_);
|
device_id_,
|
||||||
|
group_count_);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -166,12 +170,12 @@ class ConvParameters {
|
|||||||
protected:
|
protected:
|
||||||
using ParameterDataType =
|
using ParameterDataType =
|
||||||
std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray,
|
std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray,
|
||||||
SpatialArray, SpatialArray, SpatialArray, DataType, int>;
|
SpatialArray, SpatialArray, SpatialArray, DataType, int, int>;
|
||||||
|
|
||||||
ParameterDataType get_data_as_tuple() const {
|
ParameterDataType get_data_as_tuple() const {
|
||||||
return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_,
|
return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_,
|
||||||
filter_, dilation_, stride_, padding_, dtype_,
|
filter_, dilation_, stride_, padding_, dtype_,
|
||||||
device_id_);
|
device_id_, group_count_);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64 hash_code_;
|
uint64 hash_code_;
|
||||||
@ -208,6 +212,7 @@ class ConvParameters {
|
|||||||
SpatialArray padding_;
|
SpatialArray padding_;
|
||||||
DataType dtype_;
|
DataType dtype_;
|
||||||
int device_id_;
|
int device_id_;
|
||||||
|
int group_count_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user