Add dilation rates support for ConvolutionDescriptor...
...in stream executor. In preparation for the support of native cudnn dilated convolution. PiperOrigin-RevId: 171165137
This commit is contained in:
parent
8dc5e3718b
commit
b0e751a73d
@ -583,6 +583,7 @@ class ScopedConvolutionDescriptor {
|
||||
}
|
||||
const auto& strides64 = convolution_descriptor.strides();
|
||||
const auto& padding64 = convolution_descriptor.padding();
|
||||
const auto& dilations64 = convolution_descriptor.dilations();
|
||||
if (convolution_descriptor.pad_alignment() ==
|
||||
dnn::PadAlignment::kTensorFlowPadding) {
|
||||
LOG(ERROR) << "TensorFlow padding alignment is not supported.";
|
||||
@ -591,15 +592,19 @@ class ScopedConvolutionDescriptor {
|
||||
// cuDNN requires arrays of ints.
|
||||
std::vector<int> strides(convolution_descriptor.ndims());
|
||||
std::vector<int> padding(convolution_descriptor.ndims());
|
||||
std::vector<int> dilations(convolution_descriptor.ndims());
|
||||
std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
|
||||
&CheckedNarrowing<int64, int>);
|
||||
std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
|
||||
&CheckedNarrowing<int64, int>);
|
||||
std::vector<int> upscale(convolution_descriptor.ndims(), 1);
|
||||
// TODO(yangzihao): Test with negative dilation to make sure that cudnn
|
||||
// doesn't crash.
|
||||
std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
|
||||
&CheckedNarrowing<int64, int>);
|
||||
|
||||
status = wrap::cudnnSetConvolutionNdDescriptor(
|
||||
parent_, handle_, convolution_descriptor.ndims(), padding.data(),
|
||||
strides.data(), upscale.data(),
|
||||
strides.data(), dilations.data(),
|
||||
// NOTE(keveman): cuDNN supports convolution and cross correlation.
|
||||
// However, almost all the use cases do cross correlation, so just
|
||||
// hard coding it here.
|
||||
@ -2982,7 +2987,6 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
|
||||
if (memory_limit_bytes < 0) {
|
||||
memory_limit_bytes = 0;
|
||||
}
|
||||
|
||||
cudnnConvolutionBwdDataAlgo_t algo_to_use;
|
||||
cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardDataAlgorithm(
|
||||
parent_, ToHandle(dnn_handle_),
|
||||
@ -2995,7 +2999,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
|
||||
/*algo=*/&algo_to_use);
|
||||
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
|
||||
"algorithm for doing backward "
|
||||
"filter convolution";
|
||||
"data convolution";
|
||||
return algo_to_use;
|
||||
};
|
||||
|
||||
|
@ -424,6 +424,7 @@ int64 FilterDescriptor::ComputeWeightCount() const {
|
||||
ConvolutionDescriptor::ConvolutionDescriptor(int ndims)
|
||||
: zero_padding_(ndims, 0),
|
||||
filter_strides_(ndims, 1),
|
||||
dilation_rates_(ndims, 1),
|
||||
pad_alignment_(PadAlignment::kDefault),
|
||||
ndims_(ndims) {}
|
||||
|
||||
@ -435,15 +436,18 @@ ConvolutionDescriptor::~ConvolutionDescriptor() {}
|
||||
string ConvolutionDescriptor::ToString() const {
|
||||
string padding;
|
||||
string strides;
|
||||
string dilations;
|
||||
for (int i = 0; i < ndims_; i++) {
|
||||
port::Appendf(&padding, "%lld ", zero_padding_[i]);
|
||||
port::Appendf(&strides, "%lld ", filter_strides_[i]);
|
||||
port::Appendf(&dilations, "%lld ", dilation_rates_[i]);
|
||||
}
|
||||
|
||||
return port::Printf("{zero_padding: %s pad_alignment: %s filter_strides: %s}",
|
||||
padding.c_str(),
|
||||
PadAlignmentString(pad_alignment_).c_str(),
|
||||
strides.c_str());
|
||||
return port::Printf(
|
||||
"{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
|
||||
"%s}",
|
||||
padding.c_str(), PadAlignmentString(pad_alignment_).c_str(),
|
||||
strides.c_str(), dilations.c_str());
|
||||
}
|
||||
|
||||
string ConvolutionDescriptor::ToShortString() const {
|
||||
@ -455,6 +459,9 @@ string ConvolutionDescriptor::ToShortString() const {
|
||||
for (int i = 0; i < ndims_; i++) {
|
||||
port::Appendf(&desc, "_s%d:%lld", i, filter_strides_[i]);
|
||||
}
|
||||
for (int i = 0; i < ndims_; i++) {
|
||||
port::Appendf(&desc, "_d%d:%lld", i, dilation_rates_[i]);
|
||||
}
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
@ -487,6 +487,10 @@ string PadAlignmentString(PadAlignment alignment);
|
||||
// window is moved in the "y dimension" according to this stride value.
|
||||
// - horizontal_filter_stride: analogous to the vertical stride above, but in
|
||||
// the "x dimension".
|
||||
// - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
|
||||
// cells between each filter element in the "y dimension".
|
||||
// - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
|
||||
// skipped cells between each filter element in the "x dimension".
|
||||
class ConvolutionDescriptor {
|
||||
public:
|
||||
// By default construction, there is no zero-padding and the filter stride is
|
||||
@ -523,6 +527,18 @@ class ConvolutionDescriptor {
|
||||
SetDim(&filter_strides_, dim, value);
|
||||
return *this;
|
||||
}
|
||||
ConvolutionDescriptor& set_vertical_dilation_rate(int64 value) {
|
||||
SetDim(&dilation_rates_, DimIndex::Y, value);
|
||||
return *this;
|
||||
}
|
||||
ConvolutionDescriptor& set_horizontal_dilation_rate(int64 value) {
|
||||
SetDim(&dilation_rates_, DimIndex::X, value);
|
||||
return *this;
|
||||
}
|
||||
ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64 value) {
|
||||
SetDim(&dilation_rates_, dim, value);
|
||||
return *this;
|
||||
}
|
||||
ConvolutionDescriptor& set_pad_alignment(PadAlignment pad_alignment) {
|
||||
pad_alignment_ = pad_alignment;
|
||||
return *this;
|
||||
@ -539,19 +555,28 @@ class ConvolutionDescriptor {
|
||||
int64 horizontal_filter_stride() const {
|
||||
return GetDim(filter_strides_, DimIndex::X);
|
||||
}
|
||||
int64 vertical_dilation_rate() const {
|
||||
return GetDim(dilation_rates_, DimIndex::Y);
|
||||
}
|
||||
int64 horizontal_dilation_rate() const {
|
||||
return GetDim(dilation_rates_, DimIndex::X);
|
||||
}
|
||||
|
||||
int zero_padding(DimIndex dim) const { return GetDim(zero_padding_, dim); }
|
||||
int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); }
|
||||
int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); }
|
||||
PadAlignment pad_alignment() const { return pad_alignment_; }
|
||||
int ndims() const { return ndims_; }
|
||||
|
||||
std::vector<int64> strides() const { return filter_strides_; }
|
||||
std::vector<int64> dilations() const { return dilation_rates_; }
|
||||
std::vector<int64> padding() const { return zero_padding_; }
|
||||
|
||||
private:
|
||||
// Stored as: .. y, x.
|
||||
std::vector<int64> zero_padding_;
|
||||
std::vector<int64> filter_strides_;
|
||||
std::vector<int64> dilation_rates_;
|
||||
PadAlignment pad_alignment_;
|
||||
int ndims_;
|
||||
// TODO(leary) cudnn provides these fields, but need to characterize what
|
||||
|
Loading…
Reference in New Issue
Block a user