Use CUDNN_FMA_MATH to disable TF32 for RNNs.

This commit is contained in:
Nathan Luehr 2021-01-08 15:49:29 -06:00
parent b03f57f650
commit f3544e38c9

View File

@ -1123,8 +1123,13 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
"Algo requests disallowed tensor op evaluation.");
}
#if CUDNN_VERSION >= 8000
cudnnMathType_t math_type =
use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
#else
cudnnMathType_t math_type =
use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
#endif
#if CUDNN_VERSION >= 8000
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;