Temporarily set cudnn Rnn math precision to fp32.
Problem: When calling cudnnGetRNNLinLayerMatrixParams(), return error CUDNN_STATUS_BAD_PARAM if: * RNN descriptor set math precision = CUDNN_DATA_FLOAT * input descriptor dataType = CUDNN_DATA_HALF * weight descriptor dataType= CUDNN_DATA_HALF If updating Rnn descriptor math precision to CUDNN_DATA_HALF, then no error. cudnn 7.1.4 will fix the problem. PiperOrigin-RevId: 193696566
This commit is contained in:
parent
570d90b9c7
commit
5fbb1feecd
@ -2529,12 +2529,20 @@ cudnnDataType_t GetConvComputeType<double>() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A helper struct to decide whether to use FP32 as the internal compute type
|
// A helper struct to decide whether to use FP32 as the internal compute type
|
||||||
// for rnn when the input data type is FP16. By default it is turned on,
|
// for rnn when the input data type is FP16. At present it is turned off,
|
||||||
// users can explicitly disable them (choose to use FP16 as the internal compute
|
// users can explicitly control them through an env-var
|
||||||
// type) through an env-var "TF_FP16_RNN_USE_FP32_COMPUTE=0".
|
// TF_FP16_RNN_USE_FP32_COMPUTE.
|
||||||
|
// After the TODO below is fixed, users should almost always use fp32 compute
|
||||||
|
// type for training. Using fp16 might suffer suboptimal accuracy due to loss
|
||||||
|
// in precision.
|
||||||
struct RnnDoFP32ComputationFP16Input {
|
struct RnnDoFP32ComputationFP16Input {
|
||||||
static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
|
static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
|
||||||
static constexpr bool kDefaultFlag = true;
|
// TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
|
||||||
|
// Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
|
||||||
|
// precision is set.
|
||||||
|
// Set it temporary to false s.t. no error is raised when using fp16 inputs,
|
||||||
|
// fp32 math precision.
|
||||||
|
static constexpr bool kDefaultFlag = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A helper function to return the internal compute type for
|
// A helper function to return the internal compute type for
|
||||||
|
Loading…
Reference in New Issue
Block a user