adding a check in in the initialization routines for rocblas, rocfft, rocrand to avoid duplicate registrations

This commit is contained in:
Deven Desai 2019-02-01 18:32:40 +00:00
parent 834a3f7395
commit 7066b1c04d
3 changed files with 95 additions and 77 deletions

View File

@ -2324,35 +2324,43 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
} // namespace gpu
void initialize_rocblas() {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
[](internal::StreamExecutorInterface* parent) -> blas::BlasSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the rocBLAS "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
if (!blas->Init()) {
// Note: Init() will log a more specific error.
delete blas;
return nullptr;
}
return blas;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocBLAS factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
if (!rocBlasAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()
->RegisterFactory<PluginRegistry::BlasFactory>(
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
[](internal::StreamExecutorInterface* parent)
-> blas::BlasSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the "
"rocBLAS "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
if (!blas->Init()) {
// Note: Init() will log a more specific error.
delete blas;
return nullptr;
}
return blas;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocBLAS factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
}
}
} // namespace stream_executor

View File

@ -592,28 +592,33 @@ STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
} // namespace gpu
void initialize_rocfft() {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
gpu::GpuExecutor *rocm_executor =
dynamic_cast<gpu::GpuExecutor *>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the rocFFT "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
return new gpu::ROCMFft(rocm_executor);
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocFFT factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
if (!rocFftAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
[](internal::StreamExecutorInterface* parent) -> fft::FftSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the rocFFT "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
return new gpu::ROCMFft(rocm_executor);
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocFFT factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
}
}
} // namespace stream_executor

View File

@ -283,35 +283,40 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
} // namespace gpu
void initialize_rocrand() {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the hipRAND "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
if (!rng->Init()) {
// Note: Init() will log a more specific error.
delete rng;
return nullptr;
}
return rng;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocRAND factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
if (!rocRandAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the hipRAND "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
if (!rng->Init()) {
// Note: Init() will log a more specific error.
delete rng;
return nullptr;
}
return rng;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocRAND factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
}
}
} // namespace stream_executor