diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 809fb1c956b..bd8aa4bacbf 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/strcat.h" +#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -259,6 +260,9 @@ cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmType algorithm) { case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING: #if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD: +#endif +#if CUDNN_VERSION >= 5100 + case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED: #endif return algo; default: @@ -277,6 +281,9 @@ cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo( case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING: #if CUDNN_VERSION >= 5000 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD: +#endif +#if CUDNN_VERSION >= 5100 + case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED: #endif return algo; default: @@ -295,6 +302,11 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT: case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3: +#if CUDNN_VERSION >= 5100 + // Based on cudnn.h, the following is not implemented. + // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD: + case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED: +#endif return algo; default: LOG(FATAL) @@ -1952,6 +1964,33 @@ bool CudnnSupport::DoConvolveImpl( return true; } +// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms. +// Doing so by default make a few TensorFlow test cases to fail. Users can +// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1". +// https://github.com/tensorflow/tensorflow/pull/4901 +class WinogradNonfused { + public: + static bool IsEnabled() { + static bool is_enabled = IsEnabledImpl(); + return is_enabled; + } + + private: + static bool IsEnabledImpl() { + const char* tf_env_var_val = getenv("TF_ENABLE_WINOGRAD_NONFUSED"); + if (tf_env_var_val != nullptr) { + port::StringPiece tf_env_var_val_str(tf_env_var_val); + if (tf_env_var_val_str == "0") { + return false; + } + return true; + } + // TODO(zhengxq): turn the default to True when the test failure is + // resolved. + return false; + } +}; + bool CudnnSupport::GetConvolveAlgorithms( std::vector* out_algorithms) { out_algorithms->assign({ @@ -1967,6 +2006,11 @@ bool CudnnSupport::GetConvolveAlgorithms( #endif // clang-format on }); +#if CUDNN_VERSION >= 5100 + if (WinogradNonfused::IsEnabled()) { + out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED); + } +#endif return true; } @@ -1983,6 +2027,12 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( #endif // clang-format on }); +#if CUDNN_VERSION >= 5100 + if (WinogradNonfused::IsEnabled()) { + out_algorithms->push_back( + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED); + } +#endif return true; } @@ -1996,6 +2046,14 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, // clang-format on }); +#if CUDNN_VERSION >= 5100 + if (WinogradNonfused::IsEnabled()) { + out_algorithms->push_back( + // Based on cudnn.h, the following is not implemented. + // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED); + } +#endif return true; }