Eager loading of CUDNN sub-libraries.
This commit is contained in:
parent
8564160d1f
commit
a6de9b525e
@ -318,6 +318,16 @@ port::Status CudnnSupport::Init() {
|
||||
return port::Status(port::error::INTERNAL, error);
|
||||
}
|
||||
|
||||
// Preload sub libs for cudnn 8.0.4+
|
||||
#if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
|
||||
cudnnOpsInferVersionCheck();
|
||||
cudnnOpsTrainVersionCheck();
|
||||
cudnnCnnInferVersionCheck();
|
||||
cudnnCnnTrainVersionCheck();
|
||||
cudnnAdvInferVersionCheck();
|
||||
cudnnAdvTrainVersionCheck();
|
||||
#endif
|
||||
|
||||
cudnn_.reset(new CudnnAccess(cudnn_handle));
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
@ -55,6 +55,23 @@ cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) {
|
||||
return func_ptr(handle);
|
||||
}
|
||||
|
||||
|
||||
#if CUDNN_MAJOR>=8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
|
||||
cudnnStatus_t CUDNNWINAPI cudnnCnnInferVersionCheck(void) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCnnInferVersionCheck");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr();
|
||||
}
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnCnnTrainVersionCheck(void) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)();
|
||||
static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCnnTrainVersionCheck");
|
||||
if (!func_ptr) return GetSymbolNotFoundError();
|
||||
return func_ptr();
|
||||
}
|
||||
#endif
|
||||
|
||||
cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle,
|
||||
cudaStream_t streamId) {
|
||||
using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t);
|
||||
|
Loading…
Reference in New Issue
Block a user