diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 28902b65722..902a48da1a5 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1091,12 +1091,55 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { if (use_projection) { unified_size = cell_size; } + + // Require explicit algorithm config to enable tensor cores. Some configs + // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against + // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799). + // We can only reasonably expect the user to handle the subsequent failure + // in profile mode, which is run with algorithms returned from + // GetRnnAlgorithms() (which are non-default and explicitly set whether to + // use tensor ops). CuDNN 7.2.1 fixed this issue. + // TODO(csigg): Minimal support cuDNN version is 7.3, clean up. + bool allow_tensor_ops = + data_type != CUDNN_DATA_FLOAT || tensorflow::tf32_execution_allowed(); + bool use_tensor_ops = + algorithm_config.algorithm().has_value() + ? algorithm_config.algorithm()->tensor_ops_enabled() + : allow_tensor_ops; + if (use_tensor_ops && !allow_tensor_ops) { + return port::Status(port::error::INVALID_ARGUMENT, + "Algo requests disallowed tensor op evaluation."); + } + +#if CUDNN_VERSION >= 8000 + cudnnMathType_t math_type = + use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; + cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS; + uint32_t aux_flags = 0; + if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED; + RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8( + /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode, + /*biasMode=*/bias_mode, /*dirMode=*/direction_mode, + /*inputMode=*/input_mode, + /*dataType=*/data_type, /*mathPrec=*/compute_type, + /*mathType=*/math_type, + /*inputSize=*/input_size, + /*hiddenSize=*/hidden_size, /*projSize=*/cell_size, + /*numLayers=*/num_layers, + /*dropoutDesc=*/dropout_desc.handle(), + /*auxFlags=*/aux_flags)); +#else + cudnnMathType_t math_type = + use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/unified_size, /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo, /*dataType=*/compute_type)); + CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); +#endif + if (use_projection) { RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers( cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), @@ -1132,35 +1175,6 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { cudnn, input_size, data_type, rnn_desc.get(), rnn_mode, direction_mode, num_layers)); - // Require explicit algorithm config to enable tensor cores. Some configs - // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against - // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799). - // We can only reasonably expect the user to handle the subsequent failure - // in profile mode, which is run with algorithms returned from - // GetRnnAlgorithms() (which are non-default and explicitly set whether to - // use tensor ops). CuDNN 7.2.1 fixed this issue - bool allow_tensor_ops = - data_type != CUDNN_DATA_FLOAT || tensorflow::tf32_execution_allowed(); - bool use_tensor_ops; - if (algorithm_config.algorithm().has_value()) { - use_tensor_ops = algorithm_config.algorithm()->tensor_ops_enabled(); - } else { - use_tensor_ops = allow_tensor_ops; - } - - if (use_tensor_ops && !allow_tensor_ops) { - return port::Status(port::error::INVALID_ARGUMENT, - "Algo requests disallowed tensor op evaluation."); - } - - cudnnMathType_t math_type; -#if CUDNN_VERSION >= 8000 - math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; -#else - math_type = use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; -#endif - CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); - return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), num_layers, hidden_size, input_size, cell_size, batch_size, input_mode, direction_mode, rnn_mode,