Enable build with CUDA 11
This commit is contained in:
parent
8e8c67c337
commit
28feb4df0d
|
@ -273,6 +273,7 @@ cc_library(
|
|||
textual_hdrs = glob(["cufft_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cufft_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
|
@ -371,6 +372,7 @@ cc_library(
|
|||
textual_hdrs = ["curand_10_0.inc"],
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:curand_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
|
@ -430,6 +432,7 @@ cc_library(
|
|||
# LINT.IfChange
|
||||
"@local_config_cuda//cuda:cublas_headers",
|
||||
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers)
|
||||
"@local_config_cuda//cuda:cusolver_headers",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
|
@ -451,6 +454,7 @@ cc_library(
|
|||
textual_hdrs = glob(["cusparse_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cusparse_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
||||
#else
|
||||
#include "third_party/gpus/cuda/include/cublas.h"
|
||||
#endif
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
@ -65,7 +69,7 @@ typedef enum {} cublasMath_t;
|
|||
#include "tensorflow/stream_executor/cuda/cublas_10_1.inc"
|
||||
#elif CUDA_VERSION == 10020
|
||||
#include "tensorflow/stream_executor/cuda/cublas_10_2.inc"
|
||||
#elif CUDA_VERSION == 11000
|
||||
#elif CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 0
|
||||
#include "tensorflow/stream_executor/cuda/cublas_11_0.inc"
|
||||
#else
|
||||
#error "We have no wrapper for this version."
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
||||
|
@ -59,7 +60,7 @@ cusparseStatus_t GetSymbolNotFoundError() {
|
|||
#include "tensorflow/stream_executor/cuda/cusparse_10_1.inc"
|
||||
#elif CUDA_VERSION == 10020
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_10_2.inc"
|
||||
#elif CUDA_VERSION == 11000
|
||||
#elif CUSPARSE_VER_MAJOR == 11 && CUSPARSE_VER_MINOR == 0
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_11_0.inc"
|
||||
#else
|
||||
#error "We don't have a wrapper for this version."
|
||||
|
|
|
@ -31,8 +31,12 @@ namespace internal {
|
|||
|
||||
namespace {
|
||||
string GetCudaVersion() { return TF_CUDA_VERSION; }
|
||||
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
|
||||
string GetCudnnVersion() { return TF_CUDNN_VERSION; }
|
||||
string GetCublasVersion() { return TF_CUBLAS_VERSION; }
|
||||
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
|
||||
string GetCurandVersion() { return TF_CURAND_VERSION; }
|
||||
string GetCufftVersion() { return TF_CUFFT_VERSION; }
|
||||
string GetCusparseVersion() { return TF_CUSPARSE_VERSION; }
|
||||
string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
|
||||
|
||||
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||
|
@ -77,23 +81,23 @@ port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
|
|||
}
|
||||
|
||||
port::StatusOr<void*> GetCublasDsoHandle() {
|
||||
return GetDsoHandle("cublas", GetCudaLibVersion());
|
||||
return GetDsoHandle("cublas", GetCublasVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCufftDsoHandle() {
|
||||
return GetDsoHandle("cufft", GetCudaLibVersion());
|
||||
return GetDsoHandle("cufft", GetCufftVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusolverDsoHandle() {
|
||||
return GetDsoHandle("cusolver", GetCudaLibVersion());
|
||||
return GetDsoHandle("cusolver", GetCusolverVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusparseDsoHandle() {
|
||||
return GetDsoHandle("cusparse", GetCudaLibVersion());
|
||||
return GetDsoHandle("cusparse", GetCusparseVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCurandDsoHandle() {
|
||||
return GetDsoHandle("curand", GetCudaLibVersion());
|
||||
return GetDsoHandle("curand", GetCurandVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCuptiDsoHandle() {
|
||||
|
|
|
@ -84,6 +84,42 @@ cuda_header_library(
|
|||
includes = ["cublas/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusolver_headers",
|
||||
hdrs = [":cusolver-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "cusolver/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["cusolver/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cufft_headers",
|
||||
hdrs = [":cufft-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "cufft/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["cufft/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusparse_headers",
|
||||
hdrs = [":cusparse-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "cusparse/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["cusparse/include"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "curand_headers",
|
||||
hdrs = [":curand-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
strip_include_prefix = "curand/include",
|
||||
deps = [":cuda_headers"],
|
||||
includes = ["curand/include"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cublas",
|
||||
srcs = ["cuda/lib/%{cublas_lib}"],
|
||||
|
|
|
@ -85,6 +85,42 @@ cuda_header_library(
|
|||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusolver_headers",
|
||||
hdrs = [":cusolver-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["cusolver/include"],
|
||||
strip_include_prefix = "cusolver/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cufft_headers",
|
||||
hdrs = [":cufft-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["cufft/include"],
|
||||
strip_include_prefix = "cufft/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "cusparse_headers",
|
||||
hdrs = [":cusparse-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["cusparse/include"],
|
||||
strip_include_prefix = "cusparse/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cuda_header_library(
|
||||
name = "curand_headers",
|
||||
hdrs = [":curand-include"],
|
||||
include_prefix = "third_party/gpus/cuda/include",
|
||||
includes = ["curand/include"],
|
||||
strip_include_prefix = "curand/include",
|
||||
deps = [":cuda_headers"],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
name = "cublas",
|
||||
interface_library = "cuda/lib/%{cublas_lib}",
|
||||
|
|
|
@ -17,7 +17,11 @@ limitations under the License.
|
|||
#define CUDA_CUDA_CONFIG_H_
|
||||
|
||||
#define TF_CUDA_VERSION "%{cuda_version}"
|
||||
#define TF_CUDA_LIB_VERSION "%{cuda_lib_version}"
|
||||
#define TF_CUBLAS_VERSION "%{cublas_version}"
|
||||
#define TF_CUSOLVER_VERSION "%{cusolver_version}"
|
||||
#define TF_CURAND_VERSION "%{curand_version}"
|
||||
#define TF_CUFFT_VERSION "%{cufft_version}"
|
||||
#define TF_CUSPARSE_VERSION "%{cusparse_version}"
|
||||
#define TF_CUDNN_VERSION "%{cudnn_version}"
|
||||
|
||||
#define TF_CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}"
|
||||
|
|
|
@ -527,28 +527,28 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
|||
"cublas",
|
||||
cpu_value,
|
||||
cuda_config.config["cublas_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.cublas_version,
|
||||
static = False,
|
||||
),
|
||||
"cusolver": _check_cuda_lib_params(
|
||||
"cusolver",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["cusolver_library_dir"],
|
||||
cuda_config.cusolver_version,
|
||||
static = False,
|
||||
),
|
||||
"curand": _check_cuda_lib_params(
|
||||
"curand",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["curand_library_dir"],
|
||||
cuda_config.curand_version,
|
||||
static = False,
|
||||
),
|
||||
"cufft": _check_cuda_lib_params(
|
||||
"cufft",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["cufft_library_dir"],
|
||||
cuda_config.cufft_version,
|
||||
static = False,
|
||||
),
|
||||
"cudnn": _check_cuda_lib_params(
|
||||
|
@ -568,8 +568,8 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
|
|||
"cusparse": _check_cuda_lib_params(
|
||||
"cusparse",
|
||||
cpu_value,
|
||||
cuda_config.config["cuda_library_dir"],
|
||||
cuda_config.cuda_lib_version,
|
||||
cuda_config.config["cusparse_library_dir"],
|
||||
cuda_config.cusparse_version,
|
||||
static = False,
|
||||
),
|
||||
}
|
||||
|
@ -646,18 +646,37 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
|||
cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
|
||||
cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
|
||||
|
||||
if int(cuda_major) >= 11:
|
||||
cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
|
||||
cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
|
||||
curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
|
||||
cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0]
|
||||
cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0]
|
||||
elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
|
||||
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
|
||||
# It changed from 'x.y' to just 'x' in CUDA 10.1.
|
||||
if (int(cuda_major), int(cuda_minor)) >= (10, 1):
|
||||
cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
|
||||
cublas_version = cuda_lib_version
|
||||
cusolver_version = cuda_lib_version
|
||||
curand_version = cuda_lib_version
|
||||
cufft_version = cuda_lib_version
|
||||
cusparse_version = cuda_lib_version
|
||||
else:
|
||||
cuda_lib_version = cuda_version
|
||||
cublas_version = cuda_version
|
||||
cusolver_version = cuda_version
|
||||
curand_version = cuda_version
|
||||
cufft_version = cuda_version
|
||||
cusparse_version = cuda_version
|
||||
|
||||
return struct(
|
||||
cuda_toolkit_path = toolkit_path,
|
||||
cuda_version = cuda_version,
|
||||
cublas_version = cublas_version,
|
||||
cusolver_version = cusolver_version,
|
||||
curand_version = curand_version,
|
||||
cufft_version = cufft_version,
|
||||
cusparse_version = cusparse_version,
|
||||
cudnn_version = cudnn_version,
|
||||
cuda_lib_version = cuda_lib_version,
|
||||
compute_capabilities = compute_capabilities(repository_ctx),
|
||||
cpu_value = cpu_value,
|
||||
config = config,
|
||||
|
@ -739,6 +758,10 @@ def _create_dummy_repository(repository_ctx):
|
|||
"%{copy_rules}": """
|
||||
filegroup(name="cuda-include")
|
||||
filegroup(name="cublas-include")
|
||||
filegroup(name="cusolver-include")
|
||||
filegroup(name="cufft-include")
|
||||
filegroup(name="cusparse-include")
|
||||
filegroup(name="curand-include")
|
||||
filegroup(name="cudnn-include")
|
||||
""",
|
||||
},
|
||||
|
@ -770,7 +793,11 @@ filegroup(name="cudnn-include")
|
|||
"cuda:cuda_config.h",
|
||||
{
|
||||
"%{cuda_version}": "",
|
||||
"%{cuda_lib_version}": "",
|
||||
"%{cublas_version}": "",
|
||||
"%{cusolver_version}": "",
|
||||
"%{curand_version}": "",
|
||||
"%{cufft_version}": "",
|
||||
"%{cusparse_version}": "",
|
||||
"%{cudnn_version}": "",
|
||||
"%{cuda_toolkit_path}": "",
|
||||
},
|
||||
|
@ -935,6 +962,56 @@ def _create_local_cuda_repository(repository_ctx):
|
|||
],
|
||||
))
|
||||
|
||||
cusolver_include_path = cuda_config.config["cusolver_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "cusolver-include",
|
||||
srcs = [
|
||||
cusolver_include_path + "/cusolver_common.h",
|
||||
cusolver_include_path + "/cusolverDn.h",
|
||||
],
|
||||
outs = [
|
||||
"cusolver/include/cusolver_common.h",
|
||||
"cusolver/include/cusolverDn.h",
|
||||
],
|
||||
))
|
||||
|
||||
cufft_include_path = cuda_config.config["cufft_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "cufft-include",
|
||||
srcs = [
|
||||
cufft_include_path + "/cufft.h",
|
||||
],
|
||||
outs = [
|
||||
"cufft/include/cufft.h",
|
||||
],
|
||||
))
|
||||
|
||||
cusparse_include_path = cuda_config.config["cusparse_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "cusparse-include",
|
||||
srcs = [
|
||||
cusparse_include_path + "/cusparse.h",
|
||||
],
|
||||
outs = [
|
||||
"cusparse/include/cusparse.h",
|
||||
],
|
||||
))
|
||||
|
||||
curand_include_path = cuda_config.config["curand_include_dir"]
|
||||
copy_rules.append(make_copy_files_rule(
|
||||
repository_ctx,
|
||||
name = "curand-include",
|
||||
srcs = [
|
||||
curand_include_path + "/curand.h",
|
||||
],
|
||||
outs = [
|
||||
"curand/include/curand.h",
|
||||
],
|
||||
))
|
||||
|
||||
check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py"))
|
||||
cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
|
||||
cuda_lib_srcs = []
|
||||
|
@ -1143,7 +1220,11 @@ def _create_local_cuda_repository(repository_ctx):
|
|||
tpl_paths["cuda:cuda_config.h"],
|
||||
{
|
||||
"%{cuda_version}": cuda_config.cuda_version,
|
||||
"%{cuda_lib_version}": cuda_config.cuda_lib_version,
|
||||
"%{cublas_version}": cuda_config.cublas_version,
|
||||
"%{cusolver_version}": cuda_config.cusolver_version,
|
||||
"%{curand_version}": cuda_config.curand_version,
|
||||
"%{cufft_version}": cuda_config.cufft_version,
|
||||
"%{cusparse_version}": cuda_config.cusparse_version,
|
||||
"%{cudnn_version}": cuda_config.cudnn_version,
|
||||
"%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
|
||||
},
|
||||
|
|
|
@ -318,12 +318,9 @@ def _find_cublas_config(base_paths, required_version, cuda_version):
|
|||
# cuBLAS uses the major version only.
|
||||
cublas_version = header_version.split(".")[0]
|
||||
|
||||
if not _matches_version(cuda_version, cublas_version):
|
||||
raise ConfigError("cuBLAS version %s does not match CUDA version %s" %
|
||||
(cublas_version, cuda_version))
|
||||
|
||||
else:
|
||||
# There is no version info available before CUDA 10.1, just find the file.
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cublas_api.h")
|
||||
# cuBLAS version is the same as CUDA version (x.y).
|
||||
cublas_version = required_version
|
||||
|
@ -331,10 +328,98 @@ def _find_cublas_config(base_paths, required_version, cuda_version):
|
|||
library_path = _find_library(base_paths, "cublas", cublas_version)
|
||||
|
||||
return {
|
||||
"cublas_version": header_version,
|
||||
"cublas_include_dir": os.path.dirname(header_path),
|
||||
"cublas_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
def _find_cusolver_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CUSOLVER_VER_MAJOR", "CUSOLVER_VER_MINOR",
|
||||
"CUSOLVER_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "cusolver_common.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
cusolver_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cusolver_common.h")
|
||||
cusolver_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "cusolver", cusolver_version)
|
||||
|
||||
return {
|
||||
"cusolver_version": header_version,
|
||||
"cusolver_include_dir": os.path.dirname(header_path),
|
||||
"cusolver_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
def _find_curand_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CURAND_VER_MAJOR", "CURAND_VER_MINOR",
|
||||
"CURAND_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "curand.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
curand_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "curand.h")
|
||||
curand_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "curand", curand_version)
|
||||
|
||||
return {
|
||||
"curand_version": header_version,
|
||||
"curand_include_dir": os.path.dirname(header_path),
|
||||
"curand_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
def _find_cufft_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CUFFT_VER_MAJOR", "CUFFT_VER_MINOR",
|
||||
"CUFFT_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "cufft.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
cufft_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cufft.h")
|
||||
cufft_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "cufft", cufft_version)
|
||||
|
||||
return {
|
||||
"cufft_version": header_version,
|
||||
"cufft_include_dir": os.path.dirname(header_path),
|
||||
"cufft_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
|
||||
def _find_cudnn_config(base_paths, required_version):
|
||||
|
||||
|
@ -358,6 +443,36 @@ def _find_cudnn_config(base_paths, required_version):
|
|||
}
|
||||
|
||||
|
||||
def _find_cusparse_config(base_paths, required_version, cuda_version):
|
||||
|
||||
if _at_least_version(cuda_version, "11.0"):
|
||||
|
||||
def get_header_version(path):
|
||||
version = (
|
||||
_get_header_version(path, name)
|
||||
for name in ("CUSPARSE_VER_MAJOR", "CUSPARSE_VER_MINOR",
|
||||
"CUSPARSE_VER_PATCH"))
|
||||
return ".".join(version)
|
||||
|
||||
header_path, header_version = _find_header(base_paths, "cusparse.h",
|
||||
required_version,
|
||||
get_header_version)
|
||||
cusparse_version = header_version.split(".")[0]
|
||||
|
||||
else:
|
||||
header_version = cuda_version
|
||||
header_path = _find_file(base_paths, _header_paths(), "cusparse.h")
|
||||
cusparse_version = required_version
|
||||
|
||||
library_path = _find_library(base_paths, "cusparse", cusparse_version)
|
||||
|
||||
return {
|
||||
"cusparse_version": header_version,
|
||||
"cusparse_include_dir": os.path.dirname(header_path),
|
||||
"cusparse_library_dir": os.path.dirname(library_path),
|
||||
}
|
||||
|
||||
|
||||
def _find_nccl_config(base_paths, required_version):
|
||||
|
||||
def get_header_version(path):
|
||||
|
@ -465,6 +580,34 @@ def find_cuda_config():
|
|||
result.update(
|
||||
_find_cublas_config(cublas_paths, cublas_version, cuda_version))
|
||||
|
||||
cusolver_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
cusolver_paths = cuda_paths
|
||||
cusolver_version = os.environ.get("TF_CUSOLVER_VERSION", "")
|
||||
result.update(
|
||||
_find_cusolver_config(cusolver_paths, cusolver_version, cuda_version))
|
||||
|
||||
curand_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
curand_paths = cuda_paths
|
||||
curand_version = os.environ.get("TF_CURAND_VERSION", "")
|
||||
result.update(
|
||||
_find_curand_config(curand_paths, curand_version, cuda_version))
|
||||
|
||||
cufft_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
cufft_paths = cuda_paths
|
||||
cufft_version = os.environ.get("TF_CUFFT_VERSION", "")
|
||||
result.update(
|
||||
_find_cufft_config(cufft_paths, cufft_version, cuda_version))
|
||||
|
||||
cusparse_paths = base_paths
|
||||
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
|
||||
cusparse_paths = cuda_paths
|
||||
cusparse_version = os.environ.get("TF_CUSPARSE_VERSION", "")
|
||||
result.update(
|
||||
_find_cusparse_config(cusparse_paths, cusparse_version, cuda_version))
|
||||
|
||||
if "cudnn" in libraries:
|
||||
cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths)
|
||||
cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue