changing rocblas, rocfft and rocrand from being dynamically linked to being dynmically loaded
This commit is contained in:
parent
234f47031b
commit
07b3f34199
@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||
port::Status status =
|
||||
port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
|
||||
if (status.ok()) {
|
||||
LOG(INFO) << "Successfully opened CUDA library " << filename;
|
||||
LOG(INFO) << "Successfully opened dynamic library " << filename;
|
||||
return dso_handle;
|
||||
}
|
||||
|
||||
@ -54,6 +54,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||
return port::Status(port::error::FAILED_PRECONDITION, message);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace DsoLoader {
|
||||
port::StatusOr<void*> GetCudaDriverDsoHandle() {
|
||||
@ -99,6 +100,27 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
|
||||
port::StatusOr<void*> GetCudnnDsoHandle() {
|
||||
return GetDsoHandle("cudnn", GetCudnnVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||
return GetDsoHandle("rocblas", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetMiopenDsoHandle() {
|
||||
return GetDsoHandle("MIOpen", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocfftDsoHandle() {
|
||||
return GetDsoHandle("rocfft", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||
return GetDsoHandle("rocrand", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipDsoHandle() {
|
||||
return GetDsoHandle("hip_hcc", "");
|
||||
}
|
||||
|
||||
} // namespace DsoLoader
|
||||
|
||||
namespace CachedDsoLoader {
|
||||
@ -131,11 +153,37 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCuptiDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
|
||||
port::StatusOr<void*> GetCudnnDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetMiopenDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocfftDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetRocfftDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetHipDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace CachedDsoLoader
|
||||
} // namespace internal
|
||||
} // namespace stream_executor
|
||||
|
@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||
port::StatusOr<void*> GetHipDsoHandle();
|
||||
} // namespace DsoLoader
|
||||
|
||||
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
|
||||
@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||
port::StatusOr<void*> GetHipDsoHandle();
|
||||
} // namespace CachedDsoLoader
|
||||
|
||||
} // namespace internal
|
||||
} // namespace stream_executor
|
||||
|
||||
|
@ -167,10 +167,6 @@ cc_library(
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:rocblas"
|
||||
# Delete the following line once we switch the rocblas library from
|
||||
# being dynamically linked (current behaviour) to being dynamically
|
||||
# loaded (future behaviour)
|
||||
], ["@local_config_rocm//rocm:rocblas"
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
@ -197,10 +193,6 @@ cc_library(
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:rocfft"
|
||||
# Delete the following line once we switch the rocfft library from
|
||||
# being dynamically linked (current behaviour) to being dynamically
|
||||
# loaded (future behaviour)
|
||||
], ["@local_config_rocm//rocm:rocfft"
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
@ -265,10 +257,6 @@ cc_library(
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:hiprand"
|
||||
# Delete the following line once we switch the hiprand library from
|
||||
# being dynamically linked (current behaviour) to being dynamically
|
||||
# loaded (future behaviour)
|
||||
], ["@local_config_rocm//rocm:hiprand"
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/lib/status_macros.h"
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
@ -49,6 +50,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
static const char* kName; \
|
||||
@ -63,6 +65,41 @@ namespace wrap {
|
||||
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
||||
|
||||
#else
|
||||
|
||||
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char* kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void* GetDsoHandle() { \
|
||||
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void* f; \
|
||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in rocblas DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
} \
|
||||
static FuncPtrT DynLoad() { \
|
||||
static FuncPtrT f = LoadOrDie(); \
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char* DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
|
||||
|
||||
#endif
|
||||
|
||||
#define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \
|
||||
__macro(rocblas_snrm2) __macro(rocblas_dnrm2) /* __macro(rocblas_scnrm2) \
|
||||
__macro(rocblas_dznrm2) */ \
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
@ -38,6 +39,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// This macro wraps a global identifier, given by __name, in a callable
|
||||
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
|
||||
// manner on first use. This dynamic loading technique is used to avoid DSO
|
||||
@ -52,6 +54,38 @@ namespace wrap {
|
||||
} \
|
||||
} __name;
|
||||
|
||||
#else
|
||||
|
||||
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char *kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void *GetDsoHandle() { \
|
||||
auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void *f; \
|
||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in rocfft DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
} \
|
||||
static FuncPtrT DynLoad() { \
|
||||
static FuncPtrT f = LoadOrDie(); \
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char *DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#endif
|
||||
|
||||
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||
__macro(hipfftDestroy) \
|
||||
__macro(hipfftSetStream) \
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/rng.h"
|
||||
|
||||
@ -61,6 +62,8 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
||||
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
@ -70,6 +73,38 @@ namespace wrap {
|
||||
} \
|
||||
} __name;
|
||||
|
||||
#else
|
||||
|
||||
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char *kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void *GetDsoHandle() { \
|
||||
auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void *f; \
|
||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in rocrand DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
} \
|
||||
static FuncPtrT DynLoad() { \
|
||||
static FuncPtrT f = LoadOrDie(); \
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
hiprandStatus operator()(GpuExecutor *parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char *DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#endif
|
||||
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
|
||||
|
Loading…
x
Reference in New Issue
Block a user