adding a check in in the initialization routines for rocblas, rocfft, rocrand to avoid duplicate registrations
This commit is contained in:
parent
834a3f7395
commit
7066b1c04d
@ -2324,35 +2324,43 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
} // namespace gpu
|
||||
|
||||
void initialize_rocblas() {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
|
||||
[](internal::StreamExecutorInterface* parent) -> blas::BlasSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the rocBLAS "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
|
||||
if (!blas->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete blas;
|
||||
return nullptr;
|
||||
}
|
||||
return blas;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocBLAS factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
|
||||
|
||||
if (!rocBlasAlreadyRegistered) {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()
|
||||
->RegisterFactory<PluginRegistry::BlasFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
|
||||
[](internal::StreamExecutorInterface* parent)
|
||||
-> blas::BlasSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the "
|
||||
"rocBLAS "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
|
||||
if (!blas->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete blas;
|
||||
return nullptr;
|
||||
}
|
||||
return blas;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocBLAS factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
@ -592,28 +592,33 @@ STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
|
||||
} // namespace gpu
|
||||
|
||||
void initialize_rocfft() {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
||||
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
|
||||
gpu::GpuExecutor *rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor *>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the rocFFT "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new gpu::ROCMFft(rocm_executor);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocFFT factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
|
||||
|
||||
if (!rocFftAlreadyRegistered) {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
||||
[](internal::StreamExecutorInterface* parent) -> fft::FftSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the rocFFT "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new gpu::ROCMFft(rocm_executor);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocFFT factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
@ -283,35 +283,40 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
||||
} // namespace gpu
|
||||
|
||||
void initialize_rocrand() {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
|
||||
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the hipRAND "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
|
||||
if (!rng->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete rng;
|
||||
return nullptr;
|
||||
}
|
||||
return rng;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocRAND factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
|
||||
|
||||
if (!rocRandAlreadyRegistered) {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
|
||||
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the hipRAND "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
|
||||
if (!rng->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete rng;
|
||||
return nullptr;
|
||||
}
|
||||
return rng;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocRAND factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
Loading…
x
Reference in New Issue
Block a user