diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 3374965f415..a192c022a47 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1081,17 +1081,16 @@ def _create_local_cuda_repository(repository_ctx): )) # Select the headers based on the cuDNN version (strip '64_' for Windows). - if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8": - cudnn_headers = ["cudnn.h"] - else: - cudnn_headers = [ + cudnn_headers = ["cudnn.h"] + if cuda_config.cudnn_version.rsplit("_", 1)[0] >= "8": + cudnn_headers += [ + "cudnn_backend.h", "cudnn_adv_infer.h", "cudnn_adv_train.h", "cudnn_cnn_infer.h", "cudnn_cnn_train.h", "cudnn_ops_infer.h", "cudnn_ops_train.h", - "cudnn.h", "cudnn_version.h", ]