During CUDA configuraton, detect path in legacy environment variables (CUDNN_INSTALL_PATH, NCCL_INSTALL_PATH, TENSORRT_INSTALL_PATH) and special-case them.
PiperOrigin-RevId: 247058031
This commit is contained in:
parent
9d19b1a1c6
commit
78d843fd8c
21
third_party/gpus/find_cuda_config.py
vendored
21
third_party/gpus/find_cuda_config.py
vendored
@ -406,6 +406,20 @@ def _list_from_env(env_name, default=[]):
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _get_legacy_path(env_name, default=[]):
|
||||||
|
"""Returns a path specified by a legacy environment variable.
|
||||||
|
|
||||||
|
CUDNN_INSTALL_PATH, NCCL_INSTALL_PATH, TENSORRT_INSTALL_PATH set to
|
||||||
|
'/usr/lib/x86_64-linux-gnu' would previously find both library and header
|
||||||
|
paths. Detect those and return '/usr', otherwise forward to _list_from_env().
|
||||||
|
"""
|
||||||
|
if env_name in os.environ:
|
||||||
|
match = re.match("^(/[^/ ]*)+/lib/\w+-linux-gnu/?$", os.environ[env_name])
|
||||||
|
if match:
|
||||||
|
return [match.group(1)]
|
||||||
|
return _list_from_env(env_name, default)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_path(path):
|
def _normalize_path(path):
|
||||||
"""Returns normalized path, with forward slashes on Windows."""
|
"""Returns normalized path, with forward slashes on Windows."""
|
||||||
path = os.path.normpath(path)
|
path = os.path.normpath(path)
|
||||||
@ -436,18 +450,17 @@ def find_cuda_config():
|
|||||||
_find_cublas_config(cublas_paths, cublas_version, cuda_version))
|
_find_cublas_config(cublas_paths, cublas_version, cuda_version))
|
||||||
|
|
||||||
if "cudnn" in libraries:
|
if "cudnn" in libraries:
|
||||||
cudnn_paths = _list_from_env("CUDNN_INSTALL_PATH", base_paths)
|
cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths)
|
||||||
cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")
|
cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")
|
||||||
result.update(_find_cudnn_config(cudnn_paths, cudnn_version))
|
result.update(_find_cudnn_config(cudnn_paths, cudnn_version))
|
||||||
|
|
||||||
if "nccl" in libraries:
|
if "nccl" in libraries:
|
||||||
nccl_paths = _list_from_env("NCCL_INSTALL_PATH",
|
nccl_paths = _get_legacy_path("NCCL_INSTALL_PATH", base_paths)
|
||||||
base_paths) + _list_from_env("NCCL_HDR_PATH")
|
|
||||||
nccl_version = os.environ.get("TF_NCCL_VERSION", "")
|
nccl_version = os.environ.get("TF_NCCL_VERSION", "")
|
||||||
result.update(_find_nccl_config(nccl_paths, nccl_version))
|
result.update(_find_nccl_config(nccl_paths, nccl_version))
|
||||||
|
|
||||||
if "tensorrt" in libraries:
|
if "tensorrt" in libraries:
|
||||||
tensorrt_paths = _list_from_env("TENSORRT_INSTALL_PATH", base_paths)
|
tensorrt_paths = _get_legacy_path("TENSORRT_INSTALL_PATH", base_paths)
|
||||||
tensorrt_version = os.environ.get("TF_TENSORRT_VERSION", "")
|
tensorrt_version = os.environ.get("TF_TENSORRT_VERSION", "")
|
||||||
result.update(_find_tensorrt_config(tensorrt_paths, tensorrt_version))
|
result.update(_find_tensorrt_config(tensorrt_paths, tensorrt_version))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user