Use correct cudart soname in GetDsoHandle

This commit is contained in:
Nathan Luehr 2020-08-14 13:21:58 -07:00
parent 4a64bbe4ff
commit 2642e93e6c
3 changed files with 5 additions and 1 deletions

View File

@ -31,6 +31,7 @@ namespace internal {
namespace { namespace {
string GetCudaVersion() { return TF_CUDA_VERSION; } string GetCudaVersion() { return TF_CUDA_VERSION; }
string GetCudaRtVersion() { return TF_CUDART_VERSION; }
string GetCudnnVersion() { return TF_CUDNN_VERSION; } string GetCudnnVersion() { return TF_CUDNN_VERSION; }
string GetCublasVersion() { return TF_CUBLAS_VERSION; } string GetCublasVersion() { return TF_CUBLAS_VERSION; }
string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } string GetCusolverVersion() { return TF_CUSOLVER_VERSION; }
@ -77,7 +78,7 @@ port::StatusOr<void*> GetCudaDriverDsoHandle() {
} }
port::StatusOr<void*> GetCudaRuntimeDsoHandle() { port::StatusOr<void*> GetCudaRuntimeDsoHandle() {
return GetDsoHandle("cudart", GetCudaVersion()); return GetDsoHandle("cudart", GetCudaRtVersion());
} }
port::StatusOr<void*> GetCublasDsoHandle() { port::StatusOr<void*> GetCublasDsoHandle() {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define CUDA_CUDA_CONFIG_H_ #define CUDA_CUDA_CONFIG_H_
#define TF_CUDA_VERSION "%{cuda_version}" #define TF_CUDA_VERSION "%{cuda_version}"
#define TF_CUDART_VERSION "%{cudart_version}"
#define TF_CUBLAS_VERSION "%{cublas_version}" #define TF_CUBLAS_VERSION "%{cublas_version}"
#define TF_CUSOLVER_VERSION "%{cusolver_version}" #define TF_CUSOLVER_VERSION "%{cusolver_version}"
#define TF_CURAND_VERSION "%{curand_version}" #define TF_CURAND_VERSION "%{curand_version}"

View File

@ -824,6 +824,7 @@ filegroup(name="cudnn-include")
"cuda:cuda_config.h", "cuda:cuda_config.h",
{ {
"%{cuda_version}": "", "%{cuda_version}": "",
"%{cudart_version}": "",
"%{cublas_version}": "", "%{cublas_version}": "",
"%{cusolver_version}": "", "%{cusolver_version}": "",
"%{curand_version}": "", "%{curand_version}": "",
@ -1289,6 +1290,7 @@ def _create_local_cuda_repository(repository_ctx):
tpl_paths["cuda:cuda_config.h"], tpl_paths["cuda:cuda_config.h"],
{ {
"%{cuda_version}": cuda_config.cuda_version, "%{cuda_version}": cuda_config.cuda_version,
"%{cudart_version}": cuda_config.cudart_version,
"%{cublas_version}": cuda_config.cublas_version, "%{cublas_version}": cuda_config.cublas_version,
"%{cusolver_version}": cuda_config.cusolver_version, "%{cusolver_version}": cuda_config.cusolver_version,
"%{curand_version}": cuda_config.curand_version, "%{curand_version}": cuda_config.curand_version,