From 07b3f341995ff0b012bf9297f500a8e75c682f31 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Fri, 1 Feb 2019 02:47:14 +0000 Subject: [PATCH] changing rocblas, rocfft and rocrand from being dynamically linked to being dynmically loaded --- .../platform/default/dso_loader.cc | 52 ++++++++++++++++++- .../platform/default/dso_loader.h | 13 +++++ tensorflow/stream_executor/rocm/BUILD | 12 ----- tensorflow/stream_executor/rocm/rocm_blas.cc | 37 +++++++++++++ tensorflow/stream_executor/rocm/rocm_fft.cc | 34 ++++++++++++ tensorflow/stream_executor/rocm/rocm_rng.cc | 35 +++++++++++++ 6 files changed, 169 insertions(+), 14 deletions(-) diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index ad8112b831a..3f2bb5e4dd0 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -39,7 +39,7 @@ port::StatusOr GetDsoHandle(const string& name, const string& version) { port::Status status = port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle); if (status.ok()) { - LOG(INFO) << "Successfully opened CUDA library " << filename; + LOG(INFO) << "Successfully opened dynamic library " << filename; return dso_handle; } @@ -54,6 +54,7 @@ port::StatusOr GetDsoHandle(const string& name, const string& version) { return port::Status(port::error::FAILED_PRECONDITION, message); } } // namespace + namespace DsoLoader { port::StatusOr GetCudaDriverDsoHandle() { @@ -99,6 +100,27 @@ port::StatusOr GetCuptiDsoHandle() { port::StatusOr GetCudnnDsoHandle() { return GetDsoHandle("cudnn", GetCudnnVersion()); } + +port::StatusOr GetRocblasDsoHandle() { + return GetDsoHandle("rocblas", ""); +} + +port::StatusOr GetMiopenDsoHandle() { + return GetDsoHandle("MIOpen", ""); +} + +port::StatusOr GetRocfftDsoHandle() { + return GetDsoHandle("rocfft", ""); +} + +port::StatusOr GetRocrandDsoHandle() { + return GetDsoHandle("rocrand", ""); +} + +port::StatusOr GetHipDsoHandle() { + return GetDsoHandle("hip_hcc", ""); +} + } // namespace DsoLoader namespace CachedDsoLoader { @@ -131,11 +153,37 @@ port::StatusOr GetCuptiDsoHandle() { static auto result = new auto(DsoLoader::GetCuptiDsoHandle()); return *result; } - + port::StatusOr GetCudnnDsoHandle() { static auto result = new auto(DsoLoader::GetCudnnDsoHandle()); return *result; } + +port::StatusOr GetRocblasDsoHandle() { + static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); + return result; +} + +port::StatusOr GetMiopenDsoHandle() { + static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); + return result; +} + +port::StatusOr GetRocfftDsoHandle() { + static auto result = new auto(DsoLoader::GetRocfftDsoHandle()); + return result; +} + +port::StatusOr GetRocrandDsoHandle() { + static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); + return result; +} + +port::StatusOr GetHipDsoHandle() { + static auto result = new auto(DsoLoader::GetHipDsoHandle()); + return result; +} + } // namespace CachedDsoLoader } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index 45a8315b436..8da8ea7be66 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -41,6 +41,12 @@ port::StatusOr GetCufftDsoHandle(); port::StatusOr GetCurandDsoHandle(); port::StatusOr GetCuptiDsoHandle(); port::StatusOr GetCudnnDsoHandle(); + +port::StatusOr GetRocblasDsoHandle(); +port::StatusOr GetMiopenDsoHandle(); +port::StatusOr GetRocfftDsoHandle(); +port::StatusOr GetRocrandDsoHandle(); +port::StatusOr GetHipDsoHandle(); } // namespace DsoLoader // Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs @@ -54,7 +60,14 @@ port::StatusOr GetCufftDsoHandle(); port::StatusOr GetCurandDsoHandle(); port::StatusOr GetCuptiDsoHandle(); port::StatusOr GetCudnnDsoHandle(); + +port::StatusOr GetRocblasDsoHandle(); +port::StatusOr GetMiopenDsoHandle(); +port::StatusOr GetRocfftDsoHandle(); +port::StatusOr GetRocrandDsoHandle(); +port::StatusOr GetHipDsoHandle(); } // namespace CachedDsoLoader + } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index 85ab38d4242..f0b05822703 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -167,10 +167,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", ] + if_static([ "@local_config_rocm//rocm:rocblas" - # Delete the following line once we switch the rocblas library from - # being dynamically linked (current behaviour) to being dynamically - # loaded (future behaviour) - ], ["@local_config_rocm//rocm:rocblas" ])), alwayslink = True, ) @@ -197,10 +193,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", ] + if_static([ "@local_config_rocm//rocm:rocfft" - # Delete the following line once we switch the rocfft library from - # being dynamically linked (current behaviour) to being dynamically - # loaded (future behaviour) - ], ["@local_config_rocm//rocm:rocfft" ])), alwayslink = True, ) @@ -265,10 +257,6 @@ cc_library( "//tensorflow/stream_executor/platform:dso_loader", ] + if_static([ "@local_config_rocm//rocm:hiprand" - # Delete the following line once we switch the hiprand library from - # being dynamically linked (current behaviour) to being dynamically - # loaded (future behaviour) - ], ["@local_config_rocm//rocm:hiprand" ])), alwayslink = True, ) diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index b2e225433e5..a626d168c26 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/stringprintf.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -49,6 +50,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin); namespace wrap { +#ifdef PLATFORM_GOOGLE #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ struct WrapperShim__##__name { \ static const char* kName; \ @@ -63,6 +65,41 @@ namespace wrap { #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ STREAM_EXECUTOR_ROCBLAS_WRAP(__name) +#else + +#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char* kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocblas DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char* DynLoadShim__##__name::kName = #__name; + +#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ + STREAM_EXECUTOR_ROCBLAS_WRAP(__name) + +#endif + #define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /* __macro(rocblas_scnrm2) \ __macro(rocblas_dznrm2) */ \ diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc index dd30911eadd..e8a72f61d3f 100644 --- a/tensorflow/stream_executor/rocm/rocm_fft.cc +++ b/tensorflow/stream_executor/rocm/rocm_fft.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -38,6 +39,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin); namespace wrap { +#ifdef PLATFORM_GOOGLE // This macro wraps a global identifier, given by __name, in a callable // structure that loads the DLL symbol out of the DSO handle in a thread-safe // manner on first use. This dynamic loading technique is used to avoid DSO @@ -52,6 +54,38 @@ namespace wrap { } \ } __name; +#else + +#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char *kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void *GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void *f; \ + auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocfft DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + hipfftResult operator()(GpuExecutor *parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char *DynLoadShim__##__name::kName = #__name; + +#endif + #define ROCFFT_ROUTINE_EACH(__macro) \ __macro(hipfftDestroy) \ __macro(hipfftSetStream) \ diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 79250579087..27797845700 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/rng.h" @@ -61,6 +62,8 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin); namespace wrap { +#ifdef PLATFORM_GOOGLE + #define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \ struct WrapperShim__##__name { \ template \ @@ -70,6 +73,38 @@ namespace wrap { } \ } __name; +#else + +#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char *kName; \ + using FuncPtrT = std::add_pointer::type; \ + static void *GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void *f; \ + auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocrand DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + hiprandStatus operator()(GpuExecutor *parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char *DynLoadShim__##__name::kName = #__name; + +#endif + STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator); STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator); STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);