Merge pull request #29147 from houtoms:pr_fix_rnn_tensor_core_use
PiperOrigin-RevId: 250732517
This commit is contained in:
commit
40275ad3aa
@ -1087,15 +1087,23 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
// We can only reasonably expect the user to handle the subsequent failure
|
||||
// in profile mode, which is run with algorithms returned from
|
||||
// GetRnnAlgorithms() (which are non-default and explicitly set whether to
|
||||
// use tensor ops).
|
||||
if (RnnTensorOpMathEnabled() && algorithm_config.algorithm().has_value()) {
|
||||
cudnnMathType_t math_type =
|
||||
algorithm_config.algorithm()->tensor_ops_enabled()
|
||||
? CUDNN_TENSOR_OP_MATH
|
||||
: CUDNN_DEFAULT_MATH;
|
||||
// use tensor ops). CuDNN 7.2.1 fixed this issue
|
||||
if (RnnTensorOpMathEnabled()) {
|
||||
cudnnMathType_t math_type;
|
||||
if (algorithm_config.algorithm().has_value()) {
|
||||
math_type = algorithm_config.algorithm()->tensor_ops_enabled()
|
||||
? CUDNN_TENSOR_OP_MATH
|
||||
: CUDNN_DEFAULT_MATH;
|
||||
} else {
|
||||
#if CUDNN_VERSION >= 7201
|
||||
math_type = CUDNN_TENSOR_OP_MATH;
|
||||
#else
|
||||
math_type = CUDNN_DEFAULT_MATH;
|
||||
#endif // CUDNN_VERSION >= 7201
|
||||
}
|
||||
CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
|
||||
}
|
||||
#endif
|
||||
#endif // CUDNN_VERSION >= 7000
|
||||
|
||||
return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
|
||||
num_layers, hidden_size, input_size, batch_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user