Adding support for non-fused Winograd algorithm from Cudnn 5.1.
Enabling this by default makes some unit tests to fail. So adding an env-var "TF_ENABLE_WINOGRAD_NONFUSED" so users can explicitly choose to enable. Change: 146763809
This commit is contained in:
parent
e62ef90649
commit
764507ea71
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/error.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/strcat.h"
|
||||
#include "tensorflow/stream_executor/lib/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
@ -259,6 +260,9 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) {
|
||||
case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
|
||||
#if CUDNN_VERSION >= 5000
|
||||
case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
|
||||
#endif
|
||||
#if CUDNN_VERSION >= 5100
|
||||
case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
|
||||
#endif
|
||||
return algo;
|
||||
default:
|
||||
@ -277,6 +281,9 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
|
||||
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
|
||||
#if CUDNN_VERSION >= 5000
|
||||
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
|
||||
#endif
|
||||
#if CUDNN_VERSION >= 5100
|
||||
case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
|
||||
#endif
|
||||
return algo;
|
||||
default:
|
||||
@ -295,6 +302,11 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
|
||||
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
|
||||
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
|
||||
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
|
||||
#if CUDNN_VERSION >= 5100
|
||||
// Based on cudnn.h, the following is not implemented.
|
||||
// case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
|
||||
case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
|
||||
#endif
|
||||
return algo;
|
||||
default:
|
||||
LOG(FATAL)
|
||||
@ -1952,6 +1964,33 @@ bool CudnnSupport::DoConvolveImpl(
|
||||
return true;
|
||||
}
|
||||
|
||||
// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
|
||||
// Doing so by default make a few TensorFlow test cases to fail. Users can
|
||||
// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1".
|
||||
// https://github.com/tensorflow/tensorflow/pull/4901
|
||||
class WinogradNonfused {
|
||||
public:
|
||||
static bool IsEnabled() {
|
||||
static bool is_enabled = IsEnabledImpl();
|
||||
return is_enabled;
|
||||
}
|
||||
|
||||
private:
|
||||
static bool IsEnabledImpl() {
|
||||
const char* tf_env_var_val = getenv("TF_ENABLE_WINOGRAD_NONFUSED");
|
||||
if (tf_env_var_val != nullptr) {
|
||||
port::StringPiece tf_env_var_val_str(tf_env_var_val);
|
||||
if (tf_env_var_val_str == "0") {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
// TODO(zhengxq): turn the default to True when the test failure is
|
||||
// resolved.
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
bool CudnnSupport::GetConvolveAlgorithms(
|
||||
std::vector<dnn::AlgorithmType>* out_algorithms) {
|
||||
out_algorithms->assign({
|
||||
@ -1967,6 +2006,11 @@ bool CudnnSupport::GetConvolveAlgorithms(
|
||||
#endif
|
||||
// clang-format on
|
||||
});
|
||||
#if CUDNN_VERSION >= 5100
|
||||
if (WinogradNonfused::IsEnabled()) {
|
||||
out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1983,6 +2027,12 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
||||
#endif
|
||||
// clang-format on
|
||||
});
|
||||
#if CUDNN_VERSION >= 5100
|
||||
if (WinogradNonfused::IsEnabled()) {
|
||||
out_algorithms->push_back(
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1996,6 +2046,14 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
||||
// clang-format on
|
||||
});
|
||||
#if CUDNN_VERSION >= 5100
|
||||
if (WinogradNonfused::IsEnabled()) {
|
||||
out_algorithms->push_back(
|
||||
// Based on cudnn.h, the following is not implemented.
|
||||
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user