Enable build with CUDA 11
This commit is contained in:
parent
8e8c67c337
commit
28feb4df0d
@ -273,6 +273,7 @@ cc_library(
|
||||
textual_hdrs = glob(["cufft_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cufft_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
@ -371,6 +372,7 @@ cc_library(
|
||||
textual_hdrs = ["curand_10_0.inc"],
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:curand_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
@ -430,6 +432,7 @@ cc_library(
|
||||
# LINT.IfChange
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers)
|
||||
"@local_config_cuda//cuda:cusolver_headers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
@ -451,6 +454,7 @@ cc_library(
|
||||
textual_hdrs = glob(["cusparse_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cusparse_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
|
4038
tensorflow/stream_executor/cuda/cublas_11_0.inc
Normal file
4038
tensorflow/stream_executor/cuda/cublas_11_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#else
|
||||
#include "third_party/gpus/cuda/include/cublas.h"
|
||||
#endif
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
@ -65,7 +69,7 @@ typedef enum {} cublasMath_t;
|
||||
#include "tensorflow/stream_executor/cuda/cublas_10_1.inc"
|
||||
#elif CUDA_VERSION == 10020
|
||||
#include "tensorflow/stream_executor/cuda/cublas_10_2.inc"
|
||||
#elif CUDA_VERSION == 11000
|
||||
#elif CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 0
|
||||
#include "tensorflow/stream_executor/cuda/cublas_11_0.inc"
|
||||
#else
|
||||
#error "We have no wrapper for this version."
|
||||
|
2036
tensorflow/stream_executor/cuda/cuda_11_0.inc
Normal file
2036
tensorflow/stream_executor/cuda/cuda_11_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
1501
tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc
Normal file
1501
tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
5953
tensorflow/stream_executor/cuda/cusolver_dense_11_0.inc
Normal file
5953
tensorflow/stream_executor/cuda/cusolver_dense_11_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
7942
tensorflow/stream_executor/cuda/cusparse_11_0.inc
Normal file
7942
tensorflow/stream_executor/cuda/cusparse_11_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
||||
@ -59,7 +60,7 @@ cusparseStatus_t GetSymbolNotFoundError() {
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_10_1.inc"
|
||||
#elif CUDA_VERSION == 10020
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_10_2.inc"
|
||||
#elif CUDA_VERSION == 11000
|
||||
#elif CUSPARSE_VER_MAJOR == 11 && CUSPARSE_VER_MINOR == 0
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_11_0.inc"
|
||||
#else
|
||||
#error "We don't have a wrapper for this version."
|
||||
|
@ -31,8 +31,12 @@ namespace internal {
|
||||
|
||||
namespace {
|
||||
string GetCudaVersion() { return TF_CUDA_VERSION; }
|
||||
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
|
||||
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
||||
string GetCublasVersion() { return TF_CUBLAS_VERSION; }
|
||||
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
|
||||
string GetCurandVersion() { return TF_CURAND_VERSION; }
|
||||
string GetCufftVersion() { return TF_CUFFT_VERSION; }
|
||||
string GetCusparseVersion() { return TF_CUSPARSE_VERSION; }
|
||||
string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
|
||||
|
||||
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||
@ -77,23 +81,23 @@ port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCublasDsoHandle() {
|
||||
return GetDsoHandle("cublas", GetCudaLibVersion());
|
||||
return GetDsoHandle("cublas", GetCublasVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCufftDsoHandle() {
|
||||
return GetDsoHandle("cufft", GetCudaLibVersion());
|
||||
return GetDsoHandle("cufft", GetCufftVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusolverDsoHandle() {
|
||||
return GetDsoHandle("cusolver", GetCudaLibVersion());
|
||||
return GetDsoHandle("cusolver", GetCusolverVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusparseDsoHandle() {
|
||||
return GetDsoHandle("cusparse", GetCudaLibVersion());
|
||||
return GetDsoHandle("cusparse", GetCusparseVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCurandDsoHandle() {
|
||||
return GetDsoHandle("curand", GetCudaLibVersion());
|
||||
return GetDsoHandle("curand", GetCurandVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCuptiDsoHandle() {
|
||||
|
36
third_party/gpus/cuda/BUILD.tpl
vendored
36
third_party/gpus/cuda/BUILD.tpl
vendored
@ -84,6 +84,42 @@ cuda_header_library(
|
||||
includes = ["cublas/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusolver_headers",
|
||||
hdrs = [":cusolver-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "cusolver/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["cusolver/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cufft_headers",
|
||||
hdrs = [":cufft-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "cufft/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["cufft/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusparse_headers",
|
||||
hdrs = [":cusparse-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "cusparse/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["cusparse/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "curand_headers",
|
||||
hdrs = [":curand-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "curand/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["curand/include"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas",
|
||||
srcs = ["cuda/lib/%{cublas_lib}"],
|
||||
|
36
third_party/gpus/cuda/BUILD.windows.tpl
vendored
36
third_party/gpus/cuda/BUILD.windows.tpl
vendored
@ -85,6 +85,42 @@ cuda_header_library(
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusolver_headers",
|
||||
hdrs = [":cusolver-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["cusolver/include"],
|
||||
strip_include_prefix = "cusolver/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cufft_headers",
|
||||
hdrs = [":cufft-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["cufft/include"],
|
||||
strip_include_prefix = "cufft/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusparse_headers",
|
||||
hdrs = [":cusparse-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["cusparse/include"],
|
||||
strip_include_prefix = "cusparse/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "curand_headers",
|
||||
hdrs = [":curand-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["curand/include"],
|
||||
strip_include_prefix = "curand/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cublas",
|
||||
interface_library = "cuda/lib/%{cublas_lib}",
|
||||
|
6
third_party/gpus/cuda/cuda_config.h.tpl
vendored
6
third_party/gpus/cuda/cuda_config.h.tpl
vendored
@ -17,7 +17,11 @@ limitations under the License.
|
||||
#define CUDA_CUDA_CONFIG_H_
|
||||
|
||||
#define TF_CUDA_VERSION "%{cuda_version}"
|
||||
#define TF_CUDA_LIB_VERSION "%{cuda_lib_version}"
|
||||
#define TF_CUBLAS_VERSION "%{cublas_version}"
|
||||
#define TF_CUSOLVER_VERSION "%{cusolver_version}"
|
||||
#define TF_CURAND_VERSION "%{curand_version}"
|
||||
#define TF_CUFFT_VERSION "%{cufft_version}"
|
||||
#define TF_CUSPARSE_VERSION "%{cusparse_version}"
|
||||
#define TF_CUDNN_VERSION "%{cudnn_version}"
|
||||
|
||||
#define TF_CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}"
|
||||
|
113
third_party/gpus/cuda_configure.bzl
vendored
113
third_party/gpus/cuda_configure.bzl
vendored
@ -527,28 +527,28 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
||||
"cublas",
|
||||
cpu_value,
|
||||
cuda_config.config["cublas_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cusolver": _check_cuda_lib_params(
|
||||
"cusolver",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["cusolver_library_dir"],
|
||||
cuda_config.cusolver_version,
|
||||
static = False,
|
||||
),
|
||||
"curand": _check_cuda_lib_params(
|
||||
"curand",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["curand_library_dir"],
|
||||
cuda_config.curand_version,
|
||||
static = False,
|
||||
),
|
||||
"cufft": _check_cuda_lib_params(
|
||||
"cufft",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["cufft_library_dir"],
|
||||
cuda_config.cufft_version,
|
||||
static = False,
|
||||
),
|
||||
"cudnn": _check_cuda_lib_params(
|
||||
@ -568,8 +568,8 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
||||
"cusparse": _check_cuda_lib_params(
|
||||
"cusparse",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["cusparse_library_dir"],
|
||||
cuda_config.cusparse_version,
|
||||
static = False,
|
||||
),
|
||||
}
|
||||
@ -646,18 +646,37 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
||||
cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
|
||||
cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
|
||||
|
||||
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
|
||||
# It changed from 'x.y' to just 'x' in CUDA 10.1.
|
||||
if (int(cuda_major), int(cuda_minor)) >= (10, 1):
|
||||
if int(cuda_major) >= 11:
|
||||
cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
|
||||
cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
|
||||
curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
|
||||
cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0]
|
||||
cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0]
|
||||
elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
|
||||
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
|
||||
# It changed from 'x.y' to just 'x' in CUDA 10.1.
|
||||
cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
|
||||
cublas_version = cuda_lib_version
|
||||
cusolver_version = cuda_lib_version
|
||||
curand_version = cuda_lib_version
|
||||
cufft_version = cuda_lib_version
|
||||
cusparse_version = cuda_lib_version
|
||||
else:
|
||||
cuda_lib_version = cuda_version
|
||||
cublas_version = cuda_version
|
||||
cusolver_version = cuda_version
|
||||
curand_version = cuda_version
|
||||
cufft_version = cuda_version
|
||||
cusparse_version = cuda_version
|
||||
|
||||
return struct(
|
||||
cuda_toolkit_path = toolkit_path,
|
||||
cuda_version = cuda_version,
|
||||
cublas_version = cublas_version,
|
||||
cusolver_version = cusolver_version,
|
||||
curand_version = curand_version,
|
||||
cufft_version = cufft_version,
|
||||
cusparse_version = cusparse_version,
|
||||
cudnn_version = cudnn_version,
|
||||
cuda_lib_version = cuda_lib_version,
|
||||
compute_capabilities = compute_capabilities(repository_ctx),
|
||||
cpu_value = cpu_value,
|
||||
config = config,
|
||||
@ -739,6 +758,10 @@ def _create_dummy_repository(repository_ctx):
|
||||
"%{copy_rules}": """
|
||||
filegroup(name="cuda-include")
|
||||
filegroup(name="cublas-include")
|
||||
filegroup(name="cusolver-include")
|
||||
filegroup(name="cufft-include")
|
||||
filegroup(name="cusparse-include")
|
||||
filegroup(name="curand-include")
|
||||
filegroup(name="cudnn-include")
|
||||
""",
|
||||
},
|
||||
@ -770,7 +793,11 @@ filegroup(name="cudnn-include")
|
||||
"cuda:cuda_config.h",
|
||||
{
|
||||
"%{cuda_version}": "",
|
||||
"%{cuda_lib_version}": "",
|
||||
"%{cublas_version}": "",
|
||||
"%{cusolver_version}": "",
|
||||
"%{curand_version}": "",
|
||||
"%{cufft_version}": "",
|
||||
"%{cusparse_version}": "",
|
||||
"%{cudnn_version}": "",
|
||||
"%{cuda_toolkit_path}": "",
|
||||
},
|
||||
@ -935,6 +962,56 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
],
|
||||
))
|
||||
|
||||
cusolver_include_path = cuda_config.config["cusolver_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "cusolver-include",
|
||||
srcs = [
|
||||
cusolver_include_path + "/cusolver_common.h",
|
||||
cusolver_include_path + "/cusolverDn.h",
|
||||
],
|
||||
outs = [
|
||||
"cusolver/include/cusolver_common.h",
|
||||
"cusolver/include/cusolverDn.h",
|
||||
],
|
||||
))
|
||||
|
||||
cufft_include_path = cuda_config.config["cufft_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "cufft-include",
|
||||
srcs = [
|
||||
cufft_include_path + "/cufft.h",
|
||||
],
|
||||
outs = [
|
||||
"cufft/include/cufft.h",
|
||||
],
|
||||
))
|
||||
|
||||
cusparse_include_path = cuda_config.config["cusparse_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "cusparse-include",
|
||||
srcs = [
|
||||
cusparse_include_path + "/cusparse.h",
|
||||
],
|
||||
outs = [
|
||||
"cusparse/include/cusparse.h",
|
||||
],
|
||||
))
|
||||
|
||||
curand_include_path = cuda_config.config["curand_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "curand-include",
|
||||
srcs = [
|
||||
curand_include_path + "/curand.h",
|
||||
],
|
||||
outs = [
|
||||
"curand/include/curand.h",
|
||||
],
|
||||
))
|
||||
|
||||
check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py"))
|
||||
cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
|
||||
cuda_lib_srcs = []
|
||||
@ -1143,7 +1220,11 @@ def _create_local_cuda_repository(repository_ctx):
|
||||
tpl_paths["cuda:cuda_config.h"],
|
||||
{
|
||||
"%{cuda_version}": cuda_config.cuda_version,
|
||||
"%{cuda_lib_version}": cuda_config.cuda_lib_version,
|
||||
"%{cublas_version}": cuda_config.cublas_version,
|
||||
"%{cusolver_version}": cuda_config.cusolver_version,
|
||||
"%{curand_version}": cuda_config.curand_version,
|
||||
"%{cufft_version}": cuda_config.cufft_version,
|
||||
"%{cusparse_version}": cuda_config.cusparse_version,
|
||||
"%{cudnn_version}": cuda_config.cudnn_version,
|
||||
"%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
|
||||
},
|
||||
|
151
third_party/gpus/find_cuda_config.py
vendored
151
third_party/gpus/find_cuda_config.py
vendored
@ -318,12 +318,9 @@ def _find_cublas_config(base_paths, required_version, cuda_version):
|
||||
# cuBLAS uses the major version only.
|
||||
cublas_version = header_version.split(".")[0]
|
||||
|
||||
if not _matches_version(cuda_version, cublas_version):
|
||||
raise ConfigError("cuBLAS version %s does not match CUDA version %s" %
|
||||
(cublas_version, cuda_version))
|
||||
|
||||
else:
|
||||
# There is no version info available before CUDA 10.1, just find the file.
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cublas_api.h")
|
||||
# cuBLAS version is the same as CUDA version (x.y).
|
||||
cublas_version = required_version
|
||||
@ -331,10 +328,98 @@ def _find_cublas_config(base_paths, required_version, cuda_version):
|
||||
library_path = _find_library(base_paths, "cublas", cublas_version)
|
||||
|
||||
return {
|
||||
"cublas_version": header_version,
|
||||
"cublas_include_dir": os.path.dirname(header_path),
|
||||
"cublas_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
def _find_cusolver_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CUSOLVER_VER_MAJOR", "CUSOLVER_VER_MINOR",
|
||||
"CUSOLVER_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "cusolver_common.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
cusolver_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cusolver_common.h")
|
||||
cusolver_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "cusolver", cusolver_version)
|
||||
|
||||
return {
|
||||
"cusolver_version": header_version,
|
||||
"cusolver_include_dir": os.path.dirname(header_path),
|
||||
"cusolver_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
def _find_curand_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CURAND_VER_MAJOR", "CURAND_VER_MINOR",
|
||||
"CURAND_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "curand.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
curand_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "curand.h")
|
||||
curand_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "curand", curand_version)
|
||||
|
||||
return {
|
||||
"curand_version": header_version,
|
||||
"curand_include_dir": os.path.dirname(header_path),
|
||||
"curand_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
def _find_cufft_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CUFFT_VER_MAJOR", "CUFFT_VER_MINOR",
|
||||
"CUFFT_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "cufft.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
cufft_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cufft.h")
|
||||
cufft_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "cufft", cufft_version)
|
||||
|
||||
return {
|
||||
"cufft_version": header_version,
|
||||
"cufft_include_dir": os.path.dirname(header_path),
|
||||
"cufft_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
|
||||
def _find_cudnn_config(base_paths, required_version):
|
||||
|
||||
@ -358,6 +443,36 @@ def _find_cudnn_config(base_paths, required_version):
|
||||
}
|
||||
|
||||
|
||||
def _find_cusparse_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CUSPARSE_VER_MAJOR", "CUSPARSE_VER_MINOR",
|
||||
"CUSPARSE_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "cusparse.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
cusparse_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cusparse.h")
|
||||
cusparse_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "cusparse", cusparse_version)
|
||||
|
||||
return {
|
||||
"cusparse_version": header_version,
|
||||
"cusparse_include_dir": os.path.dirname(header_path),
|
||||
"cusparse_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
|
||||
def _find_nccl_config(base_paths, required_version):
|
||||
|
||||
def get_header_version(path):
|
||||
@ -465,6 +580,34 @@ def find_cuda_config():
|
||||
result.update(
|
||||
_find_cublas_config(cublas_paths, cublas_version, cuda_version))
|
||||
|
||||
cusolver_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
cusolver_paths = cuda_paths
|
||||
cusolver_version = os.environ.get("TF_CUSOLVER_VERSION", "")
|
||||
result.update(
|
||||
_find_cusolver_config(cusolver_paths, cusolver_version, cuda_version))
|
||||
|
||||
curand_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
curand_paths = cuda_paths
|
||||
curand_version = os.environ.get("TF_CURAND_VERSION", "")
|
||||
result.update(
|
||||
_find_curand_config(curand_paths, curand_version, cuda_version))
|
||||
|
||||
cufft_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
cufft_paths = cuda_paths
|
||||
cufft_version = os.environ.get("TF_CUFFT_VERSION", "")
|
||||
result.update(
|
||||
_find_cufft_config(cufft_paths, cufft_version, cuda_version))
|
||||
|
||||
cusparse_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
cusparse_paths = cuda_paths
|
||||
cusparse_version = os.environ.get("TF_CUSPARSE_VERSION", "")
|
||||
result.update(
|
||||
_find_cusparse_config(cusparse_paths, cusparse_version, cuda_version))
|
||||
|
||||
if "cudnn" in libraries:
|
||||
cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths)
|
||||
cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user