From 5fbb1feecd77a70b32d333b56bd13b1798b9a766 Mon Sep 17 00:00:00 2001 From: James Qin Date: Fri, 20 Apr 2018 11:23:29 -0700 Subject: [PATCH] Temporarily set cudnn Rnn math precision to fp32. Problem: When calling cudnnGetRNNLinLayerMatrixParams(), return error CUDNN_STATUS_BAD_PARAM if: * RNN descriptor set math precision = CUDNN_DATA_FLOAT * input descriptor dataType = CUDNN_DATA_HALF * weight descriptor dataType= CUDNN_DATA_HALF If updating Rnn descriptor math precision to CUDNN_DATA_HALF, then no error. cudnn 7.1.4 will fix the problem. PiperOrigin-RevId: 193696566 --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index d673e19007d..640f270323c 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -2529,12 +2529,20 @@ cudnnDataType_t GetConvComputeType() { } // A helper struct to decide whether to use FP32 as the internal compute type -// for rnn when the input data type is FP16. By default it is turned on, -// users can explicitly disable them (choose to use FP16 as the internal compute -// type) through an env-var "TF_FP16_RNN_USE_FP32_COMPUTE=0". +// for rnn when the input data type is FP16. At present it is turned off, +// users can explicitly control them through an env-var +// TF_FP16_RNN_USE_FP32_COMPUTE. +// After the TODO below is fixed, users should almost always use fp32 compute +// type for training. Using fp16 might suffer suboptimal accuracy due to loss +// in precision. struct RnnDoFP32ComputationFP16Input { static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE"; - static constexpr bool kDefaultFlag = true; + // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug. + // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math + // precision is set. + // Set it temporary to false s.t. no error is raised when using fp16 inputs, + // fp32 math precision. + static constexpr bool kDefaultFlag = false; }; // A helper function to return the internal compute type for