From b00a7808a7b29a78762b54e29aac87a77254b4b6 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 22 Jun 2020 03:58:58 -0700 Subject: [PATCH] Add extra header file for cuDNN 8. PiperOrigin-RevId: 317626279 Change-Id: I99b969a73555932b25081f37b64f71ac6de662d6 --- third_party/gpus/cuda_configure.bzl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 3374965f415..a192c022a47 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1081,17 +1081,16 @@ def _create_local_cuda_repository(repository_ctx): )) # Select the headers based on the cuDNN version (strip '64_' for Windows). - if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8": - cudnn_headers = ["cudnn.h"] - else: - cudnn_headers = [ + cudnn_headers = ["cudnn.h"] + if cuda_config.cudnn_version.rsplit("_", 1)[0] >= "8": + cudnn_headers += [ + "cudnn_backend.h", "cudnn_adv_infer.h", "cudnn_adv_train.h", "cudnn_cnn_infer.h", "cudnn_cnn_train.h", "cudnn_ops_infer.h", "cudnn_ops_train.h", - "cudnn.h", "cudnn_version.h", ]