Use TF_CUDA_PATHS to find cuBLAS if CUDA version >= 10.1, otherwise CUDA_TOOLKIT_PATH.

PiperOrigin-RevId: 247743023
This commit is contained in:
A. Unique TensorFlower 2019-05-11 03:13:26 -07:00 committed by TensorFlower Gardener
parent 3261198712
commit f46e445c78

View File

@ -442,10 +442,11 @@ def find_cuda_config():
cuda_paths = _list_from_env("CUDA_TOOLKIT_PATH", base_paths)
result.update(_find_cuda_config(cuda_paths, cuda_version))
cublas_paths = _list_from_env("CUBLAS_INSTALL_PATH", base_paths)
# Add cuda paths in case CuBLAS is installed under CUDA_TOOLKIT_PATH.
cublas_paths += list(set(cuda_paths) - set(cublas_paths))
cuda_version = result["cuda_version"]
cublas_paths = base_paths
if cuda_version.split(".") < (10, 1):
# Before CUDA 10.1, cuBLAS was in the same directory as the toolkit.
cublas_paths = cuda_paths
cublas_version = os.environ.get("TF_CUBLAS_VERSION", "")
result.update(
_find_cublas_config(cublas_paths, cublas_version, cuda_version))