changing rocblas, rocfft and rocrand from being dynamically linked to being dynmically loaded

This commit is contained in:
Deven Desai 2019-02-01 02:47:14 +00:00
parent 234f47031b
commit 07b3f34199
6 changed files with 169 additions and 14 deletions

View File

@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
port::Status status =
port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
if (status.ok()) {
LOG(INFO) << "Successfully opened CUDA library " << filename;
LOG(INFO) << "Successfully opened dynamic library " << filename;
return dso_handle;
}
@ -54,6 +54,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
return port::Status(port::error::FAILED_PRECONDITION, message);
}
} // namespace
namespace DsoLoader {
port::StatusOr<void*> GetCudaDriverDsoHandle() {
@ -99,6 +100,27 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
port::StatusOr<void*> GetCudnnDsoHandle() {
return GetDsoHandle("cudnn", GetCudnnVersion());
}
port::StatusOr<void*> GetRocblasDsoHandle() {
return GetDsoHandle("rocblas", "");
}
port::StatusOr<void*> GetMiopenDsoHandle() {
return GetDsoHandle("MIOpen", "");
}
port::StatusOr<void*> GetRocfftDsoHandle() {
return GetDsoHandle("rocfft", "");
}
port::StatusOr<void*> GetRocrandDsoHandle() {
return GetDsoHandle("rocrand", "");
}
port::StatusOr<void*> GetHipDsoHandle() {
return GetDsoHandle("hip_hcc", "");
}
} // namespace DsoLoader
namespace CachedDsoLoader {
@ -131,11 +153,37 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
static auto result = new auto(DsoLoader::GetCuptiDsoHandle());
return *result;
}
port::StatusOr<void*> GetCudnnDsoHandle() {
static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
return *result;
}
port::StatusOr<void*> GetRocblasDsoHandle() {
static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
return result;
}
port::StatusOr<void*> GetMiopenDsoHandle() {
static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
return result;
}
port::StatusOr<void*> GetRocfftDsoHandle() {
static auto result = new auto(DsoLoader::GetRocfftDsoHandle());
return result;
}
port::StatusOr<void*> GetRocrandDsoHandle() {
static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
return result;
}
port::StatusOr<void*> GetHipDsoHandle() {
static auto result = new auto(DsoLoader::GetHipDsoHandle());
return result;
}
} // namespace CachedDsoLoader
} // namespace internal
} // namespace stream_executor

View File

@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCuptiDsoHandle();
port::StatusOr<void*> GetCudnnDsoHandle();
port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();
port::StatusOr<void*> GetRocfftDsoHandle();
port::StatusOr<void*> GetRocrandDsoHandle();
port::StatusOr<void*> GetHipDsoHandle();
} // namespace DsoLoader
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCuptiDsoHandle();
port::StatusOr<void*> GetCudnnDsoHandle();
port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();
port::StatusOr<void*> GetRocfftDsoHandle();
port::StatusOr<void*> GetRocrandDsoHandle();
port::StatusOr<void*> GetHipDsoHandle();
} // namespace CachedDsoLoader
} // namespace internal
} // namespace stream_executor

View File

@ -167,10 +167,6 @@ cc_library(
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:rocblas"
# Delete the following line once we switch the rocblas library from
# being dynamically linked (current behaviour) to being dynamically
# loaded (future behaviour)
], ["@local_config_rocm//rocm:rocblas"
])),
alwayslink = True,
)
@ -197,10 +193,6 @@ cc_library(
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:rocfft"
# Delete the following line once we switch the rocfft library from
# being dynamically linked (current behaviour) to being dynamically
# loaded (future behaviour)
], ["@local_config_rocm//rocm:rocfft"
])),
alwayslink = True,
)
@ -265,10 +257,6 @@ cc_library(
"//tensorflow/stream_executor/platform:dso_loader",
] + if_static([
"@local_config_rocm//rocm:hiprand"
# Delete the following line once we switch the hiprand library from
# being dynamically linked (current behaviour) to being dynamically
# loaded (future behaviour)
], ["@local_config_rocm//rocm:hiprand"
])),
alwayslink = True,
)

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/status_macros.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
@ -49,6 +50,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
namespace wrap {
#ifdef PLATFORM_GOOGLE
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
struct WrapperShim__##__name { \
static const char* kName; \
@ -63,6 +65,41 @@ namespace wrap {
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
#else
#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocblas DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
rocblas_status operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
#endif
#define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(rocblas_snrm2) __macro(rocblas_dnrm2) /* __macro(rocblas_scnrm2) \
__macro(rocblas_dznrm2) */ \

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
@ -38,6 +39,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
namespace wrap {
#ifdef PLATFORM_GOOGLE
// This macro wraps a global identifier, given by __name, in a callable
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
// manner on first use. This dynamic loading technique is used to avoid DSO
@ -52,6 +54,38 @@ namespace wrap {
} \
} __name;
#else
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void *f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocfft DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
#endif
#define ROCFFT_ROUTINE_EACH(__macro) \
__macro(hipfftDestroy) \
__macro(hipfftSetStream) \

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/rng.h"
@ -61,6 +62,8 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
namespace wrap {
#ifdef PLATFORM_GOOGLE
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
@ -70,6 +73,38 @@ namespace wrap {
} \
} __name;
#else
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void *f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocrand DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hiprandStatus operator()(GpuExecutor *parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
#endif
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);