[StreamExecutor] Re-apply cublasLt changes that were mistakenly rolled-back.

PiperOrigin-RevId: 340297228
Change-Id: If23aefda775268c8d1bdae76090f274109650d62
This commit is contained in:
Tim Shen 2020-11-02 12:56:31 -08:00 committed by TensorFlower Gardener
parent f8ba2a8d9b
commit 592947d1a6
2 changed files with 20 additions and 0 deletions

View File

@ -127,6 +127,13 @@ cc_library(
linkstatic = 1, linkstatic = 1,
) )
cc_library(
name = "cublasLt",
srcs = ["cuda/lib/%{cublasLt_lib}"],
data = ["cuda/lib/%{cublasLt_lib}"],
linkstatic = 1,
)
cc_library( cc_library(
name = "cusolver", name = "cusolver",
srcs = ["cuda/lib/%{cusolver_lib}"], srcs = ["cuda/lib/%{cusolver_lib}"],
@ -168,6 +175,7 @@ cc_library(
name = "cuda", name = "cuda",
deps = [ deps = [
":cublas", ":cublas",
":cublasLt",
":cuda_headers", ":cuda_headers",
":cudart", ":cudart",
":cudnn", ":cudnn",

View File

@ -551,6 +551,13 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
cuda_config.cublas_version, cuda_config.cublas_version,
static = False, static = False,
), ),
"cublasLt": _check_cuda_lib_params(
"cublasLt",
cpu_value,
cuda_config.config["cublas_library_dir"],
cuda_config.cublas_version,
static = False,
),
"cusolver": _check_cuda_lib_params( "cusolver": _check_cuda_lib_params(
"cusolver", "cusolver",
cpu_value, cpu_value,
@ -780,6 +787,7 @@ def _create_dummy_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
"%{cudart_lib}": lib_name("cudart", cpu_value), "%{cudart_lib}": lib_name("cudart", cpu_value),
"%{cublas_lib}": lib_name("cublas", cpu_value), "%{cublas_lib}": lib_name("cublas", cpu_value),
"%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
"%{cusolver_lib}": lib_name("cusolver", cpu_value), "%{cusolver_lib}": lib_name("cusolver", cpu_value),
"%{cudnn_lib}": lib_name("cudnn", cpu_value), "%{cudnn_lib}": lib_name("cudnn", cpu_value),
"%{cufft_lib}": lib_name("cufft", cpu_value), "%{cufft_lib}": lib_name("cufft", cpu_value),
@ -811,6 +819,7 @@ filegroup(name="cudnn-include")
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value), "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
) )
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
@ -1002,11 +1011,13 @@ def _create_local_cuda_repository(repository_ctx):
cublas_include_path + "/cublas.h", cublas_include_path + "/cublas.h",
cublas_include_path + "/cublas_v2.h", cublas_include_path + "/cublas_v2.h",
cublas_include_path + "/cublas_api.h", cublas_include_path + "/cublas_api.h",
cublas_include_path + "/cublasLt.h",
], ],
outs = [ outs = [
"cublas/include/cublas.h", "cublas/include/cublas.h",
"cublas/include/cublas_v2.h", "cublas/include/cublas_v2.h",
"cublas/include/cublas_api.h", "cublas/include/cublas_api.h",
"cublas/include/cublasLt.h",
], ],
)) ))
@ -1147,6 +1158,7 @@ def _create_local_cuda_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value), "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
"%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]), "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
"%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]), "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
"%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
"%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]), "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
"%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]), "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
"%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]), "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),