Remove kernels' explicit dependency on cusolver and cusparse.
- Add dsoload stubs for cusolver and cusparse. - Use stub when if_static is not set. PiperOrigin-RevId: 243877673
This commit is contained in:
parent
e7c9ad64eb
commit
a4fae76a79
@ -5,6 +5,7 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
|
||||
@ -598,15 +599,17 @@ cc_library(
|
||||
srcs = ["cusolver_context.cc"],
|
||||
hdrs = ["cusolver_context.h"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:blas",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cusolver",
|
||||
],
|
||||
] + if_static(
|
||||
["@local_config_cuda//cuda:cusolver"],
|
||||
["//tensorflow/stream_executor/cuda:cusolver_stub"],
|
||||
),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -53,6 +53,7 @@ load(
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load(
|
||||
@ -3030,9 +3031,16 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
"@local_config_cuda//cuda:cublas",
|
||||
"@local_config_cuda//cuda:cusolver",
|
||||
],
|
||||
] + if_static(
|
||||
[
|
||||
"@local_config_cuda//cuda:cusolver",
|
||||
"@local_config_cuda//cuda:cublas",
|
||||
],
|
||||
[
|
||||
"//tensorflow/stream_executor/cuda:cusolver_stub",
|
||||
"//tensorflow/stream_executor/cuda:cublas_stub",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -3042,8 +3050,10 @@ tf_kernel_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@local_config_cuda//cuda:cusparse",
|
||||
],
|
||||
] + if_static(
|
||||
["@local_config_cuda//cuda:cusparse"],
|
||||
["//tensorflow/stream_executor/cuda:cusparse_stub"],
|
||||
),
|
||||
)
|
||||
|
||||
LINALG_DEPS = [
|
||||
|
@ -350,6 +350,28 @@ cc_library(
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_stub",
|
||||
srcs = if_cuda_is_configured(["cusolver_stub.cc"]),
|
||||
textual_hdrs = ["cusolver_dense_10_0.inc"],
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusparse_stub",
|
||||
srcs = if_cuda_is_configured(["cusparse_stub.cc"]),
|
||||
textual_hdrs = glob(["cusparse_*.inc"]),
|
||||
deps = if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cuda_kernel",
|
||||
srcs = if_cuda_is_configured(["cuda_kernel.cc"]),
|
||||
|
2283
tensorflow/stream_executor/cuda/cusolver_dense_10_0.inc
Normal file
2283
tensorflow/stream_executor/cuda/cusolver_dense_10_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
53
tensorflow/stream_executor/cuda/cusolver_stub.cc
Normal file
53
tensorflow/stream_executor/cuda/cusolver_stub.cc
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "cuda/include/cusolverDn.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
||||
// Implements the cusolver API by forwarding to cusolver loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetCusolverDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
stream_executor::port::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
cusolverStatus_t GetSymbolNotFoundError() {
|
||||
return CUSOLVER_STATUS_INTERNAL_ERROR;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cusolver_dense_10_0.inc"
|
7800
tensorflow/stream_executor/cuda/cusparse_10_0.inc
Normal file
7800
tensorflow/stream_executor/cuda/cusparse_10_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
7119
tensorflow/stream_executor/cuda/cusparse_9_0.inc
Normal file
7119
tensorflow/stream_executor/cuda/cusparse_9_0.inc
Normal file
File diff suppressed because it is too large
Load Diff
57
tensorflow/stream_executor/cuda/cusparse_stub.cc
Normal file
57
tensorflow/stream_executor/cuda/cusparse_stub.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "cuda/include/cusparse.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
|
||||
// Implements the cusparse API by forwarding to cusparse loaded from the DSO.
|
||||
|
||||
namespace {
|
||||
// Returns DSO handle or null if loading the DSO fails.
|
||||
void* GetDsoHandle() {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
return nullptr;
|
||||
#else
|
||||
static auto handle = []() -> void* {
|
||||
auto handle_or =
|
||||
stream_executor::internal::DsoLoader::GetCusparseDsoHandle();
|
||||
if (!handle_or.ok()) return nullptr;
|
||||
return handle_or.ValueOrDie();
|
||||
}();
|
||||
return handle;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T LoadSymbol(const char* symbol_name) {
|
||||
void* symbol = nullptr;
|
||||
if (auto handle = GetDsoHandle()) {
|
||||
stream_executor::port::Env::Default()
|
||||
->GetSymbolFromLibrary(handle, symbol_name, &symbol)
|
||||
.IgnoreError();
|
||||
}
|
||||
return reinterpret_cast<T>(symbol);
|
||||
}
|
||||
|
||||
cusparseStatus_t GetSymbolNotFoundError() {
|
||||
return CUSPARSE_STATUS_INTERNAL_ERROR;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if CUDA_VERSION < 9020
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_9_0.inc"
|
||||
#else
|
||||
#include "tensorflow/stream_executor/cuda/cusparse_10_0.inc"
|
||||
#endif
|
@ -83,6 +83,14 @@ port::StatusOr<void*> GetCufftDsoHandle() {
|
||||
return GetDsoHandle("cufft", GetCudaLibVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusolverDsoHandle() {
|
||||
return GetDsoHandle("cusolver", GetCudaVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusparseDsoHandle() {
|
||||
return GetDsoHandle("cusparse", GetCudaVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCurandDsoHandle() {
|
||||
return GetDsoHandle("curand", GetCudaLibVersion());
|
||||
}
|
||||
@ -147,6 +155,16 @@ port::StatusOr<void*> GetCufftDsoHandle() {
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusolverDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCusolverDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCusparseDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCusparseDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetCuptiDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCuptiDsoHandle());
|
||||
return *result;
|
||||
|
@ -39,6 +39,8 @@ port::StatusOr<void*> GetCudaRuntimeDsoHandle();
|
||||
port::StatusOr<void*> GetCublasDsoHandle();
|
||||
port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
port::StatusOr<void*> GetCusparseDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
|
||||
@ -58,6 +60,8 @@ port::StatusOr<void*> GetCudaRuntimeDsoHandle();
|
||||
port::StatusOr<void*> GetCublasDsoHandle();
|
||||
port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCusolverDsoHandle();
|
||||
port::StatusOr<void*> GetCusparseDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user