From 78d843fd8cb45b97edd40479700f7c8d68ef4b82 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 May 2019 11:30:35 -0700 Subject: [PATCH] During CUDA configuraton, detect path in legacy environment variables (CUDNN_INSTALL_PATH, NCCL_INSTALL_PATH, TENSORRT_INSTALL_PATH) and special-case them. PiperOrigin-RevId: 247058031 --- third_party/gpus/find_cuda_config.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index 7662e9e46ae..576456e3646 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -406,6 +406,20 @@ def _list_from_env(env_name, default=[]): return default +def _get_legacy_path(env_name, default=[]): + """Returns a path specified by a legacy environment variable. + + CUDNN_INSTALL_PATH, NCCL_INSTALL_PATH, TENSORRT_INSTALL_PATH set to + '/usr/lib/x86_64-linux-gnu' would previously find both library and header + paths. Detect those and return '/usr', otherwise forward to _list_from_env(). + """ + if env_name in os.environ: + match = re.match("^(/[^/ ]*)+/lib/\w+-linux-gnu/?$", os.environ[env_name]) + if match: + return [match.group(1)] + return _list_from_env(env_name, default) + + def _normalize_path(path): """Returns normalized path, with forward slashes on Windows.""" path = os.path.normpath(path) @@ -436,18 +450,17 @@ def find_cuda_config(): _find_cublas_config(cublas_paths, cublas_version, cuda_version)) if "cudnn" in libraries: - cudnn_paths = _list_from_env("CUDNN_INSTALL_PATH", base_paths) + cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths) cudnn_version = os.environ.get("TF_CUDNN_VERSION", "") result.update(_find_cudnn_config(cudnn_paths, cudnn_version)) if "nccl" in libraries: - nccl_paths = _list_from_env("NCCL_INSTALL_PATH", - base_paths) + _list_from_env("NCCL_HDR_PATH") + nccl_paths = _get_legacy_path("NCCL_INSTALL_PATH", base_paths) nccl_version = os.environ.get("TF_NCCL_VERSION", "") result.update(_find_nccl_config(nccl_paths, nccl_version)) if "tensorrt" in libraries: - tensorrt_paths = _list_from_env("TENSORRT_INSTALL_PATH", base_paths) + tensorrt_paths = _get_legacy_path("TENSORRT_INSTALL_PATH", base_paths) tensorrt_version = os.environ.get("TF_TENSORRT_VERSION", "") result.update(_find_tensorrt_config(tensorrt_paths, tensorrt_version))