From aa26dce3be923a3833f52384ec194c2cfac76d43 Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Thu, 31 Jan 2019 15:16:57 +0000 Subject: [PATCH 1/7] proper fix the earlier compile error for --config=rocm, + some minor changes --- tensorflow/stream_executor/rocm/BUILD | 2 +- .../stream_executor/rocm/rocm_diagnostics.cc | 24 +++++++---- .../stream_executor/rocm/rocm_diagnostics.h | 41 +++++++++++++++++++ .../stream_executor/rocm/rocm_gpu_executor.cc | 19 ++++----- .../stream_executor/rocm/rocm_platform.cc | 2 +- .../stream_executor/rocm/rocm_platform_id.cc | 2 +- .../stream_executor/rocm/rocm_platform_id.h | 6 +-- tensorflow/stream_executor/rocm/rocm_rng.cc | 4 +- 8 files changed, 72 insertions(+), 28 deletions(-) create mode 100644 tensorflow/stream_executor/rocm/rocm_diagnostics.h diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index dd08b36308c..38cd19b3cb8 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -33,7 +33,7 @@ filegroup( cc_library( name = "rocm_diagnostics", srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]), - hdrs = [], + hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]), deps = if_rocm_is_configured([ "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", diff --git a/tensorflow/stream_executor/rocm/rocm_diagnostics.cc b/tensorflow/stream_executor/rocm/rocm_diagnostics.cc index c6da7f9e3da..01492a8ef97 100644 --- a/tensorflow/stream_executor/rocm/rocm_diagnostics.cc +++ b/tensorflow/stream_executor/rocm/rocm_diagnostics.cc @@ -30,7 +30,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h" +#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/numbers.h" #include "tensorflow/stream_executor/lib/process_state.h" @@ -40,7 +40,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" namespace stream_executor { -namespace gpu { +namespace rocm { string DriverVersionToString(DriverVersion version) { return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version), @@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) { return result; } +} // namespace rocm +} // namespace stream_executor + +namespace stream_executor { +namespace gpu { + // -- class Diagnostician string Diagnostician::GetDevNodePath(int dev_node_ordinal) { @@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() { } port::StatusOr<DriverVersion> dso_version = FindDsoVersion(); LOG(INFO) << "librocm reported version is: " - << DriverVersionStatusToString(dso_version); + << rocm::DriverVersionStatusToString(dso_version); port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion(); LOG(INFO) << "kernel reported version is: " - << DriverVersionStatusToString(kernel_version); + << rocm::DriverVersionStatusToString(kernel_version); if (kernel_version.ok() && dso_version.ok()) { WarnOnDsoKernelMismatch(dso_version, kernel_version); @@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() { // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64"); auto result = static_cast<port::StatusOr<DriverVersion>*>(data); - *result = StringToDriverVersion(stripped_dso_version); + *result = rocm::StringToDriverVersion(stripped_dso_version); return 1; } return 0; @@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion( // TODO(b/22689637): Eliminate the explicit namespace if possible. auto stripped_kernel_version = port::StripSuffixString(kernel_version, ".ld64"); - return StringToDriverVersion(stripped_kernel_version); + return rocm::StringToDriverVersion(stripped_kernel_version); } void Diagnostician::WarnOnDsoKernelMismatch( @@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch( if (kernel_version.ok() && dso_version.ok() && dso_version.ValueOrDie() == kernel_version.ValueOrDie()) { LOG(INFO) << "kernel version seems to match DSO: " - << DriverVersionToString(kernel_version.ValueOrDie()); + << rocm::DriverVersionToString(kernel_version.ValueOrDie()); } else { LOG(ERROR) << "kernel version " - << DriverVersionStatusToString(kernel_version) + << rocm::DriverVersionStatusToString(kernel_version) << " does not match DSO version " - << DriverVersionStatusToString(dso_version) + << rocm::DriverVersionStatusToString(dso_version) << " -- cannot find working devices in this configuration"; } } diff --git a/tensorflow/stream_executor/rocm/rocm_diagnostics.h b/tensorflow/stream_executor/rocm/rocm_diagnostics.h new file mode 100644 index 00000000000..233c6bdade6 --- /dev/null +++ b/tensorflow/stream_executor/rocm/rocm_diagnostics.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_ + +#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h" + +namespace stream_executor { +namespace rocm { + +// e.g. DriverVersion{346, 3, 4} +using DriverVersion = gpu::DriverVersion; + +// Converts a parsed driver version to string form. +string DriverVersionToString(DriverVersion version); + +// Converts a parsed driver version or status value to natural string form. +string DriverVersionStatusToString(port::StatusOr<DriverVersion> version); + +// Converts a string of a form like "331.79" to a DriverVersion{331, 79}. +port::StatusOr<DriverVersion> StringToDriverVersion(const string& value); + +using Diagnostician = gpu::Diagnostician; + +} // namespace rocm +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_ diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc index 0e38556ee7d..684172cfb5a 100644 --- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc +++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h" #include "tensorflow/stream_executor/gpu/gpu_driver.h" #include "tensorflow/stream_executor/gpu/gpu_event.h" #include "tensorflow/stream_executor/gpu/gpu_executor.h" @@ -41,6 +40,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" +#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h" #include "tensorflow/stream_executor/rocm/rocm_platform_id.h" #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { blas::BlasSupport* GpuExecutor::CreateBlas() { PluginRegistry* registry = PluginRegistry::Instance(); port::StatusOr<PluginRegistry::BlasFactory> status = - registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId, + registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId, plugin_config_.blas()); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve BLAS factory: " @@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() { dnn::DnnSupport* GpuExecutor::CreateDnn() { PluginRegistry* registry = PluginRegistry::Instance(); port::StatusOr<PluginRegistry::DnnFactory> status = - registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId, + registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId, plugin_config_.dnn()); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve DNN factory: " @@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() { fft::FftSupport* GpuExecutor::CreateFft() { PluginRegistry* registry = PluginRegistry::Instance(); port::StatusOr<PluginRegistry::FftFactory> status = - registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId, + registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId, plugin_config_.fft()); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve FFT factory: " @@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() { rng::RngSupport* GpuExecutor::CreateRng() { PluginRegistry* registry = PluginRegistry::Instance(); port::StatusOr<PluginRegistry::RngFactory> status = - registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId, + registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId, plugin_config_.rng()); if (!status.ok()) { LOG(ERROR) << "Unable to retrieve RNG factory: " @@ -878,12 +878,9 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const { { int driver_version = 0; (void)GpuDriver::GetDriverVersion(&driver_version); - string augmented_driver_version = - absl::StrFormat("%d (%s)", driver_version, "__FIXME__"); - // FIXME: - // uncomment the line below once the "DriverVersionStatusToString" - // routine is moved from the "cuda" namespace to the "gpu" naemspace - // DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str()); + string augmented_driver_version = absl::StrFormat( + "%d (%s)", driver_version, + rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str()); builder.set_driver_version(augmented_driver_version); } diff --git a/tensorflow/stream_executor/rocm/rocm_platform.cc b/tensorflow/stream_executor/rocm/rocm_platform.cc index 113371dd553..ce091658da4 100644 --- a/tensorflow/stream_executor/rocm/rocm_platform.cc +++ b/tensorflow/stream_executor/rocm/rocm_platform.cc @@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus( absl::StrFormat("Executor for bus %d not found.", bus_ordinal)}; } -Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; } +Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; } int ROCmPlatform::VisibleDeviceCount() const { // Throw away the result - it logs internally, and this [containing] function diff --git a/tensorflow/stream_executor/rocm/rocm_platform_id.cc b/tensorflow/stream_executor/rocm/rocm_platform_id.cc index daa42ab022a..16f48bf12d2 100644 --- a/tensorflow/stream_executor/rocm/rocm_platform_id.cc +++ b/tensorflow/stream_executor/rocm/rocm_platform_id.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/stream_executor/rocm/rocm_platform_id.h" namespace stream_executor { -namespace gpu { +namespace rocm { PLATFORM_DEFINE_ID(kROCmPlatformId); diff --git a/tensorflow/stream_executor/rocm/rocm_platform_id.h b/tensorflow/stream_executor/rocm/rocm_platform_id.h index 71c760b8277..a17d4f97bbc 100644 --- a/tensorflow/stream_executor/rocm/rocm_platform_id.h +++ b/tensorflow/stream_executor/rocm/rocm_platform_id.h @@ -19,16 +19,16 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" namespace stream_executor { -namespace gpu { +namespace rocm { // Opaque and unique identifier for the ROCm platform. // This is needed so that plugins can refer to/identify this platform without // instantiating a ROCmPlatform object. // This is broken out here to avoid a circular dependency between ROCmPlatform -// and GpuExecutor. +// and ROCmExecutor. extern const Platform::Id kROCmPlatformId; -} // namespace gpu +} // namespace rocm } // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_ diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 2048c8ff644..65acd03c92b 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -253,7 +253,7 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, { se::port::Status status = se::PluginRegistry::Instance() ->RegisterFactory<se::PluginRegistry::RngFactory>( - se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND", + se::rocm::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND", [](se::internal::StreamExecutorInterface* parent) -> se::rng::RngSupport* { se::gpu::GpuExecutor* rocm_executor = @@ -280,5 +280,5 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, { } se::PluginRegistry::Instance()->SetDefaultFactory( - se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin); + se::rocm::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin); }); From 298383d00b5d385d0f58e5ca5bd860fc4508d37a Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Thu, 31 Jan 2019 19:23:21 +0000 Subject: [PATCH 2/7] adding code for rocblas plugin --- tensorflow/stream_executor/rocm/BUILD | 64 +- tensorflow/stream_executor/rocm/rocm_blas.cc | 2324 ++++++++++++++++++ tensorflow/stream_executor/rocm/rocm_blas.h | 159 ++ 3 files changed, 2518 insertions(+), 29 deletions(-) create mode 100644 tensorflow/stream_executor/rocm/rocm_blas.cc create mode 100644 tensorflow/stream_executor/rocm/rocm_blas.h diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index 38cd19b3cb8..737a4429469 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -140,34 +140,40 @@ cc_library( deps = ["//tensorflow/stream_executor:platform"], ) -# FIXME: enable in future PRs -#cc_library( -# name = "rocblas_plugin", -# srcs = ["rocm_blas.cc"], -# hdrs = ["rocm_blas.h"], -# visibility = ["//visibility:public"], -# deps = [ -# ":rocm_gpu_executor", -# ":rocm_platform_id", -# "//third_party/eigen3", -# "//tensorflow/core:lib_internal", -# "//tensorflow/stream_executor", -# "//tensorflow/stream_executor:event", -# "//tensorflow/stream_executor:host_or_device_scalar", -# "//tensorflow/stream_executor:plugin_registry", -# "//tensorflow/stream_executor:scratch_allocator", -# "//tensorflow/stream_executor:timer", -# "//tenosrflow/stream_executor/gpu:gpu_activation_header", -# "//tenosrflow/stream_executor/gpu:gpu_stream_header", -# "//tenosrflow/stream_executor/gpu:gpu_timer_header", -# "//tensorflow/stream_executor/lib", -# "//tensorflow/stream_executor/platform", -# "//tensorflow/stream_executor/platform:dso_loader", -# "@com_google_absl//absl/strings", -# "@local_config_rocm//rocm:rocm_headers", -# ] + if_static(["@local_config_rocm//rocm:rocblas"]), -# alwayslink = True, -#) +cc_library( + name = "rocblas_plugin", + srcs = if_rocm_is_configured(["rocm_blas.cc"]), + hdrs = if_rocm_is_configured(["rocm_blas.h"]), + visibility = ["//visibility:public"], + deps = if_rocm_is_configured([ + ":rocm_gpu_executor", + ":rocm_platform_id", + "//third_party/eigen3", + "//tensorflow/core:lib_internal", + "//tensorflow/stream_executor", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:host_or_device_scalar", + "//tensorflow/stream_executor:plugin_registry", + "//tensorflow/stream_executor:scratch_allocator", + "//tensorflow/stream_executor:timer", + "//tensorflow/stream_executor/gpu:gpu_activation", + "//tensorflow/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/stream_executor/gpu:gpu_stream_header", + "//tensorflow/stream_executor/gpu:gpu_timer_header", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:rocm_headers", + ] + if_static([ + "@local_config_rocm//rocm:rocblas" + # Delete the following line once we switch the rocblas library from + # being dynamically linked (current behaviour) to being dynamically + # loaded (future behaviour) + ], ["@local_config_rocm//rocm:rocblas" + ])), + alwayslink = True, +) # FIXME: enable in future PRs #cc_library( @@ -258,7 +264,7 @@ cc_library( # FIXME: enable in future PRs #":miopen_plugin", #":rocfft_plugin", - #":rocblas_plugin", + ":rocblas_plugin", #":rocrand_plugin", ":rocm_driver", ":rocm_platform", diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc new file mode 100644 index 00000000000..b2e225433e5 --- /dev/null +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -0,0 +1,2324 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rocm/include/rocblas.h" + +#include "tensorflow/stream_executor/rocm/rocm_blas.h" + +#define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include <assert.h> +#include <complex> + +#include "absl/strings/str_cat.h" +#include "tensorflow/stream_executor/gpu/gpu_activation.h" +#include "tensorflow/stream_executor/gpu/gpu_executor.h" +#include "tensorflow/stream_executor/gpu/gpu_helpers.h" +#include "tensorflow/stream_executor/gpu/gpu_stream.h" +#include "tensorflow/stream_executor/gpu/gpu_timer.h" +#include "tensorflow/stream_executor/rocm/rocm_platform_id.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/status_macros.h" +#include "tensorflow/stream_executor/lib/stringprintf.h" +#include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/plugin_registry.h" +#include "tensorflow/stream_executor/scratch_allocator.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace stream_executor { +namespace gpu { + +PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin); + +namespace wrap { + +#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ + struct WrapperShim__##__name { \ + static const char* kName; \ + template <typename... Args> \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return ::__name(args...); \ + } \ + } __name; \ + const char* WrapperShim__##__name::kName = #__name; + +#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ + STREAM_EXECUTOR_ROCBLAS_WRAP(__name) + +#define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \ + __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /* __macro(rocblas_scnrm2) \ + __macro(rocblas_dznrm2) */ \ + __macro(rocblas_sdot) \ + __macro(rocblas_ddot) /* __macro(rocblas_cdotu) \ + __macro(rocblas_cdotc) \ + __macro(rocblas_zdotu) \ + __macro(rocblas_zdotc) */ \ + __macro(rocblas_sscal) \ + __macro(rocblas_dscal) /* __macro(rocblas_cscal) \ + __macro(rocblas_csscal) \ + __macro(rocblas_zscal) \ + __macro(rocblas_zdscal) */ \ + __macro(rocblas_saxpy) \ + __macro(rocblas_daxpy) /* __macro(rocblas_caxpy) \ + __macro(rocblas_zaxpy) */ \ + __macro(rocblas_scopy) \ + __macro(rocblas_dcopy) /* __macro(rocblas_ccopy) \ + __macro(rocblas_zcopy) */ \ + __macro(rocblas_sswap) \ + __macro(rocblas_dswap) /* __macro(rocblas_cswap) \ + __macro(rocblas_zswap) */ \ + __macro(rocblas_isamax) \ + __macro(rocblas_idamax) /* __macro(rocblas_icamax) \ + __macro(rocblas_izamax) */ \ + __macro(rocblas_isamin) \ + __macro(rocblas_idamin) /* __macro(rocblas_icamin) \ + __macro(rocblas_izamin) */ \ + __macro(rocblas_sasum) \ + __macro(rocblas_dasum) /* __macro(rocblas_scasum) \ + __macro(rocblas_dzasum) \ + __macro(rocblas_srot) \ + __macro(rocblas_drot) \ + __macro(rocblas_crot) \ + __macro(rocblas_csrot) \ + __macro(rocblas_zrot) \ + __macro(rocblas_zdrot) \ + __macro(rocblas_srotg) \ + __macro(rocblas_drotg) \ + __macro(rocblas_Crotg) \ + __macro(rocblas_crotg) \ + __macro(rocblas_zrotm) \ + __macro(rocblas_drotm) \ + __macro(rocblas_srotmg) \ + __macro(rocblas_drotmg) */ \ + __macro(rocblas_sgemv) \ + __macro(rocblas_dgemv) /* __macro(rocblas_cgemv) \ + __macro(rocblas_zgemv) \ + __macro(rocblas_sgbmv) \ + __macro(rocblas_dgbmv) \ + __macro(rocblas_cgbmv) \ + __macro(rocblas_zgbmv) \ + __macro(rocblas_strmv) \ + __macro(rocblas_dtrmv) \ + __macro(rocblas_ctrmv) \ + __macro(rocblas_ztrmv) \ + __macro(rocblas_stbmv) \ + __macro(rocblas_dtbmv) \ + __macro(rocblas_ctbmv) \ + __macro(rocblas_ztbmv) \ + __macro(rocblas_stpmv) \ + __macro(rocblas_dtpmv) \ + __macro(rocblas_ctpmv) \ + __macro(rocblas_ztpmv) \ + __macro(rocblas_strsv) \ + __macro(rocblas_dtrsv) \ + __macro(rocblas_ctrsv) \ + __macro(rocblas_ztrsv) \ + __macro(rocblas_stpsv) \ + __macro(rocblas_dtpsv) \ + __macro(rocblas_ctpsv) \ + __macro(rocblas_ztpsv) \ + __macro(rocblas_stbsv) \ + __macro(rocblas_dtbsv) \ + __macro(rocblas_ctbsv) \ + __macro(rocblas_ztbsv) \ + __macro(rocblas_ssymv) \ + __macro(rocblas_dsymv) \ + __macro(rocblas_csymv) \ + __macro(rocblas_zsymv) \ + __macro(rocblas_chemv) \ + __macro(rocblas_zhemv) \ + __macro(rocblas_ssbmv) \ + __macro(rocblas_dsbmv) \ + __macro(rocblas_chbmv) \ + __macro(rocblas_zhbmv) \ + __macro(rocblas_sspmv) \ + __macro(rocblas_dspmv) \ + __macro(rocblas_chpmv) \ + __macro(rocblas_zhpmv) */ \ + __macro(rocblas_sger) \ + __macro(rocblas_dger) /* __macro(rocblas_cgeru) \ + __macro(rocblas_cgerc) \ + __macro(rocblas_zgeru) \ + __macro(rocblas_zgerc) */ \ + __macro(rocblas_ssyr) \ + __macro(rocblas_dsyr) /* __macro(rocblas_csyr) \ + __macro(rocblas_zsyr) \ + __macro(rocblas_cher) \ + __macro(rocblas_zher) \ + __macro(rocblas_sspr) \ + __macro(rocblas_dspr) \ + __macro(rocblas_chpr) \ + __macro(rocblas_zhpr) \ + __macro(rocblas_ssyr2) \ + __macro(rocblas_dsyr2) \ + __macro(rocblas_csyr2) \ + __macro(rocblas_zsyr2) \ + __macro(rocblas_cher2) \ + __macro(rocblas_zher2) \ + __macro(rocblas_sspr2) \ + __macro(rocblas_dspr2) \ + __macro(rocblas_chpr2) \ + __macro(rocblas_zhpr2) */ \ + __macro(rocblas_sgemm) __macro(rocblas_dgemm) \ + __macro(rocblas_hgemm) /* __macro(rocblas_cgemm) \ + __macro(rocblas_zgemm) \ + __macro(rocblas_ssyrk) \ + __macro(rocblas_dsyrk) \ + __macro(rocblas_csyrk) \ + __macro(rocblas_zsyrk) \ + __macro(rocblas_cherk) \ + __macro(rocblas_zherk) \ + __macro(rocblas_ssyr2k) \ + __macro(rocblas_dsyr2k) \ + __macro(rocblas_csyr2k) \ + __macro(rocblas_zsyr2k) \ + __macro(rocblas_cher2k) \ + __macro(rocblas_zher2k) \ + __macro(rocblas_ssyrkx) \ + __macro(rocblas_dsyrkx) \ + __macro(rocblas_csyrkx) \ + __macro(rocblas_zsyrkx) \ + __macro(rocblas_cherkx) \ + __macro(rocblas_zherkx) \ + __macro(rocblas_ssymm) \ + __macro(rocblas_dsymm) \ + __macro(rocblas_csymm) \ + __macro(rocblas_zsymm) \ + __macro(rocblas_chemm) \ + __macro(rocblas_zhemm) */ \ + __macro(rocblas_strsm) \ + __macro(rocblas_dtrsm) /* __macro(rocblas_ctrsm) \ + __macro(rocblas_ztrsm) \ + __macro(rocblas_strmm) \ + __macro(rocblas_dtrmm) \ + __macro(rocblas_ctrmm) \ + __macro(rocblas_ztrmm) */ \ + __macro(rocblas_sgeam) \ + __macro(rocblas_dgeam) /* __macro(rocblas_cgeam) \ + __macro(rocblas_zgeam) \ + __macro(rocblas_sdgmm) \ + __macro(rocblas_ddgmm) \ + __macro(rocblas_cdgmm) \ + __macro(rocblas_zdgmm) */ + +STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle) +STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle) +STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream) +// STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode) +// STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode) +// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched) +STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched) +STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched) +// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched) +STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched) +// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched) +// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched) +ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP) + +} // namespace wrap + +static string ToString(rocblas_status status) { + switch (status) { + case rocblas_status_success: + return "rocblas_status_success"; + case rocblas_status_invalid_handle: + return "rocblas_status_invalid_handle"; + case rocblas_status_not_implemented: + return "rocblas_status_not_implemented"; + case rocblas_status_invalid_pointer: + return "rocblas_status_invalid_pointer"; + case rocblas_status_invalid_size: + return "rocblas_status_invalid_size"; + case rocblas_status_memory_error: + return "rocblas_status_memory_error"; + case rocblas_status_internal_error: + return "rocblas_status_internal_error"; + default: + return absl::StrCat("<invalid rocBLAS status: ", status, ">"); + } +} + +bool ROCMBlas::Init() { + rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret); + return false; + } + + return true; +} + +ROCMBlas::ROCMBlas(gpu::GpuExecutor* parent) + : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} + +ROCMBlas::~ROCMBlas() { + if (blas_ != nullptr) { + wrap::rocblas_destroy_handle(parent_, blas_); + } +} + +bool ROCMBlas::SetStream(Stream *stream) { + CHECK(stream != nullptr); + CHECK(AsGpuStreamValue(stream) != nullptr); + CHECK(blas_ != nullptr); + rocblas_status ret = + wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream)); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); + return false; + } + + return true; +} + +namespace { + +// Helper functions transforming blas arguments into rocBLAS arguments. + +rocblas_operation ROCMBlasTranspose(blas::Transpose trans) { + switch (trans) { + case blas::Transpose::kNoTranspose: + return rocblas_operation_none; + case blas::Transpose::kTranspose: + return rocblas_operation_transpose; + case blas::Transpose::kConjugateTranspose: + return rocblas_operation_conjugate_transpose; + default: + LOG(FATAL) << "Invalid value of blas::Transpose."; + } +} + +rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) { + switch (uplo) { + case blas::UpperLower::kUpper: + return rocblas_fill_upper; + case blas::UpperLower::kLower: + return rocblas_fill_lower; + default: + LOG(FATAL) << "Invalid value of blas::UpperLower."; + } +} + +rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) { + switch (diag) { + case blas::Diagonal::kUnit: + return rocblas_diagonal_unit; + case blas::Diagonal::kNonUnit: + return rocblas_diagonal_non_unit; + default: + LOG(FATAL) << "Invalid value of blas::Diagonal."; + } +} + +rocblas_side ROCMBlasSide(blas::Side side) { + switch (side) { + case blas::Side::kLeft: + return rocblas_side_left; + case blas::Side::kRight: + return rocblas_side_right; + default: + LOG(FATAL) << "Invalid value of blas::Side."; + } +} + +} // namespace + +template <typename FuncT, typename... Args> +bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, bool err_on_failure, + Args... args) { + mutex_lock lock{mu_}; + + CHECK(blas_ != nullptr); + if (!SetStream(stream)) { + return false; + } + + rocblas_status ret = rocblas_func(parent_, blas_, args...); + if (err_on_failure && ret != rocblas_status_success) { + LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " + << ToString(ret); + } + return ret == rocblas_status_success; +} + +bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, + const DeviceMemory<float> &x, int incx, + DeviceMemory<float> *result) { + return DoBlasInternal(wrap::rocblas_sasum, stream, + false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, + const DeviceMemory<double> &x, int incx, + DeviceMemory<double> *result) { + return DoBlasInternal(wrap::rocblas_dasum, stream, + false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<float> *result) { + LOG(ERROR) << "rocBLAS does not currently support the ASUM operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, int incx, + DeviceMemory<double> *result) { + LOG(ERROR) << "rocBLAS does not currently support the ASUM operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, + const DeviceMemory<float> &x, int incx, + DeviceMemory<float> *y, int incy) { + return DoBlasInternal(wrap::rocblas_saxpy, stream, + true /* = pointer_mode_host */, elem_count, &alpha, + GpuMemory(x), incx, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, + const DeviceMemory<double> &x, int incx, + DeviceMemory<double> *y, int incy) { + return DoBlasInternal(wrap::rocblas_daxpy, stream, + true /* = pointer_mode_host */, elem_count, &alpha, + GpuMemory(x), incx, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the AXPY operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the AXPY operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, + const DeviceMemory<float> &x, int incx, + DeviceMemory<float> *y, int incy) { + return DoBlasInternal(wrap::rocblas_scopy, stream, + true /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, + const DeviceMemory<double> &x, int incx, + DeviceMemory<double> *y, int incy) { + return DoBlasInternal(wrap::rocblas_dcopy, stream, + true /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the COPY operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, int incx, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the COPY operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, + const DeviceMemory<float> &x, int incx, + const DeviceMemory<float> &y, int incy, + DeviceMemory<float> *result) { + return DoBlasInternal( + wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, + const DeviceMemory<double> &x, int incx, + const DeviceMemory<double> &y, int incy, + DeviceMemory<double> *result) { + return DoBlasInternal( + wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + const DeviceMemory<std::complex<float>> &y, int incy, + DeviceMemory<std::complex<float>> *result) { + LOG(ERROR) << "rocBLAS does not currently support the DOT operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, int incx, + const DeviceMemory<std::complex<double>> &y, int incy, + DeviceMemory<std::complex<double>> *result) { + LOG(ERROR) << "rocBLAS does not currently support the DOT operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + const DeviceMemory<std::complex<float>> &y, int incy, + DeviceMemory<std::complex<float>> *result) { + LOG(ERROR) << "rocBLAS does not currently support the DOT operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, int incx, + const DeviceMemory<std::complex<double>> &y, int incy, + DeviceMemory<std::complex<double>> *result) { + LOG(ERROR) << "rocBLAS does not currently support the DOT operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, + const DeviceMemory<float> &x, int incx, + DeviceMemory<float> *result) { + return DoBlasInternal(wrap::rocblas_snrm2, stream, + false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, + const DeviceMemory<double> &x, int incx, + DeviceMemory<double> *result) { + return DoBlasInternal(wrap::rocblas_dnrm2, stream, + false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<float> *result) { + LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, int incx, + DeviceMemory<double> *result) { + LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, + DeviceMemory<float> *x, int incx, + DeviceMemory<float> *y, int incy, float c, float s) { + LOG(ERROR) << "rocBLAS does not currently support the ROT operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, + DeviceMemory<double> *x, int incx, + DeviceMemory<double> *y, int incy, double c, + double s) { + LOG(ERROR) << "rocBLAS does not currently support the ROT operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, + DeviceMemory<std::complex<float>> *x, int incx, + DeviceMemory<std::complex<float>> *y, int incy, + float c, float s) { + LOG(ERROR) << "rocBLAS does not currently support the ROT operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, + DeviceMemory<std::complex<double>> *x, int incx, + DeviceMemory<std::complex<double>> *y, int incy, + double c, double s) { + LOG(ERROR) << "rocBLAS does not currently support the ROT operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a, + DeviceMemory<float> *b, DeviceMemory<float> *c, + DeviceMemory<float> *s) { + LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a, + DeviceMemory<double> *b, DeviceMemory<double> *c, + DeviceMemory<double> *s) { + LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, + DeviceMemory<std::complex<float>> *b, + DeviceMemory<float> *c, + DeviceMemory<std::complex<float>> *s) { + LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, + DeviceMemory<std::complex<double>> *b, + DeviceMemory<double> *c, + DeviceMemory<std::complex<double>> *s) { + LOG(ERROR) << "rocBLAS does not currently support the ROTG operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, + DeviceMemory<float> *x, int incx, + DeviceMemory<float> *y, int incy, + const DeviceMemory<float> ¶m) { + LOG(ERROR) << "rocBLAS does not currently support the ROTM operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, + DeviceMemory<double> *x, int incx, + DeviceMemory<double> *y, int incy, + const DeviceMemory<double> ¶m) { + LOG(ERROR) << "rocBLAS does not currently support the ROTM operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, + DeviceMemory<float> *d2, DeviceMemory<float> *x1, + const DeviceMemory<float> &y1, + DeviceMemory<float> *param) { + LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, + DeviceMemory<double> *d2, DeviceMemory<double> *x1, + const DeviceMemory<double> &y1, + DeviceMemory<double> *param) { + LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, + DeviceMemory<float> *x, int incx) { + return DoBlasInternal(wrap::rocblas_sscal, stream, + true /* = pointer_mode_host */, elem_count, &alpha, + GpuMemoryMutable(x), incx); +} + +bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, + DeviceMemory<double> *x, int incx) { + return DoBlasInternal(wrap::rocblas_dscal, stream, + true /* = pointer_mode_host */, elem_count, &alpha, + GpuMemoryMutable(x), incx); +} + +bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, + DeviceMemory<std::complex<float>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, + DeviceMemory<std::complex<double>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, + std::complex<float> alpha, + DeviceMemory<std::complex<float>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, + std::complex<double> alpha, + DeviceMemory<std::complex<double>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, + DeviceMemory<float> *x, int incx, + DeviceMemory<float> *y, int incy) { + return DoBlasInternal(wrap::rocblas_sswap, stream, + true /* = pointer_mode_host */, elem_count, + GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, + DeviceMemory<double> *x, int incx, + DeviceMemory<double> *y, int incy) { + return DoBlasInternal(wrap::rocblas_dswap, stream, + true /* = pointer_mode_host */, elem_count, + GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, + DeviceMemory<std::complex<float>> *x, int incx, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SWAP operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, + DeviceMemory<std::complex<double>> *x, int incx, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SWAP operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, + const DeviceMemory<float> &x, int incx, + DeviceMemory<int> *result) { + return DoBlasInternal(wrap::rocblas_isamax, stream, + false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, + const DeviceMemory<double> &x, int incx, + DeviceMemory<int> *result) { + return DoBlasInternal(wrap::rocblas_idamax, stream, + false /* = pointer_mode_host */, elem_count, + GpuMemory(x), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<int> *result) { + LOG(ERROR) << "rocBLAS does not currently support the AMAX operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, + int incx, DeviceMemory<int> *result) { + LOG(ERROR) << "rocBLAS does not currently support the AMAX operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, + const DeviceMemory<float> &x, int incx, + DeviceMemory<int> *result) { + return DoBlasInternal( + wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count, + GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, + const DeviceMemory<double> &x, int incx, + DeviceMemory<int> *result) { + return DoBlasInternal( + wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count, + GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result)); +} + +bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<int> *result) { + LOG(ERROR) << "rocBLAS does not currently support the AMIN operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, + const DeviceMemory<std::complex<double>> &x, + int incx, DeviceMemory<int> *result) { + LOG(ERROR) << "rocBLAS does not currently support the AMIN operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, uint64 kl, uint64 ku, float alpha, + const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &x, int incx, float beta, + DeviceMemory<float> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, uint64 kl, uint64 ku, double alpha, + const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &x, int incx, double beta, + DeviceMemory<double> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, uint64 kl, uint64 ku, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, uint64 kl, uint64 ku, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the GBMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, float alpha, const DeviceMemory<float> &a, + int lda, const DeviceMemory<float> &x, int incx, + float beta, DeviceMemory<float> *y, int incy) { + return DoBlasInternal( + wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), + incx, &beta, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, double alpha, const DeviceMemory<double> &a, + int lda, const DeviceMemory<double> &x, int incx, + double beta, DeviceMemory<double> *y, int incy) { + return DoBlasInternal( + wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), + incx, &beta, GpuMemoryMutable(y), incy); +} + +bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, + uint64 n, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, + const DeviceMemory<float> &x, int incx, + const DeviceMemory<float> &y, int incy, + DeviceMemory<float> *a, int lda) { + return DoBlasInternal( + wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha, + GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda); +} + +bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, + const DeviceMemory<double> &x, int incx, + const DeviceMemory<double> &y, int incy, + DeviceMemory<double> *a, int lda) { + return DoBlasInternal( + wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha, + GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda); +} + +bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + const DeviceMemory<std::complex<float>> &y, int incy, + DeviceMemory<std::complex<float>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the GER operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + const DeviceMemory<std::complex<double>> &y, int incy, + DeviceMemory<std::complex<double>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the GER operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + const DeviceMemory<std::complex<float>> &y, int incy, + DeviceMemory<std::complex<float>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the GERU operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + const DeviceMemory<std::complex<double>> &y, int incy, + DeviceMemory<std::complex<double>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the GERU operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, + uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the HBMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, + uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the HBMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the HEMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the HEMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<std::complex<float>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the HER operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + DeviceMemory<std::complex<double>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the HER operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + const DeviceMemory<std::complex<float>> &y, int incy, + DeviceMemory<std::complex<float>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the HER2 operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + const DeviceMemory<std::complex<double>> &y, int incy, + DeviceMemory<std::complex<double>> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the HER2 operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &ap, + const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the HPMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &ap, + const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the HPMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + DeviceMemory<std::complex<float>> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the HPR operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + DeviceMemory<std::complex<double>> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the HPR operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &x, int incx, + const DeviceMemory<std::complex<float>> &y, int incy, + DeviceMemory<std::complex<float>> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &x, int incx, + const DeviceMemory<std::complex<double>> &y, int incy, + DeviceMemory<std::complex<double>> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, + uint64 k, float alpha, const DeviceMemory<float> &a, + int lda, const DeviceMemory<float> &x, int incx, + float beta, DeviceMemory<float> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SBMV operation " + << "for the \"complex<float>\" dataype" ; + + return false; +} + +bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, + uint64 k, double alpha, const DeviceMemory<double> &a, + int lda, const DeviceMemory<double> &x, int incx, + double beta, DeviceMemory<double> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SBMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, const DeviceMemory<float> &ap, + const DeviceMemory<float> &x, int incx, float beta, + DeviceMemory<float> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SPMV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, const DeviceMemory<double> &ap, + const DeviceMemory<double> &x, int incx, double beta, + DeviceMemory<double> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SPMV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, const DeviceMemory<float> &x, int incx, + DeviceMemory<float> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the SPR operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, const DeviceMemory<double> &x, int incx, + DeviceMemory<double> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the SPR operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, const DeviceMemory<float> &x, int incx, + const DeviceMemory<float> &y, int incy, + DeviceMemory<float> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, const DeviceMemory<double> &x, int incx, + const DeviceMemory<double> &y, int incy, + DeviceMemory<double> *ap) { + LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &x, int incx, float beta, + DeviceMemory<float> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SYMV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &x, int incx, double beta, + DeviceMemory<double> *y, int incy) { + LOG(ERROR) << "rocBLAS does not currently support the SYMV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, const DeviceMemory<float> &x, int incx, + DeviceMemory<float> *a, int lda) { + return DoBlasInternal(wrap::rocblas_ssyr, stream, + true /* = pointer_mode_host */, + ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, + GpuMemoryMutable(a), lda); +} + +bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, const DeviceMemory<double> &x, int incx, + DeviceMemory<double> *a, int lda) { + return DoBlasInternal(wrap::rocblas_dsyr, stream, + true /* = pointer_mode_host */, + ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx, + GpuMemoryMutable(a), lda); +} + +bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, + float alpha, const DeviceMemory<float> &x, int incx, + const DeviceMemory<float> &y, int incy, + DeviceMemory<float> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, + double alpha, const DeviceMemory<double> &x, int incx, + const DeviceMemory<double> &y, int incy, + DeviceMemory<double> *a, int lda) { + LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<float> &a, int lda, + DeviceMemory<float> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<double> &a, int lda, + DeviceMemory<double> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<std::complex<float>> &a, + int lda, DeviceMemory<std::complex<float>> *x, + int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<std::complex<double>> &a, + int lda, DeviceMemory<std::complex<double>> *x, + int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<float> &a, int lda, + DeviceMemory<float> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<double> &a, int lda, + DeviceMemory<double> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<std::complex<float>> &a, + int lda, DeviceMemory<std::complex<float>> *x, + int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + uint64 k, const DeviceMemory<std::complex<double>> &a, + int lda, DeviceMemory<std::complex<double>> *x, + int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TBSV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<float> &ap, DeviceMemory<float> *x, + int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<double> &ap, + DeviceMemory<double> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<float>> &ap, + DeviceMemory<std::complex<float>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<double>> &ap, + DeviceMemory<std::complex<double>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<float> &ap, DeviceMemory<float> *x, + int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<double> &ap, + DeviceMemory<double> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<float>> &ap, + DeviceMemory<std::complex<float>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<double>> &ap, + DeviceMemory<std::complex<double>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TPSV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<float> &a, int lda, + DeviceMemory<float> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<double> &a, int lda, + DeviceMemory<double> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<float>> &a, int lda, + DeviceMemory<std::complex<float>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<double>> &a, int lda, + DeviceMemory<std::complex<double>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRMV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<float> &a, int lda, + DeviceMemory<float> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<double> &a, int lda, + DeviceMemory<double> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<float>> &a, int lda, + DeviceMemory<std::complex<float>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, blas::Diagonal diag, uint64 n, + const DeviceMemory<std::complex<double>> &a, int lda, + DeviceMemory<std::complex<double>> *x, int incx) { + LOG(ERROR) << "rocBLAS does not currently support the TRSV operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemm( + Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + float alpha, const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, float beta, + DeviceMemory<Eigen::half> *c, int ldc) { + VLOG(1) << port::Printf( + "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " + "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " + "c=%p ldc=%d", + static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha, + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); + if (transa == blas::Transpose::kNoTranspose) { + if (lda < static_cast<int64>(m)) { + LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " + "precondition violation"; + } + } else { + if (lda < static_cast<int64>(k)) { + LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k + << ") (transpose case); precondition violation"; + } + } + if (transb == blas::Transpose::kNoTranspose) { + if (ldb < static_cast<int64>(k)) { + LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k + << ") (no transpose case); precondition violation"; + } + } else { + if (ldb < static_cast<int64>(n)) { + LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " + "precondition violation"; + } + } + const Eigen::half alpha_half(alpha); + const Eigen::half beta_half(beta); + return DoBlasInternal( + wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, + reinterpret_cast<const rocblas_half*>(&alpha_half), + reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda, + reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb, + reinterpret_cast<const rocblas_half*>(&beta_half), + reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc); +} + +bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, + DeviceMemory<float> *c, int ldc) { + VLOG(1) << port::Printf( + "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " + "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " + "c=%p ldc=%d", + static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha, + a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); + if (transa == blas::Transpose::kNoTranspose) { + if (lda < static_cast<int64>(m)) { + LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " + "precondition violation"; + } + } else { + if (lda < static_cast<int64>(k)) { + LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k + << ") (transpose case); precondition violation"; + } + } + if (transb == blas::Transpose::kNoTranspose) { + if (ldb < static_cast<int64>(k)) { + LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k + << ") (no transpose case); precondition violation"; + } + } else { + if (ldb < static_cast<int64>(n)) { + LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " + "precondition violation"; + } + } + return DoBlasInternal( + wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, + GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); +} + +bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc) { + return DoBlasInternal( + wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, + GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc); +} + +bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, + const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, + int incx, float beta, DeviceMemory<float> *y, int incy, + blas::ProfileResult *output_profile_result) { + return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, + incx, beta, y, incy, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, + int incx, double beta, DeviceMemory<double> *y, int incy, + blas::ProfileResult *output_profile_result) { + return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, + incx, beta, y, incy, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, + std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, + int lda, const DeviceMemory<std::complex<float>> &x, int incx, + std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, + blas::ProfileResult *output_profile_result) { + return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, + incx, beta, y, incy, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemvWithProfiling( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, + std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, + int lda, const DeviceMemory<std::complex<double>> &x, int incx, + std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy, + blas::ProfileResult *output_profile_result) { + return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, + incx, beta, y, incy, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, + DeviceMemory<Eigen::half> *c, int ldc, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + int ldc, blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +bool ROCMBlas::DoBlasGemmWithProfiling( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + blas::ProfileResult *output_profile_result) { + return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc, + output_profile_result); +} + +template <typename T> +bool ROCMBlas::DoBlasGemvWithProfilingImpl( + Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha, + const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx, + const T &beta, DeviceMemory<T> *y, int incy, + blas::ProfileResult *output_profile_result) { + // ROCM TODO: properly implement the interface + return false; +} + +template <typename T, typename ParamType> +bool ROCMBlas::DoBlasGemmWithProfilingImpl( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, + int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, + DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) { + // ROCM TODO: properly implement the interface + return false; +} + +template <typename InT, typename OutT, typename CompT> +bool ROCMBlas::DoBlasGemmWithAlgorithmImpl( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda, + const DeviceMemory<InT> &b, int ldb, const CompT &beta, + DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + // ROCM TODO: properly implement the interface + return false; +} + +bool ROCMBlas::GetBlasGemmAlgorithms( + std::vector<blas::AlgorithmType> *out_algorithms) { + // ROCM TODO: properly implement the interface + return true; +} + +bool ROCMBlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, + const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, + int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, + int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " + << "for the \"int8\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha, + const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, + const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c, + int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " + << "for the \"half\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, + const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, + int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, + int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, + int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c, + int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + const HostOrDeviceScalar<std::complex<float>> &beta, + DeviceMemory<std::complex<float>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + const HostOrDeviceScalar<std::complex<double>> &beta, + DeviceMemory<std::complex<double>> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +template <typename T> +struct EigenHalfToRocBlasHalf { + using type = T; +}; + +template <> +struct EigenHalfToRocBlasHalf<Eigen::half> { + using type = rocblas_half; +}; + + template <typename T, typename FuncT> +port::Status ROCMBlas::DoBlasGemmBatchedInternal( + FuncT rocblas_func, Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda, + const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb, + T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + + // MAPPED_T will be same as T for all types except Eigen::Half + // for T = Eigen::half, MAPPED_T = rocblas_half + using MAPPED_T = typename EigenHalfToRocBlasHalf<T>::type; + + // Alocate local vectors to hold device pointers to matrices + std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; + for (int i = 0; i < batch_count; ++i) { + // static_cast does work when converting Eigen::half* to rocblas_half*, + // hence the use od reinterpret_cast + a_raw_ptrs.push_back(reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque())); + b_raw_ptrs.push_back(reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque())); + c_raw_ptrs.push_back(reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque())); + } + + // batch_count <= 1 is base case, no definable matrix stride, set it same as ld* + long long bsa = lda; + long long bsb = ldb; + long long bsc = ldc; + bool bsa_is_constant = true; + bool bsb_is_constant = true; + bool bsc_is_constant = true; + + if( batch_count > 1 ) + { + // Remember first stride; if any other stride is different that this one, KABLAM + bsa = a_raw_ptrs[1] - a_raw_ptrs[0]; + bsb = b_raw_ptrs[1] - b_raw_ptrs[0]; + bsc = c_raw_ptrs[1] - c_raw_ptrs[0]; + + // Loop to verify that batched strides are constant + // All the test cases from batch_matmul_op_test.py seem to satisfy this requirement of a constant + // stride. If this can be proven globally, then this loop check can be safely removed + for( int i=1; i < batch_count-1; ++i ) + { + long long iterative_bsa = a_raw_ptrs[i+1] - a_raw_ptrs[i]; + if( iterative_bsa != bsa) + { + bsa_is_constant = false; + break; + } + + long long iterative_bsb = b_raw_ptrs[i+1] - b_raw_ptrs[i]; + if( iterative_bsb != bsb) + { + bsb_is_constant = false; + break; + } + + long long iterative_bsc = c_raw_ptrs[i+1] - c_raw_ptrs[i]; + if( iterative_bsc != bsc) + { + bsc_is_constant = false; + break; + } + } + } + + assert(!(ldc < m || bsc < ldc * n)); + + if (ROCMBlasTranspose(transa) == rocblas_operation_none) + assert(!(lda < m || bsa < lda * k)); + else + assert(!(lda < k || bsa < lda * m)); + + if (ROCMBlasTranspose(transb) == rocblas_operation_none) + assert(!(ldb < k || bsb < ldb * n)); + else + assert(!(ldb < n || bsc < ldc * k)); + + + MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha); + MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta); + + if(bsa_is_constant && bsb_is_constant && bsc_is_constant) + { + bool ok = DoBlasInternal( + rocblas_func, stream, true /* = pointer_mode_host */, + ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, + GpuComplex(alpha_ptr), a_raw_ptrs[ 0 ], lda, bsa, + b_raw_ptrs[ 0 ], ldb, bsb, GpuComplex(beta_ptr), + c_raw_ptrs[ 0 ], ldc, bsc, batch_count); + + if (ok) { + return port::Status::OK(); + } + } + + return port::Status(port::error::INTERNAL, + "failed BLAS call, see log for details"); +} + +bool ROCMBlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, + const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, + const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, + float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + + const Eigen::half alpha_half(alpha); + const Eigen::half beta_half(beta); + + port::Status status = DoBlasGemmBatchedInternal( + wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k, + alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count, + scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + + return status.ok(); +} + + +bool ROCMBlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, + const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda, + const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta, + const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc, + int batch_count, ScratchAllocator *scratch_allocator) { + port::Status status = DoBlasGemmBatchedInternal( + wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, alpha, + a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, + scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + return status.ok(); +} + +bool ROCMBlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, + const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda, + const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb, + double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + port::Status status = DoBlasGemmBatchedInternal( + wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, alpha, + a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, + scratch_allocator); + if (!status.ok()) { + LOG(ERROR) << status; + } + return status.ok(); +} + +bool ROCMBlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array, + int lda, + const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array, + int ldb, std::complex<float> beta, + const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasGemmBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array, + int lda, + const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array, + int ldb, std::complex<double> beta, + const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array, + int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, + blas::UpperLower uplo, uint64 m, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the HEMM operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, + blas::UpperLower uplo, uint64 m, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the HEMM operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + float alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + float beta, DeviceMemory<std::complex<float>> *c, + int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the HERK operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + double alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + double beta, DeviceMemory<std::complex<double>> *c, + int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the HERK operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + float beta, DeviceMemory<std::complex<float>> *c, + int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the HER2K operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + double beta, DeviceMemory<std::complex<double>> *c, + int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the HER2K operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, + blas::UpperLower uplo, uint64 m, uint64 n, + float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, + DeviceMemory<float> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, + blas::UpperLower uplo, uint64 m, uint64 n, + double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, + blas::UpperLower uplo, uint64 m, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, + blas::UpperLower uplo, uint64 m, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYMM operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + float alpha, const DeviceMemory<float> &a, int lda, + float beta, DeviceMemory<float> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + double alpha, const DeviceMemory<double> &a, int lda, + double beta, DeviceMemory<double> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYRK operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + float alpha, const DeviceMemory<float> &a, int lda, + const DeviceMemory<float> &b, int ldb, float beta, + DeviceMemory<float> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + double alpha, const DeviceMemory<double> &a, int lda, + const DeviceMemory<double> &b, int ldb, double beta, + DeviceMemory<double> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + const DeviceMemory<std::complex<float>> &b, int ldb, + std::complex<float> beta, + DeviceMemory<std::complex<float>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, + blas::Transpose trans, uint64 n, uint64 k, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + const DeviceMemory<std::complex<double>> &b, int ldb, + std::complex<double> beta, + DeviceMemory<std::complex<double>> *c, int ldc) { + LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, float alpha, + const DeviceMemory<float> &a, int lda, + DeviceMemory<float> *b, int ldb) { + LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " + << "for the \"float\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, double alpha, + const DeviceMemory<double> &a, int lda, + DeviceMemory<double> *b, int ldb) { + LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " + << "for the \"double\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + DeviceMemory<std::complex<float>> *b, int ldb) { + LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + DeviceMemory<std::complex<double>> *b, int ldb) { + LOG(ERROR) << "rocBLAS does not currently support the TRMM operation " + << "for the \"complex<double>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, float alpha, + const DeviceMemory<float> &a, int lda, + DeviceMemory<float> *b, int ldb) { + return DoBlasInternal( + wrap::rocblas_strsm, stream, true /* = pointer_mode_host */, + ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), + ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float*>(GpuMemory(a)), + lda, GpuMemoryMutable(b), ldb); +} + +bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, double alpha, + const DeviceMemory<double> &a, int lda, + DeviceMemory<double> *b, int ldb) { + return DoBlasInternal( + wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */, + ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), + ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double*>(GpuMemory(a)), + lda, GpuMemoryMutable(b), ldb); +} + +bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, + std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, + DeviceMemory<std::complex<float>> *b, int ldb) { + LOG(ERROR) << "rocBLAS does not currently support the TRSM operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + +bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, + blas::UpperLower uplo, blas::Transpose transa, + blas::Diagonal diag, uint64 m, uint64 n, + std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, + DeviceMemory<std::complex<double>> *b, int ldb) { + LOG(ERROR) << "rocBLAS does not currently support the TRSM operation " + << "for the \"complex<double>\" dataype" ; + return false; +} +bool ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, + int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, + int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, + int64 stride_c, int batch_count) { + LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation " + << "for the \"Eigen::half\" dataype" ; + return false; +} + bool ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, + int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, + float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, + int batch_count) { + LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation " + << "for the \"float\" dataype" ; + return false; +} + bool ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, + int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, + double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, + int batch_count) { + LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation " + << "for the \"double\" dataype" ; + return false; +} + bool ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<float> alpha, + const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, + std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + int64 stride_c, int batch_count) { + LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation " + << "for the \"complex<float>\" dataype" ; + return false; +} + bool ROCMBlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex<double> alpha, + const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, + const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, + std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + int64 stride_c, int batch_count) { + LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation " + << "for the \"complex<double>\" dataype" ; + return false; +} +} // namespace gpu + +void initialize_rocblas() { + port::Status status = + PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>( + rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS", + [](internal::StreamExecutorInterface* parent) -> blas::BlasSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast<gpu::GpuExecutor*>(parent); + if (rocm_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the rocBLAS " + << "support library with a non-ROCM StreamExecutor"; + return nullptr; + } + + gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor); + if (!blas->Init()) { + // Note: Init() will log a more specific error. + delete blas; + return nullptr; + } + return blas; + }); + + if (!status.ok()) { + LOG(ERROR) << "Unable to register rocBLAS factory: " + << status.error_message(); + } + + PluginRegistry::Instance()->SetDefaultFactory( + rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin); +} + +} // namespace stream_executor + +REGISTER_MODULE_INITIALIZER(register_rocblas, + { stream_executor::initialize_rocblas(); }); diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h new file mode 100644 index 00000000000..8e577127450 --- /dev/null +++ b/tensorflow/stream_executor/rocm/rocm_blas.h @@ -0,0 +1,159 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS library +// capabilities, and is only included into ROCM implementation code -- it will +// not introduce rocm headers into other code. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ + +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/platform/mutex.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" +#include "tensorflow/stream_executor/plugin_registry.h" + +namespace stream_executor { + +class Stream; + +namespace gpu { + +// Opaque and unique identifier for the rocBLAS plugin. +extern const PluginId kRocBlasPlugin; + +class GpuExecutor; + +// BLAS plugin for ROCM platform via rocBLAS library. +// +// This satisfies the platform-agnostic BlasSupport interface. +// +// Note that the rocBLAS handle that this encapsulates is implicitly tied to the +// context (and, as a result, the device) that the parent GpuExecutor is tied +// to. This simply happens as an artifact of creating the rocBLAS handle when a +// ROCM context is active. +// +// Thread-safe post-initialization. +class ROCMBlas : public blas::BlasSupport { + public: + explicit ROCMBlas(GpuExecutor* parent); + + // Allocates a rocBLAS handle. + bool Init(); + + // Releases the rocBLAS handle, if present. + ~ROCMBlas() override; + + TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES + + private: + // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. + // + // rocBLAS is stateful, and only be associated with one stream (in order to + // enqueue dispatch) at a given time. As a result, this generally must be + // invoked before calling into rocBLAS. + bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // A helper function that calls the real rocBLAS function together with error + // handling. + // + // rocblas_func: rocBLAS function pointer. + // rocblas_name: rocBLAS function name. + // stream: Stream to enqueue the BLAS operation onto. + // pointer_mode_host: Indicate if the pointer to a scalar value is from host + // (true) or device (false). + // err_on_failure: Whether to print an error if the rocBLAS function fails. + // args: Arguments of rocBLAS function. + template <typename FuncT, typename... Args> + bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, bool err_on_failure, + Args... args); + + // Convenience functions that call DoBlasInternalImpl with different values + // for err_on_failure. + template <typename FuncT, typename... Args> + bool DoBlasInternal(FuncT rocblas_func, Stream *stream, bool pointer_mode_host, + Args... args) { + return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/true, args...); + } + template <typename FuncT, typename... Args> + bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, + bool pointer_mode_host, Args... args) { + return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, + /*err_on_failure=*/false, args...); + } + + // A helper function to implement DoBlasGemmBatched interfaces for generic + // types. + template <typename T, typename FuncT> + port::Status DoBlasGemmBatchedInternal( + FuncT rocblas_func, Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, + const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, + const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta, + const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, + int batch_count, ScratchAllocator *scratch_allocator); + + // Helper function for implementing DoBlasGemmWithAlgorithm. + // + // We take alpha and beta by const reference because T might be Eigen::half, + // and we want to avoid pulling in a dependency on Eigen. When we pass the + // references to rocBLAS, we essentially reinterpret_cast to __half, which is + // safe because Eigen::half inherits from __half. + template <typename InT, typename OutT, typename CompT> + bool DoBlasGemmWithAlgorithmImpl( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, + int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta, + DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result); + + // Helper function for implementing DoBlasGemmWithProfiling. + template <typename T, typename ParamType> + bool DoBlasGemmWithProfilingImpl( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, + int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, + DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); + + // Helper function for implementing DoBlasGemvWithProfiling. + template <typename T> + bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, + uint64 m, uint64 n, const T &alpha, + const DeviceMemory<T> &a, int lda, + const DeviceMemory<T> &x, int incx, + const T &beta, DeviceMemory<T> *y, int incy, + blas::ProfileResult *output_profile_result); + + // mutex that guards the rocBLAS handle for this device. + mutex mu_; + + // GpuExecutor which instantiated this ROCMBlas. + // Immutable post-initialization. + GpuExecutor* parent_; + + // rocBLAS library handle on the device. + rocblas_handle blas_ GUARDED_BY(mu_); + + SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); +}; + +} // namespace gpu +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_ From 19d3ff647d2174184d49ebfe2e505ffd266d0f09 Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Thu, 31 Jan 2019 19:49:46 +0000 Subject: [PATCH 3/7] adding code for rocfft plugin --- tensorflow/stream_executor/rocm/BUILD | 51 +- tensorflow/stream_executor/rocm/rocm_fft.cc | 588 ++++++++++++++++++++ tensorflow/stream_executor/rocm/rocm_fft.h | 132 +++++ 3 files changed, 750 insertions(+), 21 deletions(-) create mode 100644 tensorflow/stream_executor/rocm/rocm_fft.cc create mode 100644 tensorflow/stream_executor/rocm/rocm_fft.h diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index 737a4429469..c0da35121f2 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -175,26 +175,35 @@ cc_library( alwayslink = True, ) -# FIXME: enable in future PRs -#cc_library( -# name = "rocfft_plugin", -# srcs = ["rocm_fft.cc"], -# hdrs = [], -# visibility = ["//visibility:public"], -# deps = [ -# ":rocm_platform_id", -# "//tensorflow/stream_executor:event", -# "//tensorflow/stream_executor:fft", -# "//tensorflow/stream_executor:plugin_registry", -# "//tensorflow/stream_executor:scratch_allocator", -# "//tenosrflow/stream_executor/gpu:gpu_stream_header", -# "//tensorflow/stream_executor/lib", -# "//tensorflow/stream_executor/platform", -# "//tensorflow/stream_executor/platform:dso_loader", -# "@local_config_rocm//rocm:rocm_headers", -# ] + if_static(["@local_config_rocm//rocm:rocfft"]), -# alwayslink = True, -#) +cc_library( + name = "rocfft_plugin", + srcs = if_rocm_is_configured(["rocm_fft.cc"]), + hdrs = if_rocm_is_configured(["rocm_fft.h"]), + visibility = ["//visibility:public"], + deps = if_rocm_is_configured([ + ":rocm_platform_id", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:fft", + "//tensorflow/stream_executor:plugin_registry", + "//tensorflow/stream_executor:scratch_allocator", + "//tensorflow/stream_executor/gpu:gpu_activation", + "//tensorflow/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/stream_executor/gpu:gpu_executor_header", + "//tensorflow/stream_executor/gpu:gpu_stream_header", + "//tensorflow/stream_executor/gpu:gpu_kernel_header", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + "@local_config_rocm//rocm:rocm_headers", + ] + if_static([ + "@local_config_rocm//rocm:rocfft" + # Delete the following line once we switch the rocfft library from + # being dynamically linked (current behaviour) to being dynamically + # loaded (future behaviour) + ], ["@local_config_rocm//rocm:rocfft" + ])), + alwayslink = True, +) # FIXME: enable in future PRs #cc_library( @@ -263,7 +272,7 @@ cc_library( deps = if_rocm_is_configured([ # FIXME: enable in future PRs #":miopen_plugin", - #":rocfft_plugin", + ":rocfft_plugin", ":rocblas_plugin", #":rocrand_plugin", ":rocm_driver", diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc new file mode 100644 index 00000000000..dd30911eadd --- /dev/null +++ b/tensorflow/stream_executor/rocm/rocm_fft.cc @@ -0,0 +1,588 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/rocm/rocm_fft.h" + +#include <complex> + +#include "tensorflow/stream_executor/gpu/gpu_activation.h" +#include "tensorflow/stream_executor/gpu/gpu_executor.h" +#include "tensorflow/stream_executor/gpu/gpu_helpers.h" +#include "tensorflow/stream_executor/rocm/rocm_platform_id.h" +#include "tensorflow/stream_executor/gpu/gpu_stream.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/lib/initialize.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/plugin_registry.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" + +namespace stream_executor { +namespace gpu { + +PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin); + +namespace wrap { + +// This macro wraps a global identifier, given by __name, in a callable +// structure that loads the DLL symbol out of the DSO handle in a thread-safe +// manner on first use. This dynamic loading technique is used to avoid DSO +// dependencies on vendor libraries which may or may not be available in the +// deployed binary environment. +#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ + struct WrapperShim__##__name { \ + template <typename... Args> \ + hipfftResult operator()(GpuExecutor* parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return ::__name(args...); \ + } \ + } __name; + +#define ROCFFT_ROUTINE_EACH(__macro) \ + __macro(hipfftDestroy) \ + __macro(hipfftSetStream) \ + __macro(hipfftPlan1d) \ + __macro(hipfftPlan2d) \ + __macro(hipfftPlan3d) \ + __macro(hipfftPlanMany) \ + __macro(hipfftCreate) \ + __macro(hipfftSetAutoAllocation) \ + __macro(hipfftSetWorkArea) \ + __macro(hipfftGetSize1d) \ + __macro(hipfftMakePlan1d) \ + __macro(hipfftGetSize2d) \ + __macro(hipfftMakePlan2d) \ + __macro(hipfftGetSize3d) \ + __macro(hipfftMakePlan3d) \ + __macro(hipfftGetSizeMany) \ + __macro(hipfftMakePlanMany) \ + __macro(hipfftExecD2Z) \ + __macro(hipfftExecZ2D) \ + __macro(hipfftExecC2C) \ + __macro(hipfftExecC2R) \ + __macro(hipfftExecZ2Z) \ + __macro(hipfftExecR2C) \ + +ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP) + +} // namespace wrap + +namespace { + +// A helper function transforming gpu_fft arguments into rocFFT arguments. +hipfftType ROCMFftType(fft::Type type) { + switch (type) { + case fft::Type::kC2CForward: + case fft::Type::kC2CInverse: + return HIPFFT_C2C; + case fft::Type::kC2R: + return HIPFFT_C2R; + case fft::Type::kR2C: + return HIPFFT_R2C; + case fft::Type::kZ2ZForward: + case fft::Type::kZ2ZInverse: + return HIPFFT_Z2Z; + case fft::Type::kZ2D: + return HIPFFT_Z2D; + case fft::Type::kD2Z: + return HIPFFT_D2Z; + default: + LOG(FATAL) << "Invalid value of fft::Type."; + } +} + +// Associates the given stream with the given rocFFT plan. +bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) { + auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream)); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret; + return false; + } + return true; +} + +} // namespace + +port::Status ROCMFftPlan::Initialize( + GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count, + uint64 *input_embed, uint64 input_stride, uint64 input_distance, + uint64 *output_embed, uint64 output_stride, uint64 output_distance, + fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) { + if (IsInitialized()) { + LOG(FATAL) << "Try to repeatedly initialize."; + } + is_initialized_ = true; + int elem_count_[3], input_embed_[3], output_embed_[3]; + for (int i = 0; i < rank; ++i) { + elem_count_[i] = elem_count[i]; + if (input_embed) { + input_embed_[i] = input_embed[i]; + } + if (output_embed) { + output_embed_[i] = output_embed[i]; + } + } + parent_ = parent; + fft_type_ = type; + if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) { + hipfftResult_t ret; + if (scratch_allocator == nullptr) { + switch (rank) { + case 1: + // hipfftPlan1d + ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0], + ROCMFftType(type), 1 /* = batch */); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to create rocFFT 1d plan."}; + } + return port::Status::OK(); + case 2: + // hipfftPlan2d + ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0], + elem_count_[1], ROCMFftType(type)); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to create rocFFT 2d plan."}; + } + return port::Status::OK(); + case 3: + // hipfftPlan3d + ret = + wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1], + elem_count_[2], ROCMFftType(type)); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to create rocFFT 3d plan."}; + } + return port::Status::OK(); + default: + LOG(ERROR) << "Invalid rank value for hipfftPlan. " + "Requested 1, 2, or 3, given: " + << rank; + return port::Status{port::error::INVALID_ARGUMENT, + "hipfftPlan only takes rank 1, 2, or 3."}; + } + } else { + ret = wrap::hipfftCreate(parent, &plan_); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to create rocFFT plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to create rocFFT plan."}; + } + ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to set auto allocation for rocFFT plan."}; + } + size_t size_in_bytes; + switch (rank) { + case 1: + ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0], + ROCMFftType(type), /*batch=*/1, + &size_in_bytes); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to make rocFFT 1d plan."}; + } + break; + case 2: + ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0], + elem_count_[1], ROCMFftType(type), + &size_in_bytes); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to make rocFFT 2d plan."}; + } + break; + case 3: + ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0], + elem_count_[1], elem_count_[2], + ROCMFftType(type), &size_in_bytes); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to make rocFFT 3d plan."}; + } + break; + default: + LOG(ERROR) << "Invalid rank value for hipfftPlan. " + "Requested 1, 2, or 3, given: " + << rank; + return port::Status{port::error::INVALID_ARGUMENT, + "hipfftPlan only takes rank 1, 2, or 3."}; + } + // TODO(yangzihao): refactor this code and the one with the same function + // in the batch mode. + if (size_in_bytes != 0) { + auto allocated = + scratch_allocator->AllocateBytes(stream, size_in_bytes); + if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "failed to allocate work area."; + return allocated.status(); + } + } + // Connect work area with allocated space. + ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque()); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to set work area for rocFFT plan."}; + } + return port::Status::OK(); + } + } else { + // For either multiple batches or rank higher than 3, use hipfftPlanMany(). + if (scratch_allocator == nullptr) { + auto ret = wrap::hipfftPlanMany( + parent, &plan_, rank, elem_count_, + input_embed ? input_embed_ : nullptr, input_stride, input_distance, + output_embed ? output_embed_ : nullptr, output_stride, + output_distance, ROCMFftType(type), batch_count); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to create rocFFT batched plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to create rocFFT bacthed plan."}; + } + } else { + auto ret = wrap::hipfftCreate(parent, &plan_); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to create rocFFT batched plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to create rocFFT bacthed plan."}; + } + ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:" + << ret; + return port::Status{ + port::error::INTERNAL, + "Failed to set auto allocation for rocFFT bacthed plan."}; + } + size_t size_in_bytes; + ret = wrap::hipfftMakePlanMany( + parent, plan_, rank, elem_count_, + input_embed ? input_embed_ : nullptr, input_stride, input_distance, + output_embed ? output_embed_ : nullptr, output_stride, + output_distance, ROCMFftType(type), batch_count, &size_in_bytes); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to make rocFFT batched plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to make rocFFT bacthed plan."}; + } + if (size_in_bytes != 0) { + auto allocated = + scratch_allocator->AllocateBytes(stream, size_in_bytes); + if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) { + LOG(ERROR) << "failed to allocate work area."; + return allocated.status(); + } + } + // Connect work area with allocated space. + ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque()); + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret; + return port::Status{port::error::INTERNAL, + "Failed to set work area for rocFFT bacthed plan."}; + } + } + } + return port::Status::OK(); +} + +port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream, + int rank, uint64 *elem_count, + fft::Type type, + ScratchAllocator *scratch_allocator) { + return Initialize(parent_, stream, rank, elem_count, + /*input_embed=*/nullptr, /*input_stride=*/0, + /*input_distance=*/0, + /*output_embed=*/nullptr, /*output_stride=*/0, + /*output_distance=*/0, type, 1, scratch_allocator); +} + +ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); } + +int ROCMFftPlan::GetFftDirection() const { + if (!IsInitialized()) { + LOG(FATAL) << "Try to get fft direction before initialization."; + } else { + switch (fft_type_) { + case fft::Type::kC2CForward: + case fft::Type::kZ2ZForward: + case fft::Type::kR2C: + case fft::Type::kD2Z: + return HIPFFT_FORWARD; + case fft::Type::kC2CInverse: + case fft::Type::kZ2ZInverse: + case fft::Type::kC2R: + case fft::Type::kZ2D: + return HIPFFT_BACKWARD; + default: + LOG(FATAL) << "Invalid value of fft::Type."; + } + } +} + +std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x, + fft::Type type, + bool in_place_fft) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + uint64 elem_count[1] = {num_x}; + port::Status status = fft_plan_ptr->Initialize( + parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr); + // TODO(yangzihao): In the future, send error msg back to TensorFlow + // so it can fail gracefully, + if (!status.ok()) { + LOG(FATAL) << "failed to initialize hipfft 1d plan: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator( + Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft, + ScratchAllocator *scratch_allocator) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + uint64 elem_count[1] = {num_x}; + port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count, + type, scratch_allocator); + if (!status.ok()) { + LOG(FATAL) + << "failed to initialize hipfft 1d plan with customized allocator: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x, + uint64 num_y, fft::Type type, + bool in_place_fft) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + uint64 elem_count[2] = {num_x, num_y}; + port::Status status = fft_plan_ptr->Initialize( + parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr); + if (!status.ok()) { + LOG(FATAL) << "failed to initialize hipfft 2d plan: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator( + Stream *stream, uint64 num_x, uint64 num_y, fft::Type type, + bool in_place_fft, ScratchAllocator *scratch_allocator) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + uint64 elem_count[2] = {num_x, num_y}; + port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count, + type, scratch_allocator); + if (!status.ok()) { + LOG(FATAL) + << "failed to initialize hipfft 2d plan with customized allocator: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x, + uint64 num_y, uint64 num_z, + fft::Type type, + bool in_place_fft) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + uint64 elem_count[3] = {num_x, num_y, num_z}; + port::Status status = fft_plan_ptr->Initialize( + parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr); + if (!status.ok()) { + LOG(FATAL) << "failed to initialize hipfft 3d plan: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator( + Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type, + bool in_place_fft, ScratchAllocator *scratch_allocator) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + uint64 elem_count[3] = {num_x, num_y, num_z}; + port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count, + type, scratch_allocator); + if (!status.ok()) { + LOG(FATAL) + << "failed to initialize hipfft 3d plan with customized allocator: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan( + Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, + uint64 input_stride, uint64 input_distance, uint64 *output_embed, + uint64 output_stride, uint64 output_distance, fft::Type type, + bool in_place_fft, int batch_count) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + port::Status status = fft_plan_ptr->Initialize( + parent_, stream, rank, elem_count, input_embed, input_stride, + input_distance, output_embed, output_stride, output_distance, type, + batch_count, /*scratch_allocator=*/nullptr); + if (!status.ok()) { + LOG(FATAL) << "failed to initialize batched hipfft plan: " + << status.error_message(); + } + + return std::move(fft_plan_ptr); +} + +std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator( + Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, + uint64 input_stride, uint64 input_distance, uint64 *output_embed, + uint64 output_stride, uint64 output_distance, fft::Type type, + bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) { + std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()}; + port::Status status = fft_plan_ptr->Initialize( + parent_, stream, rank, elem_count, input_embed, input_stride, + input_distance, output_embed, output_stride, output_distance, type, + batch_count, scratch_allocator); + if (!status.ok()) { + LOG(FATAL) + << "failed to initialize batched hipfft plan with customized allocator: " + << status.error_message(); + } + return std::move(fft_plan_ptr); +} + +void ROCMFft::UpdatePlanWithScratchAllocator( + Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) { + LOG(ERROR) << "update plan with scratch allocator not implemented"; +} + +template <typename FuncT, typename InputT, typename OutputT> +bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec, + const DeviceMemory<InputT> &input, + DeviceMemory<OutputT> *output) { + ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan); + if (rocm_fft_plan == nullptr) { + LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object."; + return false; + } + + if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) { + return false; + } + + auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(), + GpuComplex(const_cast<InputT*>(GpuMemory(input))), + GpuComplex(GpuMemoryMutable(output))); + + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to run rocFFT routine: " << ret; + return false; + } + + return true; +} + +template <typename FuncT, typename InputT, typename OutputT> +bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, + FuncT hipfftExec, + const DeviceMemory<InputT> &input, + DeviceMemory<OutputT> *output) { + ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan); + if (rocm_fft_plan == nullptr) { + LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object."; + return false; + } + + if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) { + return false; + } + + auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(), + GpuComplex(const_cast<InputT*>(GpuMemory(input))), + GpuComplex(GpuMemoryMutable(output)), + rocm_fft_plan->GetFftDirection()); + + if (ret != HIPFFT_SUCCESS) { + LOG(ERROR) << "failed to run rocFFT routine: " << ret; + return false; + } + + return true; +} + +#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \ + __fft_type3) \ + bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<std::complex<__type>> &input, \ + DeviceMemory<std::complex<__type>> *output) { \ + return DoFftWithDirectionInternal( \ + stream, plan, wrap::hipfftExec##__fft_type1, input, output); \ + } \ + bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<__type> &input, \ + DeviceMemory<std::complex<__type>> *output) { \ + return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \ + output); \ + } \ + bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<std::complex<__type>> &input, \ + DeviceMemory<__type> *output) { \ + return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \ + output); \ + } + +STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R) +STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D) + +#undef STREAM_EXECUTOR_ROCM_DEFINE_FFT + +} // namespace gpu + +void initialize_rocfft() { + port::Status status = + PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>( + rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT", + [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * { + gpu::GpuExecutor *rocm_executor = + dynamic_cast<gpu::GpuExecutor *>(parent); + if (rocm_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the rocFFT " + << "support library with a non-ROCM StreamExecutor"; + return nullptr; + } + + return new gpu::ROCMFft(rocm_executor); + }); + if (!status.ok()) { + LOG(ERROR) << "Unable to register rocFFT factory: " + << status.error_message(); + } + + PluginRegistry::Instance()->SetDefaultFactory( + rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin); +} + +} // namespace stream_executor + +REGISTER_MODULE_INITIALIZER(register_rocfft, + { stream_executor::initialize_rocfft(); }); diff --git a/tensorflow/stream_executor/rocm/rocm_fft.h b/tensorflow/stream_executor/rocm/rocm_fft.h new file mode 100644 index 00000000000..3dbe5800b74 --- /dev/null +++ b/tensorflow/stream_executor/rocm/rocm_fft.h @@ -0,0 +1,132 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ROCM-specific support for FFT functionality -- this wraps the rocFFT library +// capabilities, and is only included into ROCM implementation code -- it will +// not introduce rocm headers into other code. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_ + +#include "tensorflow/stream_executor/fft.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/plugin_registry.h" +#include "tensorflow/stream_executor/scratch_allocator.h" +#include "rocm/include/rocfft/hipfft.h" + +namespace stream_executor { + +class Stream; + +namespace gpu { + +class GpuExecutor; + +// Opaque and unique indentifier for the rocFFT plugin. +extern const PluginId kRocFftPlugin; + +// ROCMFftPlan uses deferred initialization. Only a single call of +// Initialize() is allowed to properly create hipfft plan and set member +// variable is_initialized_ to true. Newly added interface that uses member +// variables should first check is_initialized_ to make sure that the values of +// member variables are valid. +class ROCMFftPlan : public fft::Plan { + public: + ROCMFftPlan() + : parent_(nullptr), + plan_(), + fft_type_(fft::Type::kInvalid), + scratch_(nullptr), + is_initialized_(false) {} + ~ROCMFftPlan() override; + + // Get FFT direction in hipFFT based on FFT type. + int GetFftDirection() const; + hipfftHandle GetPlan() const { + if (IsInitialized()) { + return plan_; + } else { + LOG(FATAL) << "Try to get hipfftHandle value before initialization."; + } + } + + // Initialize function for batched plan + port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, + uint64 *elem_count, uint64 *input_embed, + uint64 input_stride, uint64 input_distance, + uint64 *output_embed, uint64 output_stride, + uint64 output_distance, fft::Type type, + int batch_count, ScratchAllocator *scratch_allocator); + + // Initialize function for 1d,2d, and 3d plan + port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, + uint64 *elem_count, fft::Type type, + ScratchAllocator *scratch_allocator); + + protected: + bool IsInitialized() const { return is_initialized_; } + + private: + GpuExecutor *parent_; + hipfftHandle plan_; + fft::Type fft_type_; + DeviceMemory<uint8> scratch_; + bool is_initialized_; +}; + +// FFT support for ROCM platform via rocFFT library. +// +// This satisfies the platform-agnostic FftSupport interface. +// +// Note that the hipFFT handle that this encapsulates is implicitly tied to the +// context (and, as a result, the device) that the parent GpuExecutor is tied +// to. This simply happens as an artifact of creating the hipFFT handle when a +// ROCM context is active. +// +// Thread-safe. The ROCM context associated with all operations is the ROCM +// context of parent_, so all context is explicit. +class ROCMFft : public fft::FftSupport { + public: + explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {} + ~ROCMFft() override {} + + TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES + + private: + GpuExecutor *parent_; + + // Two helper functions that execute dynload::hipfftExec?2?. + + // This is for complex to complex FFT, when the direction is required. + template <typename FuncT, typename InputT, typename OutputT> + bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, + FuncT hipfft_exec, + const DeviceMemory<InputT> &input, + DeviceMemory<OutputT> *output); + + // This is for complex to real or real to complex FFT, when the direction + // is implied. + template <typename FuncT, typename InputT, typename OutputT> + bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec, + const DeviceMemory<InputT> &input, + DeviceMemory<OutputT> *output); + + SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft); +}; + +} // namespace gpu +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_ From 234f47031bdd5346b6a3dc29670c215c94311cc8 Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Thu, 31 Jan 2019 20:09:53 +0000 Subject: [PATCH 4/7] adding code for hiprand plugin --- tensorflow/stream_executor/rocm/BUILD | 50 +++++++----- tensorflow/stream_executor/rocm/rocm_rng.cc | 89 +++++++++++---------- 2 files changed, 74 insertions(+), 65 deletions(-) diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index c0da35121f2..85ab38d4242 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -244,26 +244,34 @@ cc_library( # alwayslink = True, #) -# FIXME: enable in future PRs -#cc_library( -# name = "rocrand_plugin", -# srcs = ["rocm_rng.cc"], -# hdrs = [], -# deps = [ -# ":rocm_gpu_executor", -# ":rocm_platform_id", -# "@local_config_rocm//rocm:rocm_headers", -# "//tensorflow/stream_executor:event", -# "//tensorflow/stream_executor:plugin_registry", -# "//tensorflow/stream_executor:rng", -# "//tenosrflow/stream_executor/gpu:gpu_activation_header", -# "//tenosrflow/stream_executor/gpu:gpu_stream_header", -# "//tensorflow/stream_executor/lib", -# "//tensorflow/stream_executor/platform", -# "//tensorflow/stream_executor/platform:dso_loader", -# ] + if_static(["@local_config_rocm//rocm:curand"]), -# alwayslink = True, -#) +cc_library( + name = "rocrand_plugin", + srcs = if_rocm_is_configured(["rocm_rng.cc"]), + hdrs = if_rocm_is_configured([]), + deps = if_rocm_is_configured([ + ":rocm_gpu_executor", + ":rocm_platform_id", + "@local_config_rocm//rocm:rocm_headers", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:plugin_registry", + "//tensorflow/stream_executor:rng", + "//tensorflow/stream_executor/gpu:gpu_activation_header", + "//tensorflow/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/stream_executor/gpu:gpu_executor_header", + "//tensorflow/stream_executor/gpu:gpu_rng_header", + "//tensorflow/stream_executor/gpu:gpu_stream_header", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + ] + if_static([ + "@local_config_rocm//rocm:hiprand" + # Delete the following line once we switch the hiprand library from + # being dynamically linked (current behaviour) to being dynamically + # loaded (future behaviour) + ], ["@local_config_rocm//rocm:hiprand" + ])), + alwayslink = True, +) cc_library( name = "all_runtime", @@ -274,7 +282,7 @@ cc_library( #":miopen_plugin", ":rocfft_plugin", ":rocblas_plugin", - #":rocrand_plugin", + ":rocrand_plugin", ":rocm_driver", ":rocm_platform", ]), diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 65acd03c92b..79250579087 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -14,21 +14,22 @@ limitations under the License. ==============================================================================*/ #include "rocm/include/hiprand/hiprand.h" -#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/gpu/gpu_rng.h" + #include "tensorflow/stream_executor/gpu/gpu_activation.h" #include "tensorflow/stream_executor/gpu/gpu_executor.h" #include "tensorflow/stream_executor/gpu/gpu_helpers.h" -#include "tensorflow/stream_executor/gpu/gpu_rng.h" +#include "tensorflow/stream_executor/rocm/rocm_platform_id.h" #include "tensorflow/stream_executor/gpu/gpu_stream.h" +#include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/rng.h" -#include "tensorflow/stream_executor/rocm/rocm_platform_id.h" // Formats hiprandStatus_t to output prettified values into a log stream. -std::ostream& operator<<(std::ostream& in, const hiprandStatus_t& status) { +std::ostream &operator<<(std::ostream &in, const hiprandStatus_t &status) { #define OSTREAM_HIPRAND_STATUS(__name) \ case HIPRAND_STATUS_##__name: \ in << "HIPRAND_STATUS_" #__name; \ @@ -60,7 +61,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin); namespace wrap { -#define PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(__name) \ +#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \ struct WrapperShim__##__name { \ template <typename... Args> \ hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \ @@ -69,15 +70,15 @@ namespace wrap { } \ } __name; -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandCreateGenerator); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandDestroyGenerator); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetStream); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniform); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniformDouble); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetGeneratorOffset); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormal); -PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormalDouble); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal); +STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble); } // namespace wrap @@ -245,40 +246,40 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) { } } // namespace gpu -} // namespace stream_executor -namespace se = ::stream_executor; +void initialize_rocrand() { + port::Status status = + PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>( + rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND", + [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast<gpu::GpuExecutor*>(parent); + if (rocm_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the hipRAND " + << "support library with a non-ROCM StreamExecutor"; + return nullptr; + } -REGISTER_MODULE_INITIALIZER(register_hiprand, { - se::port::Status status = - se::PluginRegistry::Instance() - ->RegisterFactory<se::PluginRegistry::RngFactory>( - se::rocm::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND", - [](se::internal::StreamExecutorInterface* parent) - -> se::rng::RngSupport* { - se::gpu::GpuExecutor* rocm_executor = - dynamic_cast<se::gpu::GpuExecutor*>(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the hipRAND " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - se::gpu::GpuRng* rng = new se::gpu::GpuRng(rocm_executor); - if (!rng->Init()) { - // Note: Init() will log a more specific error. - delete rng; - return nullptr; - } - return rng; - }); + gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor); + if (!rng->Init()) { + // Note: Init() will log a more specific error. + delete rng; + return nullptr; + } + return rng; + }); if (!status.ok()) { - LOG(ERROR) << "Unable to register hipRAND factory: " + LOG(ERROR) << "Unable to register rocRAND factory: " << status.error_message(); } - se::PluginRegistry::Instance()->SetDefaultFactory( - se::rocm::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin); -}); + PluginRegistry::Instance()->SetDefaultFactory( + rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin); +} + +} // namespace stream_executor + +REGISTER_MODULE_INITIALIZER(register_rocrand, + { stream_executor::initialize_rocrand(); }); From 07b3f341995ff0b012bf9297f500a8e75c682f31 Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Fri, 1 Feb 2019 02:47:14 +0000 Subject: [PATCH 5/7] changing rocblas, rocfft and rocrand from being dynamically linked to being dynmically loaded --- .../platform/default/dso_loader.cc | 52 ++++++++++++++++++- .../platform/default/dso_loader.h | 13 +++++ tensorflow/stream_executor/rocm/BUILD | 12 ----- tensorflow/stream_executor/rocm/rocm_blas.cc | 37 +++++++++++++ tensorflow/stream_executor/rocm/rocm_fft.cc | 34 ++++++++++++ tensorflow/stream_executor/rocm/rocm_rng.cc | 35 +++++++++++++ 6 files changed, 169 insertions(+), 14 deletions(-) diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index ad8112b831a..3f2bb5e4dd0 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) { port::Status status = port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle); if (status.ok()) { - LOG(INFO) << "Successfully opened CUDA library " << filename; + LOG(INFO) << "Successfully opened dynamic library " << filename; return dso_handle; } @@ -54,6 +54,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) { return port::Status(port::error::FAILED_PRECONDITION, message); } } // namespace + namespace DsoLoader { port::StatusOr<void*> GetCudaDriverDsoHandle() { @@ -99,6 +100,27 @@ port::StatusOr<void*> GetCuptiDsoHandle() { port::StatusOr<void*> GetCudnnDsoHandle() { return GetDsoHandle("cudnn", GetCudnnVersion()); } + +port::StatusOr<void*> GetRocblasDsoHandle() { + return GetDsoHandle("rocblas", ""); +} + +port::StatusOr<void*> GetMiopenDsoHandle() { + return GetDsoHandle("MIOpen", ""); +} + +port::StatusOr<void*> GetRocfftDsoHandle() { + return GetDsoHandle("rocfft", ""); +} + +port::StatusOr<void*> GetRocrandDsoHandle() { + return GetDsoHandle("rocrand", ""); +} + +port::StatusOr<void*> GetHipDsoHandle() { + return GetDsoHandle("hip_hcc", ""); +} + } // namespace DsoLoader namespace CachedDsoLoader { @@ -131,11 +153,37 @@ port::StatusOr<void*> GetCuptiDsoHandle() { static auto result = new auto(DsoLoader::GetCuptiDsoHandle()); return *result; } - + port::StatusOr<void*> GetCudnnDsoHandle() { static auto result = new auto(DsoLoader::GetCudnnDsoHandle()); return *result; } + +port::StatusOr<void*> GetRocblasDsoHandle() { + static auto result = new auto(DsoLoader::GetRocblasDsoHandle()); + return result; +} + +port::StatusOr<void*> GetMiopenDsoHandle() { + static auto result = new auto(DsoLoader::GetMiopenDsoHandle()); + return result; +} + +port::StatusOr<void*> GetRocfftDsoHandle() { + static auto result = new auto(DsoLoader::GetRocfftDsoHandle()); + return result; +} + +port::StatusOr<void*> GetRocrandDsoHandle() { + static auto result = new auto(DsoLoader::GetRocrandDsoHandle()); + return result; +} + +port::StatusOr<void*> GetHipDsoHandle() { + static auto result = new auto(DsoLoader::GetHipDsoHandle()); + return result; +} + } // namespace CachedDsoLoader } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index 45a8315b436..8da8ea7be66 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle(); port::StatusOr<void*> GetCurandDsoHandle(); port::StatusOr<void*> GetCuptiDsoHandle(); port::StatusOr<void*> GetCudnnDsoHandle(); + +port::StatusOr<void*> GetRocblasDsoHandle(); +port::StatusOr<void*> GetMiopenDsoHandle(); +port::StatusOr<void*> GetRocfftDsoHandle(); +port::StatusOr<void*> GetRocrandDsoHandle(); +port::StatusOr<void*> GetHipDsoHandle(); } // namespace DsoLoader // Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs @@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle(); port::StatusOr<void*> GetCurandDsoHandle(); port::StatusOr<void*> GetCuptiDsoHandle(); port::StatusOr<void*> GetCudnnDsoHandle(); + +port::StatusOr<void*> GetRocblasDsoHandle(); +port::StatusOr<void*> GetMiopenDsoHandle(); +port::StatusOr<void*> GetRocfftDsoHandle(); +port::StatusOr<void*> GetRocrandDsoHandle(); +port::StatusOr<void*> GetHipDsoHandle(); } // namespace CachedDsoLoader + } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index 85ab38d4242..f0b05822703 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -167,10 +167,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", ] + if_static([ "@local_config_rocm//rocm:rocblas" - # Delete the following line once we switch the rocblas library from - # being dynamically linked (current behaviour) to being dynamically - # loaded (future behaviour) - ], ["@local_config_rocm//rocm:rocblas" ])), alwayslink = True, ) @@ -197,10 +193,6 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", ] + if_static([ "@local_config_rocm//rocm:rocfft" - # Delete the following line once we switch the rocfft library from - # being dynamically linked (current behaviour) to being dynamically - # loaded (future behaviour) - ], ["@local_config_rocm//rocm:rocfft" ])), alwayslink = True, ) @@ -265,10 +257,6 @@ cc_library( "//tensorflow/stream_executor/platform:dso_loader", ] + if_static([ "@local_config_rocm//rocm:hiprand" - # Delete the following line once we switch the hiprand library from - # being dynamically linked (current behaviour) to being dynamically - # loaded (future behaviour) - ], ["@local_config_rocm//rocm:hiprand" ])), alwayslink = True, ) diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index b2e225433e5..a626d168c26 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/stringprintf.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -49,6 +50,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin); namespace wrap { +#ifdef PLATFORM_GOOGLE #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ struct WrapperShim__##__name { \ static const char* kName; \ @@ -63,6 +65,41 @@ namespace wrap { #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ STREAM_EXECUTOR_ROCBLAS_WRAP(__name) +#else + +#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char* kName; \ + using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \ + static void* GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetRocblasDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocblas DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast<FuncPtrT>(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template <typename... Args> \ + rocblas_status operator()(GpuExecutor* parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char* DynLoadShim__##__name::kName = #__name; + +#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \ + STREAM_EXECUTOR_ROCBLAS_WRAP(__name) + +#endif + #define ROCBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /* __macro(rocblas_scnrm2) \ __macro(rocblas_dznrm2) */ \ diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc index dd30911eadd..e8a72f61d3f 100644 --- a/tensorflow/stream_executor/rocm/rocm_fft.cc +++ b/tensorflow/stream_executor/rocm/rocm_fft.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" @@ -38,6 +39,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin); namespace wrap { +#ifdef PLATFORM_GOOGLE // This macro wraps a global identifier, given by __name, in a callable // structure that loads the DLL symbol out of the DSO handle in a thread-safe // manner on first use. This dynamic loading technique is used to avoid DSO @@ -52,6 +54,38 @@ namespace wrap { } \ } __name; +#else + +#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char *kName; \ + using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \ + static void *GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void *f; \ + auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocfft DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast<FuncPtrT>(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template <typename... Args> \ + hipfftResult operator()(GpuExecutor *parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char *DynLoadShim__##__name::kName = #__name; + +#endif + #define ROCFFT_ROUTINE_EACH(__macro) \ __macro(hipfftDestroy) \ __macro(hipfftSetStream) \ diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 79250579087..27797845700 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/rng.h" @@ -61,6 +62,8 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin); namespace wrap { +#ifdef PLATFORM_GOOGLE + #define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \ struct WrapperShim__##__name { \ template <typename... Args> \ @@ -70,6 +73,38 @@ namespace wrap { } \ } __name; +#else + +#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \ + struct DynLoadShim__##__name { \ + static const char *kName; \ + using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \ + static void *GetDsoHandle() { \ + auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \ + return s.ValueOrDie(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void *f; \ + auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocrand DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast<FuncPtrT>(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template <typename... Args> \ + hiprandStatus operator()(GpuExecutor *parent, Args... args) { \ + gpu::ScopedActivateExecutorContext sac{parent}; \ + return DynLoad()(args...); \ + } \ + } __name; \ + const char *DynLoadShim__##__name::kName = #__name; + +#endif + STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator); STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator); STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream); From 834a3f7395a9db748349e0bf9dfff7af558cb4fb Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Fri, 1 Feb 2019 17:59:02 +0000 Subject: [PATCH 6/7] changing the rocm_driver API from being dynamically linked to being dynamically loaded --- tensorflow/stream_executor/rocm/BUILD | 149 +++++++++--------- .../stream_executor/rocm/rocm_driver.cc | 125 +++++++-------- .../rocm/rocm_driver_wrapper.h | 147 +++++++++++++++++ 3 files changed, 285 insertions(+), 136 deletions(-) create mode 100644 tensorflow/stream_executor/rocm/rocm_driver_wrapper.h diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index f0b05822703..5190b551f80 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -47,7 +47,7 @@ cc_library( cc_library( name = "rocm_driver", srcs = if_rocm_is_configured(["rocm_driver.cc"]), - hdrs = [], + hdrs = if_rocm_is_configured(["rocm_driver_wrapper.h"]), deps = if_rocm_is_configured([ ":rocm_diagnostics", "@com_google_absl//absl/base", @@ -57,6 +57,7 @@ cc_library( "//tensorflow/stream_executor/gpu:gpu_driver_header", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -141,60 +142,60 @@ cc_library( ) cc_library( - name = "rocblas_plugin", - srcs = if_rocm_is_configured(["rocm_blas.cc"]), - hdrs = if_rocm_is_configured(["rocm_blas.h"]), - visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - ":rocm_gpu_executor", - ":rocm_platform_id", - "//third_party/eigen3", - "//tensorflow/core:lib_internal", - "//tensorflow/stream_executor", - "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:host_or_device_scalar", - "//tensorflow/stream_executor:plugin_registry", - "//tensorflow/stream_executor:scratch_allocator", - "//tensorflow/stream_executor:timer", - "//tensorflow/stream_executor/gpu:gpu_activation", - "//tensorflow/stream_executor/gpu:gpu_helpers_header", - "//tensorflow/stream_executor/gpu:gpu_stream_header", - "//tensorflow/stream_executor/gpu:gpu_timer_header", - "//tensorflow/stream_executor/lib", - "//tensorflow/stream_executor/platform", - "//tensorflow/stream_executor/platform:dso_loader", - "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:rocm_headers", - ] + if_static([ - "@local_config_rocm//rocm:rocblas" - ])), - alwayslink = True, + name = "rocblas_plugin", + srcs = if_rocm_is_configured(["rocm_blas.cc"]), + hdrs = if_rocm_is_configured(["rocm_blas.h"]), + visibility = ["//visibility:public"], + deps = if_rocm_is_configured([ + ":rocm_gpu_executor", + ":rocm_platform_id", + "//third_party/eigen3", + "//tensorflow/core:lib_internal", + "//tensorflow/stream_executor", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:host_or_device_scalar", + "//tensorflow/stream_executor:plugin_registry", + "//tensorflow/stream_executor:scratch_allocator", + "//tensorflow/stream_executor:timer", + "//tensorflow/stream_executor/gpu:gpu_activation", + "//tensorflow/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/stream_executor/gpu:gpu_stream_header", + "//tensorflow/stream_executor/gpu:gpu_timer_header", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:rocm_headers", + ] + if_static([ + "@local_config_rocm//rocm:rocblas", + ])), + alwayslink = True, ) cc_library( - name = "rocfft_plugin", - srcs = if_rocm_is_configured(["rocm_fft.cc"]), - hdrs = if_rocm_is_configured(["rocm_fft.h"]), - visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - ":rocm_platform_id", - "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:fft", - "//tensorflow/stream_executor:plugin_registry", - "//tensorflow/stream_executor:scratch_allocator", - "//tensorflow/stream_executor/gpu:gpu_activation", - "//tensorflow/stream_executor/gpu:gpu_helpers_header", - "//tensorflow/stream_executor/gpu:gpu_executor_header", - "//tensorflow/stream_executor/gpu:gpu_stream_header", - "//tensorflow/stream_executor/gpu:gpu_kernel_header", - "//tensorflow/stream_executor/lib", - "//tensorflow/stream_executor/platform", - "//tensorflow/stream_executor/platform:dso_loader", - "@local_config_rocm//rocm:rocm_headers", - ] + if_static([ - "@local_config_rocm//rocm:rocfft" - ])), - alwayslink = True, + name = "rocfft_plugin", + srcs = if_rocm_is_configured(["rocm_fft.cc"]), + hdrs = if_rocm_is_configured(["rocm_fft.h"]), + visibility = ["//visibility:public"], + deps = if_rocm_is_configured([ + ":rocm_platform_id", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:fft", + "//tensorflow/stream_executor:plugin_registry", + "//tensorflow/stream_executor:scratch_allocator", + "//tensorflow/stream_executor/gpu:gpu_activation", + "//tensorflow/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/stream_executor/gpu:gpu_executor_header", + "//tensorflow/stream_executor/gpu:gpu_stream_header", + "//tensorflow/stream_executor/gpu:gpu_kernel_header", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + "@local_config_rocm//rocm:rocm_headers", + ] + if_static([ + "@local_config_rocm//rocm:rocfft", + ])), + alwayslink = True, ) # FIXME: enable in future PRs @@ -237,28 +238,28 @@ cc_library( #) cc_library( - name = "rocrand_plugin", - srcs = if_rocm_is_configured(["rocm_rng.cc"]), - hdrs = if_rocm_is_configured([]), - deps = if_rocm_is_configured([ - ":rocm_gpu_executor", - ":rocm_platform_id", - "@local_config_rocm//rocm:rocm_headers", - "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:plugin_registry", - "//tensorflow/stream_executor:rng", - "//tensorflow/stream_executor/gpu:gpu_activation_header", - "//tensorflow/stream_executor/gpu:gpu_helpers_header", - "//tensorflow/stream_executor/gpu:gpu_executor_header", - "//tensorflow/stream_executor/gpu:gpu_rng_header", - "//tensorflow/stream_executor/gpu:gpu_stream_header", - "//tensorflow/stream_executor/lib", - "//tensorflow/stream_executor/platform", - "//tensorflow/stream_executor/platform:dso_loader", - ] + if_static([ - "@local_config_rocm//rocm:hiprand" - ])), - alwayslink = True, + name = "rocrand_plugin", + srcs = if_rocm_is_configured(["rocm_rng.cc"]), + hdrs = if_rocm_is_configured([]), + deps = if_rocm_is_configured([ + ":rocm_gpu_executor", + ":rocm_platform_id", + "@local_config_rocm//rocm:rocm_headers", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:plugin_registry", + "//tensorflow/stream_executor:rng", + "//tensorflow/stream_executor/gpu:gpu_activation_header", + "//tensorflow/stream_executor/gpu:gpu_helpers_header", + "//tensorflow/stream_executor/gpu:gpu_executor_header", + "//tensorflow/stream_executor/gpu:gpu_rng_header", + "//tensorflow/stream_executor/gpu:gpu_stream_header", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/platform", + "//tensorflow/stream_executor/platform:dso_loader", + ] + if_static([ + "@local_config_rocm//rocm:hiprand", + ])), + alwayslink = True, ) cc_library( diff --git a/tensorflow/stream_executor/rocm/rocm_driver.cc b/tensorflow/stream_executor/rocm/rocm_driver.cc index 39d52d28304..73b1b350f71 100644 --- a/tensorflow/stream_executor/rocm/rocm_driver.cc +++ b/tensorflow/stream_executor/rocm/rocm_driver.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h" bool FLAGS_gpuexec_rocm_driver_inject_init_error = false; bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false; @@ -143,7 +144,7 @@ string MemorySpaceString(MemorySpace memory_space) { // HIP driver (e.g., this value is not our cached view of the current device). static int CurrentDeviceOrDie() { int current = -1; - hipError_t result = hipGetDevice(¤t); + hipError_t result = tensorflow::wrap::hipGetDevice(¤t); if (result != hipSuccess) { LOG(FATAL) << "failed to query current device: " << ToString(result); } @@ -154,7 +155,7 @@ namespace { // Call hipDeviceSynchronize and crash if it doesn't succeed. void SynchronizeOrDie() { - auto res = hipDeviceSynchronize(); + auto res = tensorflow::wrap::hipDeviceSynchronize(); if (res != hipSuccess) { LOG(FATAL) << "Synchronize found " << ToString(res) << " :: " << port::CurrentStackTrace(); @@ -197,7 +198,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* context) { << tls->current_device_ordinal << " to " << context->device_ordinal(); // Set the device and update thread local. - CHECK_EQ(hipSuccess, hipSetDevice(context->device_ordinal())); + CHECK_EQ(hipSuccess, tensorflow::wrap::hipSetDevice(context->device_ordinal())); tls->current_device_ordinal = context->device_ordinal(); } @@ -225,7 +226,7 @@ ScopedActivateContext::~ScopedActivateContext() { << to_restore_->device_ordinal(); // Set context and update thread local. - CHECK_EQ(hipSuccess, hipSetDevice(to_restore_->device_ordinal())); + CHECK_EQ(hipSuccess, tensorflow::wrap::hipSetDevice(to_restore_->device_ordinal())); tls->current_device_ordinal = to_restore_->device_ordinal(); } @@ -261,7 +262,7 @@ string ROCMPointerToMemorySpaceString(hipDeviceptr_t pointer) { // in the process of querying. string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) { hipPointerAttribute_t from_pointerAttributes; - hipError_t result = hipPointerGetAttributes(&from_pointerAttributes, from); + hipError_t result = tensorflow::wrap::hipPointerGetAttributes(&from_pointerAttributes, from); if (result != hipSuccess) { LOG(ERROR) << "could not retrieve source pointer's device: " << ToString(result); @@ -269,7 +270,7 @@ string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) { } hipPointerAttribute_t to_pointerAttributes; - result = hipPointerGetAttributes(&to_pointerAttributes, to); + result = tensorflow::wrap::hipPointerGetAttributes(&to_pointerAttributes, to); if (result != hipSuccess) { LOG(ERROR) << "could not retrieve destination pointer's device: " << ToString(result); @@ -289,7 +290,7 @@ static port::Status InternalInit() { if (FLAGS_gpuexec_rocm_driver_inject_init_error) { LOG(ERROR) << "injecting ROCM init error; initialization will fail"; } else { - res = hipInit(0 /* = flags */); + res = tensorflow::wrap::hipInit(0 /* = flags */); } if (res == hipSuccess) { @@ -322,7 +323,7 @@ static port::Status InternalInit() { /* static */ port::Status GpuDriver::GetDevice(int device_ordinal, hipDevice_t* device) { - hipError_t res = hipDeviceGet(device, device_ordinal); + hipError_t res = tensorflow::wrap::hipDeviceGet(device, device_ordinal); if (res == hipSuccess) { return port::Status::OK(); } @@ -336,7 +337,7 @@ static port::Status InternalInit() { string* device_name) { static const size_t kCharLimit = 64; absl::InlinedVector<char, 4> chars(kCharLimit); - hipError_t res = hipDeviceGetName(chars.begin(), kCharLimit - 1, device); + hipError_t res = tensorflow::wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device); if (res != hipSuccess) { LOG(ERROR) << "failed to get device name for " << device << ": " << ToString(res); @@ -382,7 +383,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, /* static */ bool GpuDriver::FuncSetCacheConfig(hipFunction_t function, hipFuncCache_t cache_config) { - hipError_t res = hipFuncSetCacheConfig(function, cache_config); + hipError_t res = tensorflow::wrap::hipFuncSetCacheConfig(function, cache_config); if (res != hipSuccess) { LOG(ERROR) << "failed to set ROCM kernel cache config. kernel: " << function << ", config: " << cache_config << ", result: " << ToString(res); @@ -396,7 +397,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options, GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { hipSharedMemConfig shared_mem_config; ScopedActivateContext activation{context}; - hipError_t result = hipDeviceGetSharedMemConfig(&shared_mem_config); + hipError_t result = tensorflow::wrap::hipDeviceGetSharedMemConfig(&shared_mem_config); if (result != hipSuccess) { LOG(ERROR) << "failed to get ROCM device shared memory config. " << "Context device ID: " << context->device_ordinal() @@ -411,7 +412,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ port::Status GpuDriver::ContextSetSharedMemConfig( GpuContext* context, hipSharedMemConfig shared_mem_config) { ScopedActivateContext activation{context}; - hipError_t result = hipDeviceSetSharedMemConfig(shared_mem_config); + hipError_t result = tensorflow::wrap::hipDeviceSetSharedMemConfig(shared_mem_config); if (result != hipSuccess) { LOG(ERROR) << "failed to set ROCM device shared memory config. " << "Context device ID: " << context->device_ordinal() @@ -435,7 +436,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z << " bdx: " << block_dim_x << " bdy: " << block_dim_y << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes; - hipError_t res = hipModuleLaunchKernel( + hipError_t res = tensorflow::wrap::hipModuleLaunchKernel( function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem_bytes, stream, kernel_params, extra); if (res != hipSuccess) { @@ -471,7 +472,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { ScopedActivateContext activation{context}; void* hsaco_data = const_cast<char*>(hsaco_contents); - hipError_t res = hipModuleLoadData(module, hsaco_data); + hipError_t res = tensorflow::wrap::hipModuleLoadData(module, hsaco_data); if (res != hipSuccess) { LOG(ERROR) << "failed to load HSACO: " << ToString(res); @@ -491,7 +492,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { hipDeviceptr_t location, uint8 value, size_t size) { ScopedActivateContext activation{context}; - hipError_t res = hipMemset(location, value, size); + hipError_t res = tensorflow::wrap::hipMemset(location, value, size); if (res != hipSuccess) { LOG(ERROR) << "failed to memset memory: " << ToString(res); return false; @@ -513,7 +514,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { return false; } hipError_t res = - hipMemset(pointer, static_cast<int>(value), uint32_count * 4); + tensorflow::wrap::hipMemset(pointer, static_cast<int>(value), uint32_count * 4); if (res != hipSuccess) { LOG(ERROR) << "failed to memset memory: " << ToString(res); return false; @@ -527,7 +528,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { size_t uint32_count, GpuStreamHandle stream) { ScopedActivateContext activation{context}; - hipError_t res = hipMemsetAsync(location, value, uint32_count, stream); + hipError_t res = tensorflow::wrap::hipMemsetAsync(location, value, uint32_count, stream); if (res != hipSuccess) { LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res); return false; @@ -552,7 +553,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { LOG(ERROR) << "failed to memset memory"; return false; } - hipError_t res = hipMemsetAsync(pointer, value, uint32_count * 4, stream); + hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value, uint32_count * 4, stream); if (res != hipSuccess) { LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res); return false; @@ -565,7 +566,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream, StreamCallback callback, void* data) { - hipError_t res = hipStreamAddCallback(stream, (hipStreamCallback_t)callback, + hipError_t res = tensorflow::wrap::hipStreamAddCallback(stream, (hipStreamCallback_t)callback, data, 0 /* = flags */); if (res != hipSuccess) { LOG(ERROR) << "unable to add host callback: " << ToString(res); @@ -580,7 +581,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { hipFunction_t* function) { ScopedActivateContext activated{context}; CHECK(module != nullptr && kernel_name != nullptr); - hipError_t res = hipModuleGetFunction(function, module, kernel_name); + hipError_t res = tensorflow::wrap::hipModuleGetFunction(function, module, kernel_name); if (res != hipSuccess) { LOG(ERROR) << "failed to get kernel \"" << kernel_name << "\" from module: " << ToString(res); @@ -598,7 +599,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { ScopedActivateContext activated{context}; CHECK(module != nullptr && symbol_name != nullptr && (dptr != nullptr || bytes != nullptr)); - hipError_t res = hipModuleGetGlobal(dptr, bytes, module, symbol_name); + hipError_t res = tensorflow::wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name); if (res != hipSuccess) { // symbol may not be found in the current module, but it may reside in // another module. @@ -613,7 +614,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ void GpuDriver::UnloadModule(GpuContext* context, hipModule_t module) { ScopedActivateContext activated{context}; - hipError_t res = hipModuleUnload(module); + hipError_t res = tensorflow::wrap::hipModuleUnload(module); if (res != hipSuccess) { LOG(ERROR) << "failed to unload module " << module << "; leaking: " << ToString(res); @@ -623,7 +624,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ bool GpuDriver::CreateStream(GpuContext* context, GpuStreamHandle* stream) { ScopedActivateContext activated{context}; - hipError_t res = hipStreamCreateWithFlags( + hipError_t res = tensorflow::wrap::hipStreamCreateWithFlags( stream, hipStreamDefault); // switch to hipStreamNonBlocking? if (res != hipSuccess) { LOG(ERROR) << "could not allocate ROCM stream for device " @@ -643,7 +644,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { } ScopedActivateContext activated{context}; - hipError_t res = hipStreamDestroy(*stream); + hipError_t res = tensorflow::wrap::hipStreamDestroy(*stream); if (res != hipSuccess) { LOG(ERROR) << "failed to destroy ROCM stream for device " << context->device_ordinal() << ": " << ToString(res); @@ -658,7 +659,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { uint64 bytes) { ScopedActivateContext activated{context}; hipDeviceptr_t result = 0; - hipError_t res = hipMalloc(&result, bytes); + hipError_t res = tensorflow::wrap::hipMallocVanilla(&result, bytes); if (res != hipSuccess) { LOG(ERROR) << "failed to allocate " << port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes @@ -675,7 +676,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { void* location) { ScopedActivateContext activation{context}; hipDeviceptr_t pointer = absl::bit_cast<hipDeviceptr_t>(location); - hipError_t res = hipFree(pointer); + hipError_t res = tensorflow::wrap::hipFree(pointer); if (res != hipSuccess) { LOG(ERROR) << "failed to free device memory at " << location << "; result: " << ToString(res); @@ -704,7 +705,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { ScopedActivateContext activation{context}; void* host_mem = nullptr; // "Portable" memory is visible to all ROCM contexts. Safe for our use model. - hipError_t res = hipHostMalloc(&host_mem, bytes, hipHostMallocPortable); + hipError_t res = tensorflow::wrap::hipHostMallocVanilla(&host_mem, bytes, hipHostMallocPortable); if (res != hipSuccess) { LOG(ERROR) << "failed to alloc " << bytes << " bytes on host: " << ToString(res); @@ -715,7 +716,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ void GpuDriver::HostDeallocate(GpuContext* context, void* location) { ScopedActivateContext activation{context}; - hipError_t res = hipHostFree(location); + hipError_t res = tensorflow::wrap::hipHostFree(location); if (res != hipSuccess) { LOG(ERROR) << "error deallocating host memory at " << location << ": " << ToString(res); @@ -726,7 +727,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { uint64 bytes) { ScopedActivateContext activation{context}; // "Portable" memory is visible to all ROCM contexts. Safe for our use model. - hipError_t res = hipHostRegister(location, bytes, hipHostRegisterPortable); + hipError_t res = tensorflow::wrap::hipHostRegister(location, bytes, hipHostRegisterPortable); if (res != hipSuccess) { LOG(ERROR) << "error registering host memory at " << location << ": " << ToString(res); @@ -738,7 +739,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ bool GpuDriver::HostUnregister(GpuContext* context, void* location) { ScopedActivateContext activation{context}; - hipError_t res = hipHostUnregister(location); + hipError_t res = tensorflow::wrap::hipHostUnregister(location); if (res != hipSuccess) { LOG(ERROR) << "error unregistering host memory at " << location << ": " << ToString(res); @@ -755,7 +756,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { } ScopedActivateContext activated{context}; - hipError_t res = hipEventDestroy(*event); + hipError_t res = tensorflow::wrap::hipEventDestroy(*event); *event = nullptr; switch (res) { @@ -779,7 +780,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuEventHandle event, GpuStreamHandle stream) { ScopedActivateContext activated{context}; - hipError_t res = hipEventRecord(event, stream); + hipError_t res = tensorflow::wrap::hipEventRecord(event, stream); switch (res) { case hipSuccess: return port::Status::OK(); @@ -800,7 +801,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ port::StatusOr<hipError_t> GpuDriver::QueryEvent( GpuContext* context, GpuEventHandle event) { ScopedActivateContext activated{context}; - hipError_t res = hipEventQuery(event); + hipError_t res = tensorflow::wrap::hipEventQuery(event); if (res != hipSuccess && res != hipErrorNotReady) { return port::Status{ port::error::INTERNAL, @@ -817,12 +818,12 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { ScopedActivateContext activated{context}; // The stop event must have completed in order for hipEventElapsedTime to // work. - hipError_t res = hipEventSynchronize(stop); + hipError_t res = tensorflow::wrap::hipEventSynchronize(stop); if (res != hipSuccess) { LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res); return false; } - res = hipEventElapsedTime(elapsed_milliseconds, start, stop); + res = tensorflow::wrap::hipEventElapsedTime(elapsed_milliseconds, start, stop); if (res != hipSuccess) { LOG(ERROR) << "failed to get elapsed time between events: " << ToString(res); @@ -836,7 +837,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream, GpuEventHandle event) { ScopedActivateContext activation{context}; - hipError_t res = hipStreamWaitEvent(stream, event, 0 /* = flags */); + hipError_t res = tensorflow::wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */); if (res != hipSuccess) { LOG(ERROR) << "could not wait stream on event: " << ToString(res); return false; @@ -847,7 +848,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) { ScopedActivateContext activation{context}; - hipError_t res = hipDeviceSynchronize(); + hipError_t res = tensorflow::wrap::hipDeviceSynchronize(); if (res != hipSuccess) { LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res) << " :: " << port::CurrentStackTrace(); @@ -861,7 +862,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream) { ScopedActivateContext activated{context}; CHECK(stream != nullptr); - hipError_t res = hipStreamSynchronize(stream); + hipError_t res = tensorflow::wrap::hipStreamSynchronize(stream); if (res != hipSuccess) { port::Status status = port::InternalError( absl::StrCat("could not synchronize on ROCM stream: ", ToString(res))); @@ -877,7 +878,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream) { ScopedActivateContext activated{context}; CHECK(stream != nullptr); - hipError_t res = hipStreamQuery(stream); + hipError_t res = tensorflow::wrap::hipStreamQuery(stream); if (res == hipSuccess) { return true; } @@ -891,7 +892,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ port::Status GpuDriver::SynchronousMemcpyD2H( GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, uint64 size) { ScopedActivateContext activation{context}; - hipError_t res = hipMemcpyDtoH(host_dst, gpu_src, size); + hipError_t res = tensorflow::wrap::hipMemcpyDtoH(host_dst, gpu_src, size); if (res != hipSuccess) { return port::InternalError( absl::StrFormat("failed to synchronous memcpy from device to host: %s; " @@ -908,7 +909,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src, uint64 size) { ScopedActivateContext activation{context}; - hipError_t res = hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size); + hipError_t res = tensorflow::wrap::hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size); if (res != hipSuccess) { return port::InternalError(absl::StrFormat( "failed to synchronous memcpy from host to device: %s; Gpu dst: %p;" @@ -924,7 +925,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src, uint64 size) { ScopedActivateContext activation{context}; - hipError_t res = hipMemcpyDtoD(gpu_dst, gpu_src, size); + hipError_t res = tensorflow::wrap::hipMemcpyDtoD(gpu_dst, gpu_src, size); if (res != hipSuccess) { return port::InternalError(absl::StrFormat( "failed to synchronous memcpy from host to device: %s; Gpu dst: %p; " @@ -942,7 +943,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { uint64 size, GpuStreamHandle stream) { ScopedActivateContext activation{context}; - hipError_t res = hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream); + hipError_t res = tensorflow::wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream); if (res != hipSuccess) { LOG(ERROR) << absl::StrFormat( "failed to enqueue async memcpy from device to host: %s; host dst: %p; " @@ -964,7 +965,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { GpuStreamHandle stream) { ScopedActivateContext activation{context}; hipError_t res = - hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream); + tensorflow::wrap::hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream); if (res != hipSuccess) { LOG(ERROR) << absl::StrFormat( "failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; " @@ -984,7 +985,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { uint64 size, GpuStreamHandle stream) { ScopedActivateContext activation{context}; - hipError_t result = hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream); + hipError_t result = tensorflow::wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream); if (result != hipSuccess) { LOG(ERROR) << absl::StrFormat( "failed to enqueue async memcpy from device to device: %s" @@ -1021,7 +1022,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { } ScopedActivateContext activated{context}; - hipError_t res = hipEventCreateWithFlags(event, hipflags); + hipError_t res = tensorflow::wrap::hipEventCreateWithFlags(event, hipflags); if (res == hipSuccess) { return port::Status::OK(); @@ -1037,7 +1038,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ int GpuDriver::GetDeviceCount() { int device_count = 0; - hipError_t res = hipGetDeviceCount(&device_count); + hipError_t res = tensorflow::wrap::hipGetDeviceCount(&device_count); if (res != hipSuccess) { LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res); return 0; @@ -1061,7 +1062,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ port::Status GpuDriver::GetPointerAddressRange( hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) { - hipError_t result = hipMemGetAddressRange(base, size, dptr); + hipError_t result = tensorflow::wrap::hipMemGetAddressRange(base, size, dptr); if (result == hipSuccess) { return port::Status::OK(); } else if (result == hipErrorNotFound) { @@ -1106,7 +1107,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ port::StatusOr<hipDevice_t> GpuDriver::GetPointerDevice( hipDeviceptr_t pointer) { hipPointerAttribute_t pointerAttributes; - hipError_t result = hipPointerGetAttributes(&pointerAttributes, pointer); + hipError_t result = tensorflow::wrap::hipPointerGetAttributes(&pointerAttributes, pointer); if (result != hipSuccess) { return port::Status{ port::error::INTERNAL, @@ -1114,7 +1115,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { } hipDevice_t device; - result = hipDeviceGet(&device, pointerAttributes.device); + result = tensorflow::wrap::hipDeviceGet(&device, pointerAttributes.device); if (result != hipSuccess) { return port::Status{ port::error::INTERNAL, @@ -1127,7 +1128,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) { /* static */ port::Status GpuDriver::GetGpuISAVersion(int* version, hipDevice_t device) { hipDeviceProp_t props; - hipError_t result = hipGetDeviceProperties(&props, device); + hipError_t result = tensorflow::wrap::hipGetDeviceProperties(&props, device); if (result == hipSuccess) { *version = props.gcnArch; return port::Status::OK(); @@ -1145,7 +1146,7 @@ template <typename T> static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, hipDeviceAttribute_t attribute) { int value = -1; - hipError_t result = hipDeviceGetAttribute(&value, attribute, device); + hipError_t result = tensorflow::wrap::hipDeviceGetAttribute(&value, attribute, device); if (result != hipSuccess) { return port::Status{ port::error::NOT_FOUND, @@ -1200,21 +1201,21 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, hipDevice_t device) { int value; hipError_t res = - hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device); + tensorflow::wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query max grid dim x: " << ToString(res); return false; } *x = value; - res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device); + res = tensorflow::wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query max grid dim y: " << ToString(res); return false; } *y = value; - res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device); + res = tensorflow::wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query max grid dim z: " << ToString(res); return false; @@ -1224,7 +1225,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, } /* static */ bool GpuDriver::GetDriverVersion(int* driver_version) { - hipError_t res = hipDriverGetVersion(driver_version); + hipError_t res = tensorflow::wrap::hipDriverGetVersion(driver_version); if (res != hipSuccess) { LOG(ERROR) << "failed to query driver version: " << ToString(res); return false; @@ -1235,7 +1236,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, /* static */ bool GpuDriver::GetDeviceProperties( hipDeviceProp_t* device_properties, int device_ordinal) { - hipError_t res = hipGetDeviceProperties(device_properties, device_ordinal); + hipError_t res = tensorflow::wrap::hipGetDeviceProperties(device_properties, device_ordinal); if (res != hipSuccess) { LOG(ERROR) << "failed to query device properties: " << ToString(res); return false; @@ -1268,7 +1269,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, ScopedActivateContext activation{context}; size_t free = 0; size_t total = 0; - hipError_t res = hipMemGetInfo(&free, &total); + hipError_t res = tensorflow::wrap::hipMemGetInfo(&free, &total); if (res != hipSuccess) { LOG(ERROR) << "failed to query device memory info: " << ToString(res); return false; @@ -1282,7 +1283,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, /* static */ bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device, uint64* result) { size_t value = -1; - hipError_t res = hipDeviceTotalMem(&value, device); + hipError_t res = tensorflow::wrap::hipDeviceTotalMem(&value, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query total available memory: " << ToString(res); return false; @@ -1297,7 +1298,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, static const int kBufferSize = 64; absl::InlinedVector<char, 4> chars(kBufferSize); chars[kBufferSize - 1] = '\0'; - hipError_t res = hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); + hipError_t res = tensorflow::wrap::hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); if (res != hipSuccess) { LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res); return pci_bus_id; @@ -1313,7 +1314,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, } int can_access_peer = -1; - hipError_t res = hipDeviceCanAccessPeer( + hipError_t res = tensorflow::wrap::hipDeviceCanAccessPeer( &can_access_peer, from->device_ordinal(), to->device_ordinal()); if (res != hipSuccess) { LOG(ERROR) << "failed to detect peer access capability: " << ToString(res); @@ -1331,7 +1332,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device, ScopedActivateContext activated{from}; hipError_t result = - hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */); + tensorflow::wrap::hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */); if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) { return port::Status{ port::error::INTERNAL, diff --git a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h new file mode 100644 index 00000000000..0a0ab3ae745 --- /dev/null +++ b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file wraps rocm driver calls with dso loader so that we don't need to +// have explicit linking to librocm. All TF rocm driver usage should route +// through this wrapper. + +#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ + +#include "tensorflow/stream_executor/lib/env.h" +#include "tensorflow/stream_executor/platform/dso_loader.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "rocm/include/hip/hip_runtime.h" + + +#if defined(TENSORFLOW_USE_ROCM) + +#endif + +namespace tensorflow { +namespace wrap { +#ifdef PLATFORM_GOOGLE +// Use static linked library +#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ + template <typename... Args> \ + auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \ + return ::hipSymbolName(args...); \ + } + +// This macro wraps a global identifier, given by hipSymbolName, in a callable +// structure that loads the DLL symbol out of the DSO handle in a thread-safe +// manner on first use. This dynamic loading technique is used to avoid DSO +// dependencies on vendor libraries which may or may not be available in the +// deployed binary environment. +#else +#define TO_STR_(x) #x +#define TO_STR(x) TO_STR_(x) + +// hipMalloc and hipHostMalloc are defined as funtion templates in the +// HIP header files, and hence their names get mangled and the attempt +// to resolve their name when trying to dynamically load them will fail +// Updating the HIP header files to make them C functions is underway. +// Until that change flows through, we will workaround the issue by +// creating dummy wrappers for them here + +hipError_t hipMallocVanilla(void** ptr, size_t size) { + return hipErrorNotInitialized; +} + +hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) { + return hipErrorNotInitialized; +} + +#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ + template <typename... Args> \ + auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \ + using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char *kName = TO_STR(hipSymbolName); \ + void *f; \ + auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \ + stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \ + .ValueOrDie(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in HIP DSO; dlerror: " << s.error_message(); \ + return reinterpret_cast<FuncPtrT>(f); \ + }(); \ + return loaded(args...); \ + } +#endif + +// clang-format off +#define HIP_ROUTINE_EACH(__macro) \ + __macro(hipDeviceCanAccessPeer) \ + __macro(hipDeviceEnablePeerAccess) \ + __macro(hipDeviceGet) \ + __macro(hipDeviceGetAttribute) \ + __macro(hipDeviceGetName) \ + __macro(hipDeviceGetPCIBusId) \ + __macro(hipDeviceGetSharedMemConfig) \ + __macro(hipDeviceSetSharedMemConfig) \ + __macro(hipDeviceSynchronize) \ + __macro(hipDeviceTotalMem) \ + __macro(hipDriverGetVersion) \ + __macro(hipEventCreateWithFlags) \ + __macro(hipEventElapsedTime) \ + __macro(hipEventDestroy) \ + __macro(hipEventQuery) \ + __macro(hipEventRecord) \ + __macro(hipEventSynchronize) \ + __macro(hipFree) \ + __macro(hipFuncSetCacheConfig) \ + __macro(hipGetDevice) \ + __macro(hipGetDeviceCount) \ + __macro(hipGetDeviceProperties) \ + __macro(hipHostFree) \ + __macro(hipHostRegister) \ + __macro(hipHostUnregister) \ + __macro(hipInit) \ + __macro(hipMemGetAddressRange) \ + __macro(hipMemGetInfo) \ + __macro(hipMemcpyDtoD) \ + __macro(hipMemcpyDtoDAsync) \ + __macro(hipMemcpyDtoH) \ + __macro(hipMemcpyDtoHAsync) \ + __macro(hipMemcpyHtoD) \ + __macro(hipMemcpyHtoDAsync) \ + __macro(hipMemset) \ + __macro(hipMemsetAsync) \ + __macro(hipModuleGetFunction) \ + __macro(hipModuleGetGlobal) \ + __macro(hipModuleLaunchKernel) \ + __macro(hipModuleLoadData) \ + __macro(hipModuleUnload) \ + __macro(hipPointerGetAttributes) \ + __macro(hipSetDevice) \ + __macro(hipStreamAddCallback) \ + __macro(hipStreamCreateWithFlags) \ + __macro(hipStreamDestroy) \ + __macro(hipStreamQuery) \ + __macro(hipStreamSynchronize) \ + __macro(hipStreamWaitEvent) \ +// clang-format on + +HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP) +#undef HIP_ROUTINE_EACH +#undef STREAM_EXECUTOR_HIP_WRAP +#undef TO_STR +#undef TO_STR_ +} // namespace wrap +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ From 7066b1c04dab63bd44c1450fc93032e58727e3e6 Mon Sep 17 00:00:00 2001 From: Deven Desai <deven.desai.amd@gmail.com> Date: Fri, 1 Feb 2019 18:32:40 +0000 Subject: [PATCH 7/7] adding a check in in the initialization routines for rocblas, rocfft, rocrand to avoid duplicate registrations --- tensorflow/stream_executor/rocm/rocm_blas.cc | 64 +++++++++++--------- tensorflow/stream_executor/rocm/rocm_fft.cc | 47 +++++++------- tensorflow/stream_executor/rocm/rocm_rng.cc | 61 ++++++++++--------- 3 files changed, 95 insertions(+), 77 deletions(-) diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index a626d168c26..2137bc00275 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -2324,35 +2324,43 @@ bool ROCMBlas::DoBlasGemmStridedBatched( } // namespace gpu void initialize_rocblas() { - port::Status status = - PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>( - rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS", - [](internal::StreamExecutorInterface* parent) -> blas::BlasSupport* { - gpu::GpuExecutor* rocm_executor = - dynamic_cast<gpu::GpuExecutor*>(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the rocBLAS " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor); - if (!blas->Init()) { - // Note: Init() will log a more specific error. - delete blas; - return nullptr; - } - return blas; - }); - - if (!status.ok()) { - LOG(ERROR) << "Unable to register rocBLAS factory: " - << status.error_message(); - } - - PluginRegistry::Instance()->SetDefaultFactory( + auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory( rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin); + + if (!rocBlasAlreadyRegistered) { + port::Status status = + PluginRegistry::Instance() + ->RegisterFactory<PluginRegistry::BlasFactory>( + rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS", + [](internal::StreamExecutorInterface* parent) + -> blas::BlasSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast<gpu::GpuExecutor*>(parent); + if (rocm_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the " + "rocBLAS " + << "support library with a non-ROCM StreamExecutor"; + return nullptr; + } + + gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor); + if (!blas->Init()) { + // Note: Init() will log a more specific error. + delete blas; + return nullptr; + } + return blas; + }); + + if (!status.ok()) { + LOG(ERROR) << "Unable to register rocBLAS factory: " + << status.error_message(); + } + + PluginRegistry::Instance()->SetDefaultFactory( + rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin); + } } } // namespace stream_executor diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc index e8a72f61d3f..b23e05d9dde 100644 --- a/tensorflow/stream_executor/rocm/rocm_fft.cc +++ b/tensorflow/stream_executor/rocm/rocm_fft.cc @@ -592,28 +592,33 @@ STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D) } // namespace gpu void initialize_rocfft() { - port::Status status = - PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>( - rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT", - [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * { - gpu::GpuExecutor *rocm_executor = - dynamic_cast<gpu::GpuExecutor *>(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the rocFFT " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - return new gpu::ROCMFft(rocm_executor); - }); - if (!status.ok()) { - LOG(ERROR) << "Unable to register rocFFT factory: " - << status.error_message(); - } - - PluginRegistry::Instance()->SetDefaultFactory( + auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory( rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin); + + if (!rocFftAlreadyRegistered) { + port::Status status = + PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>( + rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT", + [](internal::StreamExecutorInterface* parent) -> fft::FftSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast<gpu::GpuExecutor*>(parent); + if (rocm_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the rocFFT " + << "support library with a non-ROCM StreamExecutor"; + return nullptr; + } + + return new gpu::ROCMFft(rocm_executor); + }); + if (!status.ok()) { + LOG(ERROR) << "Unable to register rocFFT factory: " + << status.error_message(); + } + + PluginRegistry::Instance()->SetDefaultFactory( + rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin); + } } } // namespace stream_executor diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc index 27797845700..545bfc6d943 100644 --- a/tensorflow/stream_executor/rocm/rocm_rng.cc +++ b/tensorflow/stream_executor/rocm/rocm_rng.cc @@ -283,35 +283,40 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) { } // namespace gpu void initialize_rocrand() { - port::Status status = - PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>( - rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND", - [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* { - gpu::GpuExecutor* rocm_executor = - dynamic_cast<gpu::GpuExecutor*>(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the hipRAND " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor); - if (!rng->Init()) { - // Note: Init() will log a more specific error. - delete rng; - return nullptr; - } - return rng; - }); - - if (!status.ok()) { - LOG(ERROR) << "Unable to register rocRAND factory: " - << status.error_message(); - } - - PluginRegistry::Instance()->SetDefaultFactory( + auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory( rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin); + + if (!rocRandAlreadyRegistered) { + port::Status status = + PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>( + rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND", + [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* { + gpu::GpuExecutor* rocm_executor = + dynamic_cast<gpu::GpuExecutor*>(parent); + if (rocm_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the hipRAND " + << "support library with a non-ROCM StreamExecutor"; + return nullptr; + } + + gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor); + if (!rng->Init()) { + // Note: Init() will log a more specific error. + delete rng; + return nullptr; + } + return rng; + }); + + if (!status.ok()) { + LOG(ERROR) << "Unable to register rocRAND factory: " + << status.error_message(); + } + + PluginRegistry::Instance()->SetDefaultFactory( + rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin); + } } } // namespace stream_executor