Merge pull request #43716 from nluehr:cuda_11.1_fix
PiperOrigin-RevId: 335904788 Change-Id: If2f6ad73cdeedf97654ef8b5e343af19bcd7d1bf
This commit is contained in:
commit
68a6fe0d98
@ -31,6 +31,7 @@ namespace internal {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
string GetCudaVersion() { return TF_CUDA_VERSION; }
|
string GetCudaVersion() { return TF_CUDA_VERSION; }
|
||||||
|
string GetCudaRtVersion() { return TF_CUDART_VERSION; }
|
||||||
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
||||||
string GetCublasVersion() { return TF_CUBLAS_VERSION; }
|
string GetCublasVersion() { return TF_CUBLAS_VERSION; }
|
||||||
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
|
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
|
||||||
@ -77,7 +78,7 @@ port::StatusOr<void*> GetCudaDriverDsoHandle() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
|
port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
|
||||||
return GetDsoHandle("cudart", GetCudaVersion());
|
return GetDsoHandle("cudart", GetCudaRtVersion());
|
||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<void*> GetCublasDsoHandle() {
|
port::StatusOr<void*> GetCublasDsoHandle() {
|
||||||
|
1
third_party/gpus/cuda/cuda_config.h.tpl
vendored
1
third_party/gpus/cuda/cuda_config.h.tpl
vendored
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define CUDA_CUDA_CONFIG_H_
|
#define CUDA_CUDA_CONFIG_H_
|
||||||
|
|
||||||
#define TF_CUDA_VERSION "%{cuda_version}"
|
#define TF_CUDA_VERSION "%{cuda_version}"
|
||||||
|
#define TF_CUDART_VERSION "%{cudart_version}"
|
||||||
#define TF_CUBLAS_VERSION "%{cublas_version}"
|
#define TF_CUBLAS_VERSION "%{cublas_version}"
|
||||||
#define TF_CUSOLVER_VERSION "%{cusolver_version}"
|
#define TF_CUSOLVER_VERSION "%{cusolver_version}"
|
||||||
#define TF_CURAND_VERSION "%{curand_version}"
|
#define TF_CURAND_VERSION "%{curand_version}"
|
||||||
|
15
third_party/gpus/cuda_configure.bzl
vendored
15
third_party/gpus/cuda_configure.bzl
vendored
@ -534,14 +534,14 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
|||||||
"cudart",
|
"cudart",
|
||||||
cpu_value,
|
cpu_value,
|
||||||
cuda_config.config["cuda_library_dir"],
|
cuda_config.config["cuda_library_dir"],
|
||||||
cuda_config.cuda_version,
|
cuda_config.cudart_version,
|
||||||
static = False,
|
static = False,
|
||||||
),
|
),
|
||||||
"cudart_static": _check_cuda_lib_params(
|
"cudart_static": _check_cuda_lib_params(
|
||||||
"cudart_static",
|
"cudart_static",
|
||||||
cpu_value,
|
cpu_value,
|
||||||
cuda_config.config["cuda_library_dir"],
|
cuda_config.config["cuda_library_dir"],
|
||||||
cuda_config.cuda_version,
|
cuda_config.cudart_version,
|
||||||
static = True,
|
static = True,
|
||||||
),
|
),
|
||||||
"cublas": _check_cuda_lib_params(
|
"cublas": _check_cuda_lib_params(
|
||||||
@ -651,6 +651,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
|||||||
cuda_toolkit_path: The CUDA toolkit installation directory.
|
cuda_toolkit_path: The CUDA toolkit installation directory.
|
||||||
cudnn_install_basedir: The cuDNN installation directory.
|
cudnn_install_basedir: The cuDNN installation directory.
|
||||||
cuda_version: The version of CUDA on the system.
|
cuda_version: The version of CUDA on the system.
|
||||||
|
cudart_version: The CUDA runtime version on the system.
|
||||||
cudnn_version: The version of cuDNN on the system.
|
cudnn_version: The version of cuDNN on the system.
|
||||||
compute_capabilities: A list of the system's CUDA compute capabilities.
|
compute_capabilities: A list of the system's CUDA compute capabilities.
|
||||||
cpu_value: The name of the host operating system.
|
cpu_value: The name of the host operating system.
|
||||||
@ -668,6 +669,11 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
|||||||
cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
|
cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
|
||||||
|
|
||||||
if int(cuda_major) >= 11:
|
if int(cuda_major) >= 11:
|
||||||
|
# The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability.
|
||||||
|
if int(cuda_major) == 11:
|
||||||
|
cudart_version = "64_110" if is_windows else "11.0"
|
||||||
|
else:
|
||||||
|
cudart_version = ("64_%s" if is_windows else "%s") % cuda_major
|
||||||
cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
|
cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
|
||||||
cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
|
cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
|
||||||
curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
|
curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
|
||||||
@ -677,12 +683,14 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
|||||||
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
|
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
|
||||||
# It changed from 'x.y' to just 'x' in CUDA 10.1.
|
# It changed from 'x.y' to just 'x' in CUDA 10.1.
|
||||||
cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
|
cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
|
||||||
|
cudart_version = cuda_version
|
||||||
cublas_version = cuda_lib_version
|
cublas_version = cuda_lib_version
|
||||||
cusolver_version = cuda_lib_version
|
cusolver_version = cuda_lib_version
|
||||||
curand_version = cuda_lib_version
|
curand_version = cuda_lib_version
|
||||||
cufft_version = cuda_lib_version
|
cufft_version = cuda_lib_version
|
||||||
cusparse_version = cuda_lib_version
|
cusparse_version = cuda_lib_version
|
||||||
else:
|
else:
|
||||||
|
cudart_version = cuda_version
|
||||||
cublas_version = cuda_version
|
cublas_version = cuda_version
|
||||||
cusolver_version = cuda_version
|
cusolver_version = cuda_version
|
||||||
curand_version = cuda_version
|
curand_version = cuda_version
|
||||||
@ -693,6 +701,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
|||||||
cuda_toolkit_path = toolkit_path,
|
cuda_toolkit_path = toolkit_path,
|
||||||
cuda_version = cuda_version,
|
cuda_version = cuda_version,
|
||||||
cuda_version_major = cuda_major,
|
cuda_version_major = cuda_major,
|
||||||
|
cudart_version = cudart_version,
|
||||||
cublas_version = cublas_version,
|
cublas_version = cublas_version,
|
||||||
cusolver_version = cusolver_version,
|
cusolver_version = cusolver_version,
|
||||||
curand_version = curand_version,
|
curand_version = curand_version,
|
||||||
@ -816,6 +825,7 @@ filegroup(name="cudnn-include")
|
|||||||
"cuda:cuda_config.h",
|
"cuda:cuda_config.h",
|
||||||
{
|
{
|
||||||
"%{cuda_version}": "",
|
"%{cuda_version}": "",
|
||||||
|
"%{cudart_version}": "",
|
||||||
"%{cublas_version}": "",
|
"%{cublas_version}": "",
|
||||||
"%{cusolver_version}": "",
|
"%{cusolver_version}": "",
|
||||||
"%{curand_version}": "",
|
"%{curand_version}": "",
|
||||||
@ -1281,6 +1291,7 @@ def _create_local_cuda_repository(repository_ctx):
|
|||||||
tpl_paths["cuda:cuda_config.h"],
|
tpl_paths["cuda:cuda_config.h"],
|
||||||
{
|
{
|
||||||
"%{cuda_version}": cuda_config.cuda_version,
|
"%{cuda_version}": cuda_config.cuda_version,
|
||||||
|
"%{cudart_version}": cuda_config.cudart_version,
|
||||||
"%{cublas_version}": cuda_config.cublas_version,
|
"%{cublas_version}": cuda_config.cublas_version,
|
||||||
"%{cusolver_version}": cuda_config.cusolver_version,
|
"%{cusolver_version}": cuda_config.cusolver_version,
|
||||||
"%{curand_version}": cuda_config.curand_version,
|
"%{curand_version}": cuda_config.curand_version,
|
||||||
|
Loading…
Reference in New Issue
Block a user