Eager loading of CUDNN sub-libraries.

This commit is contained in:
Nathan Luehr 2020-10-01 16:57:08 -05:00
parent 8564160d1f
commit a6de9b525e
2 changed files with 27 additions and 0 deletions

View File

@ -318,6 +318,16 @@ port::Status CudnnSupport::Init() {
return port::Status(port::error::INTERNAL, error);
}
// Preload sub libs for cudnn 8.0.4+
#if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
cudnnOpsInferVersionCheck();
cudnnOpsTrainVersionCheck();
cudnnCnnInferVersionCheck();
cudnnCnnTrainVersionCheck();
cudnnAdvInferVersionCheck();
cudnnAdvTrainVersionCheck();
#endif
cudnn_.reset(new CudnnAccess(cudnn_handle));
return port::Status::OK();
}

View File

@ -55,6 +55,23 @@ cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) {
return func_ptr(handle);
}
#if CUDNN_MAJOR>=8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
cudnnStatus_t CUDNNWINAPI cudnnCnnInferVersionCheck(void) {
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)();
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCnnInferVersionCheck");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr();
}
cudnnStatus_t CUDNNWINAPI cudnnCnnTrainVersionCheck(void) {
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)();
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCnnTrainVersionCheck");
if (!func_ptr) return GetSymbolNotFoundError();
return func_ptr();
}
#endif
cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle,
cudaStream_t streamId) {
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t);