Add extra header file for cuDNN 8.

PiperOrigin-RevId: 317626279
Change-Id: I99b969a73555932b25081f37b64f71ac6de662d6
This commit is contained in:
Christian Sigg 2020-06-22 03:58:58 -07:00 committed by TensorFlower Gardener
parent 9b1e89b77c
commit b00a7808a7
1 changed files with 4 additions and 5 deletions

View File

@ -1081,17 +1081,16 @@ def _create_local_cuda_repository(repository_ctx):
)) ))
# Select the headers based on the cuDNN version (strip '64_' for Windows). # Select the headers based on the cuDNN version (strip '64_' for Windows).
if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8": cudnn_headers = ["cudnn.h"]
cudnn_headers = ["cudnn.h"] if cuda_config.cudnn_version.rsplit("_", 1)[0] >= "8":
else: cudnn_headers += [
cudnn_headers = [ "cudnn_backend.h",
"cudnn_adv_infer.h", "cudnn_adv_infer.h",
"cudnn_adv_train.h", "cudnn_adv_train.h",
"cudnn_cnn_infer.h", "cudnn_cnn_infer.h",
"cudnn_cnn_train.h", "cudnn_cnn_train.h",
"cudnn_ops_infer.h", "cudnn_ops_infer.h",
"cudnn_ops_train.h", "cudnn_ops_train.h",
"cudnn.h",
"cudnn_version.h", "cudnn_version.h",
] ]