Enable build with CUDA 11

This commit is contained in:
Nathan Luehr 2020-04-15 13:15:45 -07:00
parent 8e8c67c337
commit 28feb4df0d
15 changed files with 21813 additions and 30 deletions

View File

@ -273,6 +273,7 @@ cc_library(
textual_hdrs = glob(["cufft_*.inc"]), textual_hdrs = glob(["cufft_*.inc"]),
deps = if_cuda_is_configured([ deps = if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cufft_headers",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader", "//tensorflow/stream_executor/platform:dso_loader",
]), ]),
@ -371,6 +372,7 @@ cc_library(
textual_hdrs = ["curand_10_0.inc"], textual_hdrs = ["curand_10_0.inc"],
deps = if_cuda_is_configured([ deps = if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:curand_headers",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader", "//tensorflow/stream_executor/platform:dso_loader",
]), ]),
@ -430,6 +432,7 @@ cc_library(
# LINT.IfChange # LINT.IfChange
"@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cublas_headers",
# LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers) # LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers)
"@local_config_cuda//cuda:cusolver_headers",
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader", "//tensorflow/stream_executor/platform:dso_loader",
@ -451,6 +454,7 @@ cc_library(
textual_hdrs = glob(["cusparse_*.inc"]), textual_hdrs = glob(["cusparse_*.inc"]),
deps = if_cuda_is_configured([ deps = if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cusparse_headers",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader", "//tensorflow/stream_executor/platform:dso_loader",
]), ]),

File diff suppressed because it is too large Load Diff

View File

@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if CUBLAS_VER_MAJOR >= 11
#include "third_party/gpus/cuda/include/cublas_v2.h"
#else
#include "third_party/gpus/cuda/include/cublas.h" #include "third_party/gpus/cuda/include/cublas.h"
#endif
#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/dso_loader.h"
@ -65,7 +69,7 @@ typedef enum {} cublasMath_t;
#include "tensorflow/stream_executor/cuda/cublas_10_1.inc" #include "tensorflow/stream_executor/cuda/cublas_10_1.inc"
#elif CUDA_VERSION == 10020 #elif CUDA_VERSION == 10020
#include "tensorflow/stream_executor/cuda/cublas_10_2.inc" #include "tensorflow/stream_executor/cuda/cublas_10_2.inc"
#elif CUDA_VERSION == 11000 #elif CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 0
#include "tensorflow/stream_executor/cuda/cublas_11_0.inc" #include "tensorflow/stream_executor/cuda/cublas_11_0.inc"
#else #else
#error "We have no wrapper for this version." #error "We have no wrapper for this version."

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cusparse.h" #include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/dso_loader.h"
@ -59,7 +60,7 @@ cusparseStatus_t GetSymbolNotFoundError() {
#include "tensorflow/stream_executor/cuda/cusparse_10_1.inc" #include "tensorflow/stream_executor/cuda/cusparse_10_1.inc"
#elif CUDA_VERSION == 10020 #elif CUDA_VERSION == 10020
#include "tensorflow/stream_executor/cuda/cusparse_10_2.inc" #include "tensorflow/stream_executor/cuda/cusparse_10_2.inc"
#elif CUDA_VERSION == 11000 #elif CUSPARSE_VER_MAJOR == 11 && CUSPARSE_VER_MINOR == 0
#include "tensorflow/stream_executor/cuda/cusparse_11_0.inc" #include "tensorflow/stream_executor/cuda/cusparse_11_0.inc"
#else #else
#error "We don't have a wrapper for this version." #error "We don't have a wrapper for this version."

View File

@ -31,8 +31,12 @@ namespace internal {
namespace { namespace {
string GetCudaVersion() { return TF_CUDA_VERSION; } string GetCudaVersion() { return TF_CUDA_VERSION; }
string GetCudaLibVersion() { return TF_CUDA_LIB_VERSION; }
string GetCudnnVersion() { return TF_CUDNN_VERSION; } string GetCudnnVersion() { return TF_CUDNN_VERSION; }
string GetCublasVersion() { return TF_CUBLAS_VERSION; }
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
string GetCurandVersion() { return TF_CURAND_VERSION; }
string GetCufftVersion() { return TF_CUFFT_VERSION; }
string GetCusparseVersion() { return TF_CUSPARSE_VERSION; }
string GetTensorRTVersion() { return TF_TENSORRT_VERSION; } string GetTensorRTVersion() { return TF_TENSORRT_VERSION; }
port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) { port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
@ -77,23 +81,23 @@ port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
} }
port::StatusOr<void*> GetCublasDsoHandle() { port::StatusOr<void*> GetCublasDsoHandle() {
return GetDsoHandle("cublas", GetCudaLibVersion()); return GetDsoHandle("cublas", GetCublasVersion());
} }
port::StatusOr<void*> GetCufftDsoHandle() { port::StatusOr<void*> GetCufftDsoHandle() {
return GetDsoHandle("cufft", GetCudaLibVersion()); return GetDsoHandle("cufft", GetCufftVersion());
} }
port::StatusOr<void*> GetCusolverDsoHandle() { port::StatusOr<void*> GetCusolverDsoHandle() {
return GetDsoHandle("cusolver", GetCudaLibVersion()); return GetDsoHandle("cusolver", GetCusolverVersion());
} }
port::StatusOr<void*> GetCusparseDsoHandle() { port::StatusOr<void*> GetCusparseDsoHandle() {
return GetDsoHandle("cusparse", GetCudaLibVersion()); return GetDsoHandle("cusparse", GetCusparseVersion());
} }
port::StatusOr<void*> GetCurandDsoHandle() { port::StatusOr<void*> GetCurandDsoHandle() {
return GetDsoHandle("curand", GetCudaLibVersion()); return GetDsoHandle("curand", GetCurandVersion());
} }
port::StatusOr<void*> GetCuptiDsoHandle() { port::StatusOr<void*> GetCuptiDsoHandle() {

View File

@ -84,6 +84,42 @@ cuda_header_library(
includes = ["cublas/include"], includes = ["cublas/include"],
) )
cuda_header_library(
name = "cusolver_headers",
hdrs = [":cusolver-include"],
include_prefix = "third_party/gpus/cuda/include",
strip_include_prefix = "cusolver/include",
deps = [":cuda_headers"],
includes = ["cusolver/include"],
)
cuda_header_library(
name = "cufft_headers",
hdrs = [":cufft-include"],
include_prefix = "third_party/gpus/cuda/include",
strip_include_prefix = "cufft/include",
deps = [":cuda_headers"],
includes = ["cufft/include"],
)
cuda_header_library(
name = "cusparse_headers",
hdrs = [":cusparse-include"],
include_prefix = "third_party/gpus/cuda/include",
strip_include_prefix = "cusparse/include",
deps = [":cuda_headers"],
includes = ["cusparse/include"],
)
cuda_header_library(
name = "curand_headers",
hdrs = [":curand-include"],
include_prefix = "third_party/gpus/cuda/include",
strip_include_prefix = "curand/include",
deps = [":cuda_headers"],
includes = ["curand/include"],
)
cc_library( cc_library(
name = "cublas", name = "cublas",
srcs = ["cuda/lib/%{cublas_lib}"], srcs = ["cuda/lib/%{cublas_lib}"],

View File

@ -85,6 +85,42 @@ cuda_header_library(
deps = [":cuda_headers"], deps = [":cuda_headers"],
) )
cuda_header_library(
name = "cusolver_headers",
hdrs = [":cusolver-include"],
include_prefix = "third_party/gpus/cuda/include",
includes = ["cusolver/include"],
strip_include_prefix = "cusolver/include",
deps = [":cuda_headers"],
)
cuda_header_library(
name = "cufft_headers",
hdrs = [":cufft-include"],
include_prefix = "third_party/gpus/cuda/include",
includes = ["cufft/include"],
strip_include_prefix = "cufft/include",
deps = [":cuda_headers"],
)
cuda_header_library(
name = "cusparse_headers",
hdrs = [":cusparse-include"],
include_prefix = "third_party/gpus/cuda/include",
includes = ["cusparse/include"],
strip_include_prefix = "cusparse/include",
deps = [":cuda_headers"],
)
cuda_header_library(
name = "curand_headers",
hdrs = [":curand-include"],
include_prefix = "third_party/gpus/cuda/include",
includes = ["curand/include"],
strip_include_prefix = "curand/include",
deps = [":cuda_headers"],
)
cc_import( cc_import(
name = "cublas", name = "cublas",
interface_library = "cuda/lib/%{cublas_lib}", interface_library = "cuda/lib/%{cublas_lib}",

View File

@ -17,7 +17,11 @@ limitations under the License.
#define CUDA_CUDA_CONFIG_H_ #define CUDA_CUDA_CONFIG_H_
#define TF_CUDA_VERSION "%{cuda_version}" #define TF_CUDA_VERSION "%{cuda_version}"
#define TF_CUDA_LIB_VERSION "%{cuda_lib_version}" #define TF_CUBLAS_VERSION "%{cublas_version}"
#define TF_CUSOLVER_VERSION "%{cusolver_version}"
#define TF_CURAND_VERSION "%{curand_version}"
#define TF_CUFFT_VERSION "%{cufft_version}"
#define TF_CUSPARSE_VERSION "%{cusparse_version}"
#define TF_CUDNN_VERSION "%{cudnn_version}" #define TF_CUDNN_VERSION "%{cudnn_version}"
#define TF_CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}" #define TF_CUDA_TOOLKIT_PATH "%{cuda_toolkit_path}"

View File

@ -527,28 +527,28 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
"cublas", "cublas",
cpu_value, cpu_value,
cuda_config.config["cublas_library_dir"], cuda_config.config["cublas_library_dir"],
cuda_config.cuda_lib_version, cuda_config.cublas_version,
static = False, static = False,
), ),
"cusolver": _check_cuda_lib_params( "cusolver": _check_cuda_lib_params(
"cusolver", "cusolver",
cpu_value, cpu_value,
cuda_config.config["cuda_library_dir"], cuda_config.config["cusolver_library_dir"],
cuda_config.cuda_lib_version, cuda_config.cusolver_version,
static = False, static = False,
), ),
"curand": _check_cuda_lib_params( "curand": _check_cuda_lib_params(
"curand", "curand",
cpu_value, cpu_value,
cuda_config.config["cuda_library_dir"], cuda_config.config["curand_library_dir"],
cuda_config.cuda_lib_version, cuda_config.curand_version,
static = False, static = False,
), ),
"cufft": _check_cuda_lib_params( "cufft": _check_cuda_lib_params(
"cufft", "cufft",
cpu_value, cpu_value,
cuda_config.config["cuda_library_dir"], cuda_config.config["cufft_library_dir"],
cuda_config.cuda_lib_version, cuda_config.cufft_version,
static = False, static = False,
), ),
"cudnn": _check_cuda_lib_params( "cudnn": _check_cuda_lib_params(
@ -568,8 +568,8 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
"cusparse": _check_cuda_lib_params( "cusparse": _check_cuda_lib_params(
"cusparse", "cusparse",
cpu_value, cpu_value,
cuda_config.config["cuda_library_dir"], cuda_config.config["cusparse_library_dir"],
cuda_config.cuda_lib_version, cuda_config.cusparse_version,
static = False, static = False,
), ),
} }
@ -646,18 +646,37 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor) cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. if int(cuda_major) >= 11:
# It changed from 'x.y' to just 'x' in CUDA 10.1. cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
if (int(cuda_major), int(cuda_minor)) >= (10, 1): cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0]
cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0]
elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
# cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
# It changed from 'x.y' to just 'x' in CUDA 10.1.
cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
cublas_version = cuda_lib_version
cusolver_version = cuda_lib_version
curand_version = cuda_lib_version
cufft_version = cuda_lib_version
cusparse_version = cuda_lib_version
else: else:
cuda_lib_version = cuda_version cublas_version = cuda_version
cusolver_version = cuda_version
curand_version = cuda_version
cufft_version = cuda_version
cusparse_version = cuda_version
return struct( return struct(
cuda_toolkit_path = toolkit_path, cuda_toolkit_path = toolkit_path,
cuda_version = cuda_version, cuda_version = cuda_version,
cublas_version = cublas_version,
cusolver_version = cusolver_version,
curand_version = curand_version,
cufft_version = cufft_version,
cusparse_version = cusparse_version,
cudnn_version = cudnn_version, cudnn_version = cudnn_version,
cuda_lib_version = cuda_lib_version,
compute_capabilities = compute_capabilities(repository_ctx), compute_capabilities = compute_capabilities(repository_ctx),
cpu_value = cpu_value, cpu_value = cpu_value,
config = config, config = config,
@ -739,6 +758,10 @@ def _create_dummy_repository(repository_ctx):
"%{copy_rules}": """ "%{copy_rules}": """
filegroup(name="cuda-include") filegroup(name="cuda-include")
filegroup(name="cublas-include") filegroup(name="cublas-include")
filegroup(name="cusolver-include")
filegroup(name="cufft-include")
filegroup(name="cusparse-include")
filegroup(name="curand-include")
filegroup(name="cudnn-include") filegroup(name="cudnn-include")
""", """,
}, },
@ -770,7 +793,11 @@ filegroup(name="cudnn-include")
"cuda:cuda_config.h", "cuda:cuda_config.h",
{ {
"%{cuda_version}": "", "%{cuda_version}": "",
"%{cuda_lib_version}": "", "%{cublas_version}": "",
"%{cusolver_version}": "",
"%{curand_version}": "",
"%{cufft_version}": "",
"%{cusparse_version}": "",
"%{cudnn_version}": "", "%{cudnn_version}": "",
"%{cuda_toolkit_path}": "", "%{cuda_toolkit_path}": "",
}, },
@ -935,6 +962,56 @@ def _create_local_cuda_repository(repository_ctx):
], ],
)) ))
cusolver_include_path = cuda_config.config["cusolver_include_dir"]
copy_rules.append(make_copy_files_rule(
repository_ctx,
name = "cusolver-include",
srcs = [
cusolver_include_path + "/cusolver_common.h",
cusolver_include_path + "/cusolverDn.h",
],
outs = [
"cusolver/include/cusolver_common.h",
"cusolver/include/cusolverDn.h",
],
))
cufft_include_path = cuda_config.config["cufft_include_dir"]
copy_rules.append(make_copy_files_rule(
repository_ctx,
name = "cufft-include",
srcs = [
cufft_include_path + "/cufft.h",
],
outs = [
"cufft/include/cufft.h",
],
))
cusparse_include_path = cuda_config.config["cusparse_include_dir"]
copy_rules.append(make_copy_files_rule(
repository_ctx,
name = "cusparse-include",
srcs = [
cusparse_include_path + "/cusparse.h",
],
outs = [
"cusparse/include/cusparse.h",
],
))
curand_include_path = cuda_config.config["curand_include_dir"]
copy_rules.append(make_copy_files_rule(
repository_ctx,
name = "curand-include",
srcs = [
curand_include_path + "/curand.h",
],
outs = [
"curand/include/curand.h",
],
))
check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py")) check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py"))
cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config) cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
cuda_lib_srcs = [] cuda_lib_srcs = []
@ -1143,7 +1220,11 @@ def _create_local_cuda_repository(repository_ctx):
tpl_paths["cuda:cuda_config.h"], tpl_paths["cuda:cuda_config.h"],
{ {
"%{cuda_version}": cuda_config.cuda_version, "%{cuda_version}": cuda_config.cuda_version,
"%{cuda_lib_version}": cuda_config.cuda_lib_version, "%{cublas_version}": cuda_config.cublas_version,
"%{cusolver_version}": cuda_config.cusolver_version,
"%{curand_version}": cuda_config.curand_version,
"%{cufft_version}": cuda_config.cufft_version,
"%{cusparse_version}": cuda_config.cusparse_version,
"%{cudnn_version}": cuda_config.cudnn_version, "%{cudnn_version}": cuda_config.cudnn_version,
"%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
}, },

View File

@ -318,12 +318,9 @@ def _find_cublas_config(base_paths, required_version, cuda_version):
# cuBLAS uses the major version only. # cuBLAS uses the major version only.
cublas_version = header_version.split(".")[0] cublas_version = header_version.split(".")[0]
if not _matches_version(cuda_version, cublas_version):
raise ConfigError("cuBLAS version %s does not match CUDA version %s" %
(cublas_version, cuda_version))
else: else:
# There is no version info available before CUDA 10.1, just find the file. # There is no version info available before CUDA 10.1, just find the file.
header_version = cuda_version
header_path = _find_file(base_paths, _header_paths(), "cublas_api.h") header_path = _find_file(base_paths, _header_paths(), "cublas_api.h")
# cuBLAS version is the same as CUDA version (x.y). # cuBLAS version is the same as CUDA version (x.y).
cublas_version = required_version cublas_version = required_version
@ -331,10 +328,98 @@ def _find_cublas_config(base_paths, required_version, cuda_version):
library_path = _find_library(base_paths, "cublas", cublas_version) library_path = _find_library(base_paths, "cublas", cublas_version)
return { return {
"cublas_version": header_version,
"cublas_include_dir": os.path.dirname(header_path), "cublas_include_dir": os.path.dirname(header_path),
"cublas_library_dir": os.path.dirname(library_path), "cublas_library_dir": os.path.dirname(library_path),
} }
def _find_cusolver_config(base_paths, required_version, cuda_version):
if _at_least_version(cuda_version, "11.0"):
def get_header_version(path):
version = (
_get_header_version(path, name)
for name in ("CUSOLVER_VER_MAJOR", "CUSOLVER_VER_MINOR",
"CUSOLVER_VER_PATCH"))
return ".".join(version)
header_path, header_version = _find_header(base_paths, "cusolver_common.h",
required_version,
get_header_version)
cusolver_version = header_version.split(".")[0]
else:
header_version = cuda_version
header_path = _find_file(base_paths, _header_paths(), "cusolver_common.h")
cusolver_version = required_version
library_path = _find_library(base_paths, "cusolver", cusolver_version)
return {
"cusolver_version": header_version,
"cusolver_include_dir": os.path.dirname(header_path),
"cusolver_library_dir": os.path.dirname(library_path),
}
def _find_curand_config(base_paths, required_version, cuda_version):
if _at_least_version(cuda_version, "11.0"):
def get_header_version(path):
version = (
_get_header_version(path, name)
for name in ("CURAND_VER_MAJOR", "CURAND_VER_MINOR",
"CURAND_VER_PATCH"))
return ".".join(version)
header_path, header_version = _find_header(base_paths, "curand.h",
required_version,
get_header_version)
curand_version = header_version.split(".")[0]
else:
header_version = cuda_version
header_path = _find_file(base_paths, _header_paths(), "curand.h")
curand_version = required_version
library_path = _find_library(base_paths, "curand", curand_version)
return {
"curand_version": header_version,
"curand_include_dir": os.path.dirname(header_path),
"curand_library_dir": os.path.dirname(library_path),
}
def _find_cufft_config(base_paths, required_version, cuda_version):
if _at_least_version(cuda_version, "11.0"):
def get_header_version(path):
version = (
_get_header_version(path, name)
for name in ("CUFFT_VER_MAJOR", "CUFFT_VER_MINOR",
"CUFFT_VER_PATCH"))
return ".".join(version)
header_path, header_version = _find_header(base_paths, "cufft.h",
required_version,
get_header_version)
cufft_version = header_version.split(".")[0]
else:
header_version = cuda_version
header_path = _find_file(base_paths, _header_paths(), "cufft.h")
cufft_version = required_version
library_path = _find_library(base_paths, "cufft", cufft_version)
return {
"cufft_version": header_version,
"cufft_include_dir": os.path.dirname(header_path),
"cufft_library_dir": os.path.dirname(library_path),
}
def _find_cudnn_config(base_paths, required_version): def _find_cudnn_config(base_paths, required_version):
@ -358,6 +443,36 @@ def _find_cudnn_config(base_paths, required_version):
} }
def _find_cusparse_config(base_paths, required_version, cuda_version):
if _at_least_version(cuda_version, "11.0"):
def get_header_version(path):
version = (
_get_header_version(path, name)
for name in ("CUSPARSE_VER_MAJOR", "CUSPARSE_VER_MINOR",
"CUSPARSE_VER_PATCH"))
return ".".join(version)
header_path, header_version = _find_header(base_paths, "cusparse.h",
required_version,
get_header_version)
cusparse_version = header_version.split(".")[0]
else:
header_version = cuda_version
header_path = _find_file(base_paths, _header_paths(), "cusparse.h")
cusparse_version = required_version
library_path = _find_library(base_paths, "cusparse", cusparse_version)
return {
"cusparse_version": header_version,
"cusparse_include_dir": os.path.dirname(header_path),
"cusparse_library_dir": os.path.dirname(library_path),
}
def _find_nccl_config(base_paths, required_version): def _find_nccl_config(base_paths, required_version):
def get_header_version(path): def get_header_version(path):
@ -465,6 +580,34 @@ def find_cuda_config():
result.update( result.update(
_find_cublas_config(cublas_paths, cublas_version, cuda_version)) _find_cublas_config(cublas_paths, cublas_version, cuda_version))
cusolver_paths = base_paths
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
cusolver_paths = cuda_paths
cusolver_version = os.environ.get("TF_CUSOLVER_VERSION", "")
result.update(
_find_cusolver_config(cusolver_paths, cusolver_version, cuda_version))
curand_paths = base_paths
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
curand_paths = cuda_paths
curand_version = os.environ.get("TF_CURAND_VERSION", "")
result.update(
_find_curand_config(curand_paths, curand_version, cuda_version))
cufft_paths = base_paths
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
cufft_paths = cuda_paths
cufft_version = os.environ.get("TF_CUFFT_VERSION", "")
result.update(
_find_cufft_config(cufft_paths, cufft_version, cuda_version))
cusparse_paths = base_paths
if tuple(int(v) for v in cuda_version.split(".")) < (11, 0):
cusparse_paths = cuda_paths
cusparse_version = os.environ.get("TF_CUSPARSE_VERSION", "")
result.update(
_find_cusparse_config(cusparse_paths, cusparse_version, cuda_version))
if "cudnn" in libraries: if "cudnn" in libraries:
cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths) cudnn_paths = _get_legacy_path("CUDNN_INSTALL_PATH", base_paths)
cudnn_version = os.environ.get("TF_CUDNN_VERSION", "") cudnn_version = os.environ.get("TF_CUDNN_VERSION", "")

File diff suppressed because one or more lines are too long