Use CUDNN_FMA_MATH to disable TF32 for RNNs.
This commit is contained in:
parent
b03f57f650
commit
f3544e38c9
@ -1123,8 +1123,13 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
"Algo requests disallowed tensor op evaluation.");
|
"Algo requests disallowed tensor op evaluation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 8000
|
||||||
|
cudnnMathType_t math_type =
|
||||||
|
use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
|
||||||
|
#else
|
||||||
cudnnMathType_t math_type =
|
cudnnMathType_t math_type =
|
||||||
use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
|
use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
|
||||||
|
#endif
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 8000
|
#if CUDNN_VERSION >= 8000
|
||||||
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
|
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user