Temporarily set cudnn Rnn math precision to fp32.

Problem:
When calling cudnnGetRNNLinLayerMatrixParams(), return error CUDNN_STATUS_BAD_PARAM if:

* RNN descriptor set math precision = CUDNN_DATA_FLOAT
* input descriptor dataType = CUDNN_DATA_HALF
* weight descriptor dataType= CUDNN_DATA_HALF

If updating Rnn descriptor math precision to CUDNN_DATA_HALF, then no error.

cudnn 7.1.4 will fix the problem.

PiperOrigin-RevId: 193696566
This commit is contained in:
James Qin 2018-04-20 11:23:29 -07:00 committed by TensorFlower Gardener
parent 570d90b9c7
commit 5fbb1feecd

View File

@ -2529,12 +2529,20 @@ cudnnDataType_t GetConvComputeType<double>() {
} }
// A helper struct to decide whether to use FP32 as the internal compute type // A helper struct to decide whether to use FP32 as the internal compute type
// for rnn when the input data type is FP16. By default it is turned on, // for rnn when the input data type is FP16. At present it is turned off,
// users can explicitly disable them (choose to use FP16 as the internal compute // users can explicitly control them through an env-var
// type) through an env-var "TF_FP16_RNN_USE_FP32_COMPUTE=0". // TF_FP16_RNN_USE_FP32_COMPUTE.
// After the TODO below is fixed, users should almost always use fp32 compute
// type for training. Using fp16 might suffer suboptimal accuracy due to loss
// in precision.
struct RnnDoFP32ComputationFP16Input { struct RnnDoFP32ComputationFP16Input {
static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE"; static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
static constexpr bool kDefaultFlag = true; // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
// Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
// precision is set.
// Set it temporary to false s.t. no error is raised when using fp16 inputs,
// fp32 math precision.
static constexpr bool kDefaultFlag = false;
}; };
// A helper function to return the internal compute type for // A helper function to return the internal compute type for