Use TF_CUDA_PATHS to find cuBLAS if CUDA version >= 10.1, otherwise CUDA_TOOLKIT_PATH.
PiperOrigin-RevId: 247743023
This commit is contained in:
parent
3261198712
commit
f46e445c78
7
third_party/gpus/find_cuda_config.py
vendored
7
third_party/gpus/find_cuda_config.py
vendored
@ -442,10 +442,11 @@ def find_cuda_config():
|
||||
cuda_paths = _list_from_env("CUDA_TOOLKIT_PATH", base_paths)
|
||||
result.update(_find_cuda_config(cuda_paths, cuda_version))
|
||||
|
||||
cublas_paths = _list_from_env("CUBLAS_INSTALL_PATH", base_paths)
|
||||
# Add cuda paths in case CuBLAS is installed under CUDA_TOOLKIT_PATH.
|
||||
cublas_paths += list(set(cuda_paths) - set(cublas_paths))
|
||||
cuda_version = result["cuda_version"]
|
||||
cublas_paths = base_paths
|
||||
if cuda_version.split(".") < (10, 1):
|
||||
# Before CUDA 10.1, cuBLAS was in the same directory as the toolkit.
|
||||
cublas_paths = cuda_paths
|
||||
cublas_version = os.environ.get("TF_CUBLAS_VERSION", "")
|
||||
result.update(
|
||||
_find_cublas_config(cublas_paths, cublas_version, cuda_version))
|
||||
|
Loading…
x
Reference in New Issue
Block a user