Add extra header file for cuDNN 8.
PiperOrigin-RevId: 317626279 Change-Id: I99b969a73555932b25081f37b64f71ac6de662d6
This commit is contained in:
parent
9b1e89b77c
commit
b00a7808a7
|
@ -1081,17 +1081,16 @@ def _create_local_cuda_repository(repository_ctx):
|
||||||
))
|
))
|
||||||
|
|
||||||
# Select the headers based on the cuDNN version (strip '64_' for Windows).
|
# Select the headers based on the cuDNN version (strip '64_' for Windows).
|
||||||
if cuda_config.cudnn_version.rsplit("_", 1)[0] < "8":
|
|
||||||
cudnn_headers = ["cudnn.h"]
|
cudnn_headers = ["cudnn.h"]
|
||||||
else:
|
if cuda_config.cudnn_version.rsplit("_", 1)[0] >= "8":
|
||||||
cudnn_headers = [
|
cudnn_headers += [
|
||||||
|
"cudnn_backend.h",
|
||||||
"cudnn_adv_infer.h",
|
"cudnn_adv_infer.h",
|
||||||
"cudnn_adv_train.h",
|
"cudnn_adv_train.h",
|
||||||
"cudnn_cnn_infer.h",
|
"cudnn_cnn_infer.h",
|
||||||
"cudnn_cnn_train.h",
|
"cudnn_cnn_train.h",
|
||||||
"cudnn_ops_infer.h",
|
"cudnn_ops_infer.h",
|
||||||
"cudnn_ops_train.h",
|
"cudnn_ops_train.h",
|
||||||
"cudnn.h",
|
|
||||||
"cudnn_version.h",
|
"cudnn_version.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue