From aa26dce3be923a3833f52384ec194c2cfac76d43 Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Thu, 31 Jan 2019 15:16:57 +0000
Subject: [PATCH 1/7] proper fix the earlier compile error for --config=rocm, +
 some minor changes

---
 tensorflow/stream_executor/rocm/BUILD         |  2 +-
 .../stream_executor/rocm/rocm_diagnostics.cc  | 24 +++++++----
 .../stream_executor/rocm/rocm_diagnostics.h   | 41 +++++++++++++++++++
 .../stream_executor/rocm/rocm_gpu_executor.cc | 19 ++++-----
 .../stream_executor/rocm/rocm_platform.cc     |  2 +-
 .../stream_executor/rocm/rocm_platform_id.cc  |  2 +-
 .../stream_executor/rocm/rocm_platform_id.h   |  6 +--
 tensorflow/stream_executor/rocm/rocm_rng.cc   |  4 +-
 8 files changed, 72 insertions(+), 28 deletions(-)
 create mode 100644 tensorflow/stream_executor/rocm/rocm_diagnostics.h

diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index dd08b36308c..38cd19b3cb8 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -33,7 +33,7 @@ filegroup(
 cc_library(
     name = "rocm_diagnostics",
     srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
-    hdrs = [],
+    hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
     deps = if_rocm_is_configured([
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
diff --git a/tensorflow/stream_executor/rocm/rocm_diagnostics.cc b/tensorflow/stream_executor/rocm/rocm_diagnostics.cc
index c6da7f9e3da..01492a8ef97 100644
--- a/tensorflow/stream_executor/rocm/rocm_diagnostics.cc
+++ b/tensorflow/stream_executor/rocm/rocm_diagnostics.cc
@@ -30,7 +30,7 @@ limitations under the License.
 #include "absl/container/inlined_vector.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
-#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
+#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
 #include "tensorflow/stream_executor/lib/error.h"
 #include "tensorflow/stream_executor/lib/numbers.h"
 #include "tensorflow/stream_executor/lib/process_state.h"
@@ -40,7 +40,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/platform/logging.h"
 
 namespace stream_executor {
-namespace gpu {
+namespace rocm {
 
 string DriverVersionToString(DriverVersion version) {
   return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
@@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) {
   return result;
 }
 
+}  // namespace rocm
+}  // namespace stream_executor
+
+namespace stream_executor {
+namespace gpu {
+
 // -- class Diagnostician
 
 string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
@@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() {
   }
   port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
   LOG(INFO) << "librocm reported version is: "
-            << DriverVersionStatusToString(dso_version);
+            << rocm::DriverVersionStatusToString(dso_version);
 
   port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
   LOG(INFO) << "kernel reported version is: "
-            << DriverVersionStatusToString(kernel_version);
+            << rocm::DriverVersionStatusToString(kernel_version);
 
   if (kernel_version.ok() && dso_version.ok()) {
     WarnOnDsoKernelMismatch(dso_version, kernel_version);
@@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
       // TODO(b/22689637): Eliminate the explicit namespace if possible.
       auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
       auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
-      *result = StringToDriverVersion(stripped_dso_version);
+      *result = rocm::StringToDriverVersion(stripped_dso_version);
       return 1;
     }
     return 0;
@@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
   // TODO(b/22689637): Eliminate the explicit namespace if possible.
   auto stripped_kernel_version =
       port::StripSuffixString(kernel_version, ".ld64");
-  return StringToDriverVersion(stripped_kernel_version);
+  return rocm::StringToDriverVersion(stripped_kernel_version);
 }
 
 void Diagnostician::WarnOnDsoKernelMismatch(
@@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch(
   if (kernel_version.ok() && dso_version.ok() &&
       dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
     LOG(INFO) << "kernel version seems to match DSO: "
-              << DriverVersionToString(kernel_version.ValueOrDie());
+              << rocm::DriverVersionToString(kernel_version.ValueOrDie());
   } else {
     LOG(ERROR) << "kernel version "
-               << DriverVersionStatusToString(kernel_version)
+               << rocm::DriverVersionStatusToString(kernel_version)
                << " does not match DSO version "
-               << DriverVersionStatusToString(dso_version)
+               << rocm::DriverVersionStatusToString(dso_version)
                << " -- cannot find working devices in this configuration";
   }
 }
diff --git a/tensorflow/stream_executor/rocm/rocm_diagnostics.h b/tensorflow/stream_executor/rocm/rocm_diagnostics.h
new file mode 100644
index 00000000000..233c6bdade6
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/rocm_diagnostics.h
@@ -0,0 +1,41 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
+
+#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
+
+namespace stream_executor {
+namespace rocm {
+
+// e.g. DriverVersion{346, 3, 4}
+using DriverVersion = gpu::DriverVersion;
+
+// Converts a parsed driver version to string form.
+string DriverVersionToString(DriverVersion version);
+
+// Converts a parsed driver version or status value to natural string form.
+string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
+
+// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
+port::StatusOr<DriverVersion> StringToDriverVersion(const string& value);
+
+using Diagnostician = gpu::Diagnostician;
+
+}  // namespace rocm
+}  // namespace stream_executor
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
index 0e38556ee7d..684172cfb5a 100644
--- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
+++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc
@@ -18,7 +18,6 @@ limitations under the License.
 #include "absl/base/casts.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
-#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
 #include "tensorflow/stream_executor/gpu/gpu_event.h"
 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
@@ -41,6 +40,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/platform/port.h"
 #include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
 #include "tensorflow/stream_executor/stream.h"
 #include "tensorflow/stream_executor/stream_executor_internal.h"
@@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
 blas::BlasSupport* GpuExecutor::CreateBlas() {
   PluginRegistry* registry = PluginRegistry::Instance();
   port::StatusOr<PluginRegistry::BlasFactory> status =
-      registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId,
+    registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId,
                                                         plugin_config_.blas());
   if (!status.ok()) {
     LOG(ERROR) << "Unable to retrieve BLAS factory: "
@@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() {
 dnn::DnnSupport* GpuExecutor::CreateDnn() {
   PluginRegistry* registry = PluginRegistry::Instance();
   port::StatusOr<PluginRegistry::DnnFactory> status =
-      registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId,
+    registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId,
                                                        plugin_config_.dnn());
   if (!status.ok()) {
     LOG(ERROR) << "Unable to retrieve DNN factory: "
@@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() {
 fft::FftSupport* GpuExecutor::CreateFft() {
   PluginRegistry* registry = PluginRegistry::Instance();
   port::StatusOr<PluginRegistry::FftFactory> status =
-      registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId,
+    registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId,
                                                        plugin_config_.fft());
   if (!status.ok()) {
     LOG(ERROR) << "Unable to retrieve FFT factory: "
@@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() {
 rng::RngSupport* GpuExecutor::CreateRng() {
   PluginRegistry* registry = PluginRegistry::Instance();
   port::StatusOr<PluginRegistry::RngFactory> status =
-      registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId,
+    registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId,
                                                        plugin_config_.rng());
   if (!status.ok()) {
     LOG(ERROR) << "Unable to retrieve RNG factory: "
@@ -878,12 +878,9 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const {
   {
     int driver_version = 0;
     (void)GpuDriver::GetDriverVersion(&driver_version);
-    string augmented_driver_version =
-        absl::StrFormat("%d (%s)", driver_version, "__FIXME__");
-    // FIXME:
-    // uncomment the line below once the "DriverVersionStatusToString"
-    // routine is moved from the "cuda" namespace to the "gpu" naemspace
-    // DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
+    string augmented_driver_version = absl::StrFormat(
+        "%d (%s)", driver_version,
+        rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
     builder.set_driver_version(augmented_driver_version);
   }
 
diff --git a/tensorflow/stream_executor/rocm/rocm_platform.cc b/tensorflow/stream_executor/rocm/rocm_platform.cc
index 113371dd553..ce091658da4 100644
--- a/tensorflow/stream_executor/rocm/rocm_platform.cc
+++ b/tensorflow/stream_executor/rocm/rocm_platform.cc
@@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
       absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
 }
 
-Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; }
+Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
 
 int ROCmPlatform::VisibleDeviceCount() const {
   // Throw away the result - it logs internally, and this [containing] function
diff --git a/tensorflow/stream_executor/rocm/rocm_platform_id.cc b/tensorflow/stream_executor/rocm/rocm_platform_id.cc
index daa42ab022a..16f48bf12d2 100644
--- a/tensorflow/stream_executor/rocm/rocm_platform_id.cc
+++ b/tensorflow/stream_executor/rocm/rocm_platform_id.cc
@@ -16,7 +16,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
 
 namespace stream_executor {
-namespace gpu {
+namespace rocm {
 
 PLATFORM_DEFINE_ID(kROCmPlatformId);
 
diff --git a/tensorflow/stream_executor/rocm/rocm_platform_id.h b/tensorflow/stream_executor/rocm/rocm_platform_id.h
index 71c760b8277..a17d4f97bbc 100644
--- a/tensorflow/stream_executor/rocm/rocm_platform_id.h
+++ b/tensorflow/stream_executor/rocm/rocm_platform_id.h
@@ -19,16 +19,16 @@ limitations under the License.
 #include "tensorflow/stream_executor/platform.h"
 
 namespace stream_executor {
-namespace gpu {
+namespace rocm {
 
 // Opaque and unique identifier for the ROCm platform.
 // This is needed so that plugins can refer to/identify this platform without
 // instantiating a ROCmPlatform object.
 // This is broken out here to avoid a circular dependency between ROCmPlatform
-// and GpuExecutor.
+// and ROCmExecutor.
 extern const Platform::Id kROCmPlatformId;
 
-}  // namespace gpu
+}  // namespace rocm
 }  // namespace stream_executor
 
 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_
diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc
index 2048c8ff644..65acd03c92b 100644
--- a/tensorflow/stream_executor/rocm/rocm_rng.cc
+++ b/tensorflow/stream_executor/rocm/rocm_rng.cc
@@ -253,7 +253,7 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, {
   se::port::Status status =
       se::PluginRegistry::Instance()
           ->RegisterFactory<se::PluginRegistry::RngFactory>(
-              se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
+              se::rocm::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
               [](se::internal::StreamExecutorInterface* parent)
                   -> se::rng::RngSupport* {
                 se::gpu::GpuExecutor* rocm_executor =
@@ -280,5 +280,5 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, {
   }
 
   se::PluginRegistry::Instance()->SetDefaultFactory(
-      se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
+      se::rocm::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
 });

From 298383d00b5d385d0f58e5ca5bd860fc4508d37a Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Thu, 31 Jan 2019 19:23:21 +0000
Subject: [PATCH 2/7] adding code for rocblas plugin

---
 tensorflow/stream_executor/rocm/BUILD        |   64 +-
 tensorflow/stream_executor/rocm/rocm_blas.cc | 2324 ++++++++++++++++++
 tensorflow/stream_executor/rocm/rocm_blas.h  |  159 ++
 3 files changed, 2518 insertions(+), 29 deletions(-)
 create mode 100644 tensorflow/stream_executor/rocm/rocm_blas.cc
 create mode 100644 tensorflow/stream_executor/rocm/rocm_blas.h

diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index 38cd19b3cb8..737a4429469 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -140,34 +140,40 @@ cc_library(
     deps = ["//tensorflow/stream_executor:platform"],
 )
 
-# FIXME: enable in future PRs
-#cc_library(
-#    name = "rocblas_plugin",
-#    srcs = ["rocm_blas.cc"],
-#    hdrs = ["rocm_blas.h"],
-#    visibility = ["//visibility:public"],
-#    deps = [
-#        ":rocm_gpu_executor",
-#        ":rocm_platform_id",
-#        "//third_party/eigen3",
-#        "//tensorflow/core:lib_internal",
-#        "//tensorflow/stream_executor",
-#        "//tensorflow/stream_executor:event",
-#        "//tensorflow/stream_executor:host_or_device_scalar",
-#        "//tensorflow/stream_executor:plugin_registry",
-#        "//tensorflow/stream_executor:scratch_allocator",
-#        "//tensorflow/stream_executor:timer",
-#        "//tenosrflow/stream_executor/gpu:gpu_activation_header",
-#        "//tenosrflow/stream_executor/gpu:gpu_stream_header",
-#        "//tenosrflow/stream_executor/gpu:gpu_timer_header",
-#        "//tensorflow/stream_executor/lib",
-#        "//tensorflow/stream_executor/platform",
-#        "//tensorflow/stream_executor/platform:dso_loader",
-#        "@com_google_absl//absl/strings",
-#        "@local_config_rocm//rocm:rocm_headers",
-#    ] + if_static(["@local_config_rocm//rocm:rocblas"]),
-#    alwayslink = True,
-#)
+cc_library(
+   name = "rocblas_plugin",
+   srcs = if_rocm_is_configured(["rocm_blas.cc"]),
+   hdrs = if_rocm_is_configured(["rocm_blas.h"]),
+   visibility = ["//visibility:public"],
+   deps = if_rocm_is_configured([
+       ":rocm_gpu_executor",
+       ":rocm_platform_id",
+       "//third_party/eigen3",
+       "//tensorflow/core:lib_internal",
+       "//tensorflow/stream_executor",
+       "//tensorflow/stream_executor:event",
+       "//tensorflow/stream_executor:host_or_device_scalar",
+       "//tensorflow/stream_executor:plugin_registry",
+       "//tensorflow/stream_executor:scratch_allocator",
+       "//tensorflow/stream_executor:timer",
+       "//tensorflow/stream_executor/gpu:gpu_activation",
+       "//tensorflow/stream_executor/gpu:gpu_helpers_header",
+       "//tensorflow/stream_executor/gpu:gpu_stream_header",
+       "//tensorflow/stream_executor/gpu:gpu_timer_header",
+       "//tensorflow/stream_executor/lib",
+       "//tensorflow/stream_executor/platform",
+       "//tensorflow/stream_executor/platform:dso_loader",
+       "@com_google_absl//absl/strings",
+       "@local_config_rocm//rocm:rocm_headers",
+   ] + if_static([
+       "@local_config_rocm//rocm:rocblas"
+       # Delete the following line once we switch the rocblas library from
+       # being dynamically linked (current behaviour) to being dynamically
+       # loaded (future behaviour)
+       ], ["@local_config_rocm//rocm:rocblas"
+   ])),
+   alwayslink = True,
+)
 
 # FIXME: enable in future PRs
 #cc_library(
@@ -258,7 +264,7 @@ cc_library(
         # FIXME: enable in future PRs
         #":miopen_plugin",
         #":rocfft_plugin",
-        #":rocblas_plugin",
+        ":rocblas_plugin",
         #":rocrand_plugin",
         ":rocm_driver",
         ":rocm_platform",
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
new file mode 100644
index 00000000000..b2e225433e5
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -0,0 +1,2324 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "rocm/include/rocblas.h"
+
+#include "tensorflow/stream_executor/rocm/rocm_blas.h"
+
+#define EIGEN_USE_GPU
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include <assert.h>
+#include <complex>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/stream_executor/gpu/gpu_activation.h"
+#include "tensorflow/stream_executor/gpu/gpu_executor.h"
+#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
+#include "tensorflow/stream_executor/gpu/gpu_stream.h"
+#include "tensorflow/stream_executor/gpu/gpu_timer.h"
+#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/status_macros.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/scratch_allocator.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace stream_executor {
+namespace gpu {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
+
+namespace wrap {
+
+#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                       \
+  struct WrapperShim__##__name {                                   \
+    static const char* kName;                                      \
+    template <typename... Args>                                    \
+    rocblas_status operator()(GpuExecutor* parent, Args... args) { \
+      gpu::ScopedActivateExecutorContext sac{parent};              \
+      return ::__name(args...);                                    \
+    }                                                              \
+  } __name;                                                        \
+  const char* WrapperShim__##__name::kName = #__name;
+
+#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
+  STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
+
+#define ROCBLAS_BLAS_ROUTINE_EACH(__macro)                                     \
+  __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /*  __macro(rocblas_scnrm2)    \
+                                                  __macro(rocblas_dznrm2) */   \
+      __macro(rocblas_sdot)                                                    \
+          __macro(rocblas_ddot) /*  __macro(rocblas_cdotu)                     \
+                                  __macro(rocblas_cdotc)                       \
+                                  __macro(rocblas_zdotu)                       \
+                                  __macro(rocblas_zdotc)                    */ \
+      __macro(rocblas_sscal)                                                   \
+          __macro(rocblas_dscal) /*  __macro(rocblas_cscal)                    \
+                                   __macro(rocblas_csscal)                     \
+                                   __macro(rocblas_zscal)                      \
+                                   __macro(rocblas_zdscal) */                  \
+      __macro(rocblas_saxpy)                                                   \
+          __macro(rocblas_daxpy) /*  __macro(rocblas_caxpy)                    \
+                                   __macro(rocblas_zaxpy) */                   \
+      __macro(rocblas_scopy)                                                   \
+          __macro(rocblas_dcopy) /*  __macro(rocblas_ccopy)                    \
+                                   __macro(rocblas_zcopy) */                   \
+      __macro(rocblas_sswap)                                                   \
+          __macro(rocblas_dswap) /*  __macro(rocblas_cswap)                    \
+                                   __macro(rocblas_zswap) */                   \
+      __macro(rocblas_isamax)                                                  \
+          __macro(rocblas_idamax) /*  __macro(rocblas_icamax)                  \
+                                    __macro(rocblas_izamax) */                 \
+      __macro(rocblas_isamin)                                                  \
+          __macro(rocblas_idamin) /*  __macro(rocblas_icamin)                  \
+                                    __macro(rocblas_izamin) */                 \
+      __macro(rocblas_sasum)                                                   \
+          __macro(rocblas_dasum) /*  __macro(rocblas_scasum)                   \
+                                   __macro(rocblas_dzasum)                     \
+                                   __macro(rocblas_srot)                       \
+                                   __macro(rocblas_drot)                       \
+                                   __macro(rocblas_crot)                       \
+                                   __macro(rocblas_csrot)                      \
+                                   __macro(rocblas_zrot)                       \
+                                   __macro(rocblas_zdrot)                      \
+                                   __macro(rocblas_srotg)                      \
+                                   __macro(rocblas_drotg)                      \
+                                   __macro(rocblas_Crotg)                      \
+                                   __macro(rocblas_crotg)                      \
+                                   __macro(rocblas_zrotm)                      \
+                                   __macro(rocblas_drotm)                      \
+                                   __macro(rocblas_srotmg)                     \
+                                   __macro(rocblas_drotmg) */                  \
+      __macro(rocblas_sgemv)                                                   \
+          __macro(rocblas_dgemv) /*  __macro(rocblas_cgemv)                    \
+                                   __macro(rocblas_zgemv)                      \
+                                   __macro(rocblas_sgbmv)                      \
+                                   __macro(rocblas_dgbmv)                      \
+                                   __macro(rocblas_cgbmv)                      \
+                                   __macro(rocblas_zgbmv)                      \
+                                   __macro(rocblas_strmv)                      \
+                                   __macro(rocblas_dtrmv)                      \
+                                   __macro(rocblas_ctrmv)                      \
+                                   __macro(rocblas_ztrmv)                      \
+                                   __macro(rocblas_stbmv)                      \
+                                   __macro(rocblas_dtbmv)                      \
+                                   __macro(rocblas_ctbmv)                      \
+                                   __macro(rocblas_ztbmv)                      \
+                                   __macro(rocblas_stpmv)                      \
+                                   __macro(rocblas_dtpmv)                      \
+                                   __macro(rocblas_ctpmv)                      \
+                                   __macro(rocblas_ztpmv)                      \
+                                   __macro(rocblas_strsv)                      \
+                                   __macro(rocblas_dtrsv)                      \
+                                   __macro(rocblas_ctrsv)                      \
+                                   __macro(rocblas_ztrsv)                      \
+                                   __macro(rocblas_stpsv)                      \
+                                   __macro(rocblas_dtpsv)                      \
+                                   __macro(rocblas_ctpsv)                      \
+                                   __macro(rocblas_ztpsv)                      \
+                                   __macro(rocblas_stbsv)                      \
+                                   __macro(rocblas_dtbsv)                      \
+                                   __macro(rocblas_ctbsv)                      \
+                                   __macro(rocblas_ztbsv)                      \
+                                   __macro(rocblas_ssymv)                      \
+                                   __macro(rocblas_dsymv)                      \
+                                   __macro(rocblas_csymv)                      \
+                                   __macro(rocblas_zsymv)                      \
+                                   __macro(rocblas_chemv)                      \
+                                   __macro(rocblas_zhemv)                      \
+                                   __macro(rocblas_ssbmv)                      \
+                                   __macro(rocblas_dsbmv)                      \
+                                   __macro(rocblas_chbmv)                      \
+                                   __macro(rocblas_zhbmv)                      \
+                                   __macro(rocblas_sspmv)                      \
+                                   __macro(rocblas_dspmv)                      \
+                                   __macro(rocblas_chpmv)                      \
+                                   __macro(rocblas_zhpmv) */                   \
+      __macro(rocblas_sger)                                                    \
+          __macro(rocblas_dger) /*  __macro(rocblas_cgeru)                     \
+                                  __macro(rocblas_cgerc)                       \
+                                  __macro(rocblas_zgeru)                       \
+                                  __macro(rocblas_zgerc)                    */ \
+      __macro(rocblas_ssyr)                                                    \
+          __macro(rocblas_dsyr) /*  __macro(rocblas_csyr)                      \
+                                  __macro(rocblas_zsyr)                        \
+                                  __macro(rocblas_cher)                        \
+                                  __macro(rocblas_zher)                        \
+                                  __macro(rocblas_sspr)                        \
+                                  __macro(rocblas_dspr)                        \
+                                  __macro(rocblas_chpr)                        \
+                                  __macro(rocblas_zhpr)                        \
+                                  __macro(rocblas_ssyr2)                       \
+                                  __macro(rocblas_dsyr2)                       \
+                                  __macro(rocblas_csyr2)                       \
+                                  __macro(rocblas_zsyr2)                       \
+                                  __macro(rocblas_cher2)                       \
+                                  __macro(rocblas_zher2)                       \
+                                  __macro(rocblas_sspr2)                       \
+                                  __macro(rocblas_dspr2)                       \
+                                  __macro(rocblas_chpr2)                       \
+                                  __macro(rocblas_zhpr2)                    */ \
+      __macro(rocblas_sgemm) __macro(rocblas_dgemm)                            \
+          __macro(rocblas_hgemm) /*  __macro(rocblas_cgemm)                    \
+                                   __macro(rocblas_zgemm)                      \
+                                   __macro(rocblas_ssyrk)                      \
+                                   __macro(rocblas_dsyrk)                      \
+                                   __macro(rocblas_csyrk)                      \
+                                   __macro(rocblas_zsyrk)                      \
+                                   __macro(rocblas_cherk)                      \
+                                   __macro(rocblas_zherk)                      \
+                                   __macro(rocblas_ssyr2k)                     \
+                                   __macro(rocblas_dsyr2k)                     \
+                                   __macro(rocblas_csyr2k)                     \
+                                   __macro(rocblas_zsyr2k)                     \
+                                   __macro(rocblas_cher2k)                     \
+                                   __macro(rocblas_zher2k)                     \
+                                   __macro(rocblas_ssyrkx)                     \
+                                   __macro(rocblas_dsyrkx)                     \
+                                   __macro(rocblas_csyrkx)                     \
+                                   __macro(rocblas_zsyrkx)                     \
+                                   __macro(rocblas_cherkx)                     \
+                                   __macro(rocblas_zherkx)                     \
+                                   __macro(rocblas_ssymm)                      \
+                                   __macro(rocblas_dsymm)                      \
+                                   __macro(rocblas_csymm)                      \
+                                   __macro(rocblas_zsymm)                      \
+                                   __macro(rocblas_chemm)                      \
+                                   __macro(rocblas_zhemm) */                   \
+      __macro(rocblas_strsm)                                                   \
+          __macro(rocblas_dtrsm) /*  __macro(rocblas_ctrsm)                    \
+                                   __macro(rocblas_ztrsm)                      \
+                                   __macro(rocblas_strmm)                      \
+                                   __macro(rocblas_dtrmm)                      \
+                                   __macro(rocblas_ctrmm)                      \
+                                   __macro(rocblas_ztrmm) */                   \
+      __macro(rocblas_sgeam)                                                   \
+          __macro(rocblas_dgeam) /*  __macro(rocblas_cgeam)                    \
+                                   __macro(rocblas_zgeam)                      \
+                                   __macro(rocblas_sdgmm)                      \
+                                   __macro(rocblas_ddgmm)                      \
+                                   __macro(rocblas_cdgmm)                      \
+                                   __macro(rocblas_zdgmm) */
+
+STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_create_handle)
+STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_destroy_handle)
+STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_stream)
+// STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode)
+// STREAM_EXECUTOR_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode)
+// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_batched)
+STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
+STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
+// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
+STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
+// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
+// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
+ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
+
+}  // namespace wrap
+
+static string ToString(rocblas_status status) {
+  switch (status) {
+    case rocblas_status_success:
+      return "rocblas_status_success";
+    case rocblas_status_invalid_handle:
+      return "rocblas_status_invalid_handle";
+    case rocblas_status_not_implemented:
+      return "rocblas_status_not_implemented";
+    case rocblas_status_invalid_pointer:
+      return "rocblas_status_invalid_pointer";
+    case rocblas_status_invalid_size:
+      return "rocblas_status_invalid_size";
+    case rocblas_status_memory_error:
+      return "rocblas_status_memory_error";
+    case rocblas_status_internal_error:
+      return "rocblas_status_internal_error";
+    default:
+      return absl::StrCat("<invalid rocBLAS status: ", status, ">");
+  }
+}
+
+bool ROCMBlas::Init() {
+  rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_);
+  if (ret != rocblas_status_success) {
+    LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret);
+    return false;
+  }
+
+  return true;
+}
+
+ROCMBlas::ROCMBlas(gpu::GpuExecutor* parent)
+    : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
+
+ROCMBlas::~ROCMBlas() {
+  if (blas_ != nullptr) {
+    wrap::rocblas_destroy_handle(parent_, blas_);
+  }
+}
+
+bool ROCMBlas::SetStream(Stream *stream) {
+  CHECK(stream != nullptr);
+  CHECK(AsGpuStreamValue(stream) != nullptr);
+  CHECK(blas_ != nullptr);
+  rocblas_status ret =
+      wrap::rocblas_set_stream(parent_, blas_, AsGpuStreamValue(stream));
+  if (ret != rocblas_status_success) {
+    LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret);
+    return false;
+  }
+
+  return true;
+}
+
+namespace {
+
+// Helper functions transforming blas arguments into rocBLAS arguments.
+
+rocblas_operation ROCMBlasTranspose(blas::Transpose trans) {
+  switch (trans) {
+    case blas::Transpose::kNoTranspose:
+      return rocblas_operation_none;
+    case blas::Transpose::kTranspose:
+      return rocblas_operation_transpose;
+    case blas::Transpose::kConjugateTranspose:
+      return rocblas_operation_conjugate_transpose;
+    default:
+      LOG(FATAL) << "Invalid value of blas::Transpose.";
+  }
+}
+
+rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) {
+  switch (uplo) {
+    case blas::UpperLower::kUpper:
+      return rocblas_fill_upper;
+    case blas::UpperLower::kLower:
+      return rocblas_fill_lower;
+    default:
+      LOG(FATAL) << "Invalid value of blas::UpperLower.";
+  }
+}
+
+rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) {
+  switch (diag) {
+    case blas::Diagonal::kUnit:
+      return rocblas_diagonal_unit;
+    case blas::Diagonal::kNonUnit:
+      return rocblas_diagonal_non_unit;
+    default:
+      LOG(FATAL) << "Invalid value of blas::Diagonal.";
+  }
+}
+
+rocblas_side ROCMBlasSide(blas::Side side) {
+  switch (side) {
+    case blas::Side::kLeft:
+      return rocblas_side_left;
+    case blas::Side::kRight:
+      return rocblas_side_right;
+    default:
+      LOG(FATAL) << "Invalid value of blas::Side.";
+  }
+}
+
+}  // namespace
+
+template <typename FuncT, typename... Args>
+bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
+                                  bool pointer_mode_host, bool err_on_failure,
+                                  Args... args) {
+  mutex_lock lock{mu_};
+
+  CHECK(blas_ != nullptr);
+  if (!SetStream(stream)) {
+    return false;
+  }
+
+  rocblas_status ret = rocblas_func(parent_, blas_, args...);
+  if (err_on_failure && ret != rocblas_status_success) {
+    LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": "
+               << ToString(ret);
+  }
+  return ret == rocblas_status_success;
+}
+
+bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<float> &x, int incx,
+                          DeviceMemory<float> *result) {
+  return DoBlasInternal(wrap::rocblas_sasum, stream,
+                        false /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<double> &x, int incx,
+                          DeviceMemory<double> *result) {
+  return DoBlasInternal(wrap::rocblas_dasum, stream,
+                        false /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          DeviceMemory<float> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          DeviceMemory<double> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the ASUM operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
+                          const DeviceMemory<float> &x, int incx,
+                          DeviceMemory<float> *y, int incy) {
+  return DoBlasInternal(wrap::rocblas_saxpy, stream,
+                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        GpuMemory(x), incx, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
+                          const DeviceMemory<double> &x, int incx,
+                          DeviceMemory<double> *y, int incy) {
+  return DoBlasInternal(wrap::rocblas_daxpy, stream,
+                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        GpuMemory(x), incx, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the AXPY operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<float> &x, int incx,
+                          DeviceMemory<float> *y, int incy) {
+  return DoBlasInternal(wrap::rocblas_scopy, stream,
+                        true /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<double> &x, int incx,
+                          DeviceMemory<double> *y, int incy) {
+  return DoBlasInternal(wrap::rocblas_dcopy, stream,
+                        true /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the COPY operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
+                         const DeviceMemory<float> &x, int incx,
+                         const DeviceMemory<float> &y, int incy,
+                         DeviceMemory<float> *result) {
+  return DoBlasInternal(
+      wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
+      GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
+                         const DeviceMemory<double> &x, int incx,
+                         const DeviceMemory<double> &y, int incy,
+                         DeviceMemory<double> *result) {
+  return DoBlasInternal(
+      wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
+      GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          const DeviceMemory<std::complex<float>> &y, int incy,
+                          DeviceMemory<std::complex<float>> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          const DeviceMemory<std::complex<double>> &y, int incy,
+                          DeviceMemory<std::complex<double>> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          const DeviceMemory<std::complex<float>> &y, int incy,
+                          DeviceMemory<std::complex<float>> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          const DeviceMemory<std::complex<double>> &y, int incy,
+                          DeviceMemory<std::complex<double>> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the DOT operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<float> &x, int incx,
+                          DeviceMemory<float> *result) {
+  return DoBlasInternal(wrap::rocblas_snrm2, stream,
+                        false /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<double> &x, int incx,
+                          DeviceMemory<double> *result) {
+  return DoBlasInternal(wrap::rocblas_dnrm2, stream,
+                        false /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          DeviceMemory<float> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          DeviceMemory<double> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the NRM2 operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
+                         DeviceMemory<float> *x, int incx,
+                         DeviceMemory<float> *y, int incy, float c, float s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
+                         DeviceMemory<double> *x, int incx,
+                         DeviceMemory<double> *y, int incy, double c,
+                         double s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
+                         DeviceMemory<std::complex<float>> *x, int incx,
+                         DeviceMemory<std::complex<float>> *y, int incy,
+                         float c, float s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count,
+                         DeviceMemory<std::complex<double>> *x, int incx,
+                         DeviceMemory<std::complex<double>> *y, int incy,
+                         double c, double s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROT operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
+                          DeviceMemory<float> *b, DeviceMemory<float> *c,
+                          DeviceMemory<float> *s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
+                          DeviceMemory<double> *b, DeviceMemory<double> *c,
+                          DeviceMemory<double> *s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
+                          DeviceMemory<std::complex<float>> *b,
+                          DeviceMemory<float> *c,
+                          DeviceMemory<std::complex<float>> *s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
+                          DeviceMemory<std::complex<double>> *b,
+                          DeviceMemory<double> *c,
+                          DeviceMemory<std::complex<double>> *s) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTG operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
+                          DeviceMemory<float> *x, int incx,
+                          DeviceMemory<float> *y, int incy,
+                          const DeviceMemory<float> &param) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count,
+                          DeviceMemory<double> *x, int incx,
+                          DeviceMemory<double> *y, int incy,
+                          const DeviceMemory<double> &param) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTM operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
+                           DeviceMemory<float> *d2, DeviceMemory<float> *x1,
+                           const DeviceMemory<float> &y1,
+                           DeviceMemory<float> *param) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
+                           DeviceMemory<double> *d2, DeviceMemory<double> *x1,
+                           const DeviceMemory<double> &y1,
+                           DeviceMemory<double> *param) {
+  LOG(ERROR) << "rocBLAS does not currently support the ROTMG operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
+                          DeviceMemory<float> *x, int incx) {
+  return DoBlasInternal(wrap::rocblas_sscal, stream,
+                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        GpuMemoryMutable(x), incx);
+}
+
+bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
+                          DeviceMemory<double> *x, int incx) {
+  return DoBlasInternal(wrap::rocblas_dscal, stream,
+                        true /* = pointer_mode_host */, elem_count, &alpha,
+                        GpuMemoryMutable(x), incx);
+}
+
+bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
+                          DeviceMemory<std::complex<float>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
+                          DeviceMemory<std::complex<double>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
+                          std::complex<float> alpha,
+                          DeviceMemory<std::complex<float>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
+                          std::complex<double> alpha,
+                          DeviceMemory<std::complex<double>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+                          DeviceMemory<float> *x, int incx,
+                          DeviceMemory<float> *y, int incy) {
+  return DoBlasInternal(wrap::rocblas_sswap, stream,
+                        true /* = pointer_mode_host */, elem_count,
+                        GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+                          DeviceMemory<double> *x, int incx,
+                          DeviceMemory<double> *y, int incy) {
+  return DoBlasInternal(wrap::rocblas_dswap, stream,
+                        true /* = pointer_mode_host */, elem_count,
+                        GpuMemoryMutable(x), incx, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+                          DeviceMemory<std::complex<float>> *x, int incx,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+                          DeviceMemory<std::complex<double>> *x, int incx,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SWAP operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<float> &x, int incx,
+                           DeviceMemory<int> *result) {
+  return DoBlasInternal(wrap::rocblas_isamax, stream,
+                        false /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<double> &x, int incx,
+                           DeviceMemory<int> *result) {
+  return DoBlasInternal(wrap::rocblas_idamax, stream,
+                        false /* = pointer_mode_host */, elem_count,
+                        GpuMemory(x), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<std::complex<float>> &x, int incx,
+                           DeviceMemory<int> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
+            << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<std::complex<double>> &x,
+                           int incx, DeviceMemory<int> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the AMAX operation "
+            << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<float> &x, int incx,
+                           DeviceMemory<int> *result) {
+  return DoBlasInternal(
+      wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count,
+      GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<double> &x, int incx,
+                           DeviceMemory<int> *result) {
+  return DoBlasInternal(
+      wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count,
+      GpuComplex(GpuMemory(x)), incx, GpuMemoryMutable(result));
+}
+
+bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<std::complex<float>> &x, int incx,
+                           DeviceMemory<int> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
+            << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+                           const DeviceMemory<std::complex<double>> &x,
+                           int incx, DeviceMemory<int> *result) {
+  LOG(ERROR) << "rocBLAS does not currently support the AMIN operation "
+            << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, uint64 kl, uint64 ku, float alpha,
+                          const DeviceMemory<float> &a, int lda,
+                          const DeviceMemory<float> &x, int incx, float beta,
+                          DeviceMemory<float> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, uint64 kl, uint64 ku, double alpha,
+                          const DeviceMemory<double> &a, int lda,
+                          const DeviceMemory<double> &x, int incx, double beta,
+                          DeviceMemory<double> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, uint64 kl, uint64 ku,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, uint64 kl, uint64 ku,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the GBMV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, float alpha, const DeviceMemory<float> &a,
+                          int lda, const DeviceMemory<float> &x, int incx,
+                          float beta, DeviceMemory<float> *y, int incy) {
+  return DoBlasInternal(
+      wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
+      ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
+      incx, &beta, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, double alpha, const DeviceMemory<double> &a,
+                          int lda, const DeviceMemory<double> &x, int incx,
+                          double beta, DeviceMemory<double> *y, int incy) {
+  return DoBlasInternal(
+      wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
+      ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
+      incx, &beta, GpuMemoryMutable(y), incy);
+}
+
+bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+                          uint64 n, std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
+                         const DeviceMemory<float> &x, int incx,
+                         const DeviceMemory<float> &y, int incy,
+                         DeviceMemory<float> *a, int lda) {
+  return DoBlasInternal(
+      wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha,
+      GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
+}
+
+bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
+                         const DeviceMemory<double> &x, int incx,
+                         const DeviceMemory<double> &y, int incy,
+                         DeviceMemory<double> *a, int lda) {
+  return DoBlasInternal(
+      wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha,
+      GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(a), lda);
+}
+
+bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          const DeviceMemory<std::complex<float>> &y, int incy,
+                          DeviceMemory<std::complex<float>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the GER operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          const DeviceMemory<std::complex<double>> &y, int incy,
+                          DeviceMemory<std::complex<double>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the GER operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          const DeviceMemory<std::complex<float>> &y, int incy,
+                          DeviceMemory<std::complex<float>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          const DeviceMemory<std::complex<double>> &y, int incy,
+                          DeviceMemory<std::complex<double>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the GERU operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          uint64 k, std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          uint64 k, std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the HBMV operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the HEMV operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         float alpha,
+                         const DeviceMemory<std::complex<float>> &x, int incx,
+                         DeviceMemory<std::complex<float>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the HER operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         double alpha,
+                         const DeviceMemory<std::complex<double>> &x, int incx,
+                         DeviceMemory<std::complex<double>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the HER operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          const DeviceMemory<std::complex<float>> &y, int incy,
+                          DeviceMemory<std::complex<float>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          const DeviceMemory<std::complex<double>> &y, int incy,
+                          DeviceMemory<std::complex<double>> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the HER2 operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &ap,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &ap,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the HPMV operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         float alpha,
+                         const DeviceMemory<std::complex<float>> &x, int incx,
+                         DeviceMemory<std::complex<float>> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         double alpha,
+                         const DeviceMemory<std::complex<double>> &x, int incx,
+                         DeviceMemory<std::complex<double>> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the HPR operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &x, int incx,
+                          const DeviceMemory<std::complex<float>> &y, int incy,
+                          DeviceMemory<std::complex<float>> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
+             << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &x, int incx,
+                          const DeviceMemory<std::complex<double>> &y, int incy,
+                          DeviceMemory<std::complex<double>> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the HPR2 operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          uint64 k, float alpha, const DeviceMemory<float> &a,
+                          int lda, const DeviceMemory<float> &x, int incx,
+                          float beta, DeviceMemory<float> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
+             << "for the \"complex<float>\" dataype" ;
+
+  return false;
+}
+
+bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          uint64 k, double alpha, const DeviceMemory<double> &a,
+                          int lda, const DeviceMemory<double> &x, int incx,
+                          double beta, DeviceMemory<double> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SBMV operation "
+             << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          float alpha, const DeviceMemory<float> &ap,
+                          const DeviceMemory<float> &x, int incx, float beta,
+                          DeviceMemory<float> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          double alpha, const DeviceMemory<double> &ap,
+                          const DeviceMemory<double> &x, int incx, double beta,
+                          DeviceMemory<double> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SPMV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         float alpha, const DeviceMemory<float> &x, int incx,
+                         DeviceMemory<float> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
+             << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         double alpha, const DeviceMemory<double> &x, int incx,
+                         DeviceMemory<double> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the SPR operation "
+             << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          float alpha, const DeviceMemory<float> &x, int incx,
+                          const DeviceMemory<float> &y, int incy,
+                          DeviceMemory<float> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
+             << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          double alpha, const DeviceMemory<double> &x, int incx,
+                          const DeviceMemory<double> &y, int incy,
+                          DeviceMemory<double> *ap) {
+  LOG(ERROR) << "rocBLAS does not currently support the SPR2 operation "
+             << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          float alpha, const DeviceMemory<float> &a, int lda,
+                          const DeviceMemory<float> &x, int incx, float beta,
+                          DeviceMemory<float> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          double alpha, const DeviceMemory<double> &a, int lda,
+                          const DeviceMemory<double> &x, int incx, double beta,
+                          DeviceMemory<double> *y, int incy) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYMV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         float alpha, const DeviceMemory<float> &x, int incx,
+                         DeviceMemory<float> *a, int lda) {
+  return DoBlasInternal(wrap::rocblas_ssyr, stream,
+                        true /* = pointer_mode_host */,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemoryMutable(a), lda);
+}
+
+bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
+                         double alpha, const DeviceMemory<double> &x, int incx,
+                         DeviceMemory<double> *a, int lda) {
+  return DoBlasInternal(wrap::rocblas_dsyr, stream,
+                        true /* = pointer_mode_host */,
+                        ROCMBlasUpperLower(uplo), n, &alpha, GpuMemory(x), incx,
+                        GpuMemoryMutable(a), lda);
+}
+
+bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          float alpha, const DeviceMemory<float> &x, int incx,
+                          const DeviceMemory<float> &y, int incy,
+                          DeviceMemory<float> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+                          double alpha, const DeviceMemory<double> &x, int incx,
+                          const DeviceMemory<double> &y, int incy,
+                          DeviceMemory<double> *a, int lda) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYR2 operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<float> &a, int lda,
+                          DeviceMemory<float> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<double> &a, int lda,
+                          DeviceMemory<double> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<std::complex<float>> &a,
+                          int lda, DeviceMemory<std::complex<float>> *x,
+                          int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<std::complex<double>> &a,
+                          int lda, DeviceMemory<std::complex<double>> *x,
+                          int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBMV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<float> &a, int lda,
+                          DeviceMemory<float> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<double> &a, int lda,
+                          DeviceMemory<double> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<std::complex<float>> &a,
+                          int lda, DeviceMemory<std::complex<float>> *x,
+                          int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          uint64 k, const DeviceMemory<std::complex<double>> &a,
+                          int lda, DeviceMemory<std::complex<double>> *x,
+                          int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TBSV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+                          int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<double> &ap,
+                          DeviceMemory<double> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<float>> &ap,
+                          DeviceMemory<std::complex<float>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<double>> &ap,
+                          DeviceMemory<std::complex<double>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPMV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+                          int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<double> &ap,
+                          DeviceMemory<double> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<float>> &ap,
+                          DeviceMemory<std::complex<float>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<double>> &ap,
+                          DeviceMemory<std::complex<double>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TPSV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<float> &a, int lda,
+                          DeviceMemory<float> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<double> &a, int lda,
+                          DeviceMemory<double> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          DeviceMemory<std::complex<float>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          DeviceMemory<std::complex<double>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<float> &a, int lda,
+                          DeviceMemory<float> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<double> &a, int lda,
+                          DeviceMemory<double> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          DeviceMemory<std::complex<float>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, blas::Diagonal diag, uint64 n,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          DeviceMemory<std::complex<double>> *x, int incx) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRSV operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemm(
+    Stream *stream, blas::Transpose transa,
+    blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+    float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+    const DeviceMemory<Eigen::half> &b, int ldb, float beta,
+    DeviceMemory<Eigen::half> *c, int ldc) {
+  VLOG(1) << port::Printf(
+      "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
+      "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
+      "c=%p ldc=%d",
+      static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
+      a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
+  if (transa == blas::Transpose::kNoTranspose) {
+    if (lda < static_cast<int64>(m)) {
+      LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
+                      "precondition violation";
+    }
+  } else {
+    if (lda < static_cast<int64>(k)) {
+      LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
+                   << ") (transpose case); precondition violation";
+    }
+  }
+  if (transb == blas::Transpose::kNoTranspose) {
+    if (ldb < static_cast<int64>(k)) {
+      LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
+                   << ") (no transpose case); precondition violation";
+    }
+  } else {
+    if (ldb < static_cast<int64>(n)) {
+      LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
+                      "precondition violation";
+    }
+  }
+  const Eigen::half alpha_half(alpha);
+  const Eigen::half beta_half(beta);
+  return DoBlasInternal(
+      wrap::rocblas_hgemm, stream, true /* = pointer_mode_host */,
+      ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
+      reinterpret_cast<const rocblas_half*>(&alpha_half),
+      reinterpret_cast<const rocblas_half*>(GpuMemory(a)), lda,
+      reinterpret_cast<const rocblas_half*>(GpuMemory(b)), ldb,
+      reinterpret_cast<const rocblas_half*>(&beta_half),
+      reinterpret_cast<rocblas_half*>(GpuMemoryMutable(c)), ldc);
+}
+
+bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+                          float alpha, const DeviceMemory<float> &a, int lda,
+                          const DeviceMemory<float> &b, int ldb, float beta,
+                          DeviceMemory<float> *c, int ldc) {
+  VLOG(1) << port::Printf(
+      "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
+      "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
+      "c=%p ldc=%d",
+      static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
+      a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
+  if (transa == blas::Transpose::kNoTranspose) {
+    if (lda < static_cast<int64>(m)) {
+      LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
+                      "precondition violation";
+    }
+  } else {
+    if (lda < static_cast<int64>(k)) {
+      LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
+                   << ") (transpose case); precondition violation";
+    }
+  }
+  if (transb == blas::Transpose::kNoTranspose) {
+    if (ldb < static_cast<int64>(k)) {
+      LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
+                   << ") (no transpose case); precondition violation";
+    }
+  } else {
+    if (ldb < static_cast<int64>(n)) {
+      LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
+                      "precondition violation";
+    }
+  }
+  return DoBlasInternal(
+      wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */,
+      ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
+      GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
+}
+
+bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+                          double alpha, const DeviceMemory<double> &a, int lda,
+                          const DeviceMemory<double> &b, int ldb, double beta,
+                          DeviceMemory<double> *c, int ldc) {
+  return DoBlasInternal(
+      wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
+      ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
+      GpuMemory(a), lda, GpuMemory(b), ldb, &beta, GpuMemoryMutable(c), ldc);
+}
+
+bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &b, int ldb,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+                          blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &b, int ldb,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemvWithProfiling(
+    Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
+    const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
+    int incx, float beta, DeviceMemory<float> *y, int incy,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+                                     incx, beta, y, incy,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemvWithProfiling(
+    Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
+    const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
+    int incx, double beta, DeviceMemory<double> *y, int incy,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+                                     incx, beta, y, incy,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemvWithProfiling(
+    Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+    std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
+    int lda, const DeviceMemory<std::complex<float>> &x, int incx,
+    std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+                                     incx, beta, y, incy,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemvWithProfiling(
+    Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+    std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
+    int lda, const DeviceMemory<std::complex<double>> &x, int incx,
+    std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+                                     incx, beta, y, incy,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemmWithProfiling(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+    int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
+    DeviceMemory<Eigen::half> *c, int ldc,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+                                     lda, b, ldb, beta, c, ldc,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemmWithProfiling(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+    const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+    int ldc, blas::ProfileResult *output_profile_result) {
+  return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+                                     lda, b, ldb, beta, c, ldc,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemmWithProfiling(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+    const DeviceMemory<double> &b, int ldb, double beta,
+    DeviceMemory<double> *c, int ldc,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+                                     lda, b, ldb, beta, c, ldc,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemmWithProfiling(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, std::complex<float> alpha,
+    const DeviceMemory<std::complex<float>> &a, int lda,
+    const DeviceMemory<std::complex<float>> &b, int ldb,
+    std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+                                     lda, b, ldb, beta, c, ldc,
+                                     output_profile_result);
+}
+
+bool ROCMBlas::DoBlasGemmWithProfiling(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, std::complex<double> alpha,
+    const DeviceMemory<std::complex<double>> &a, int lda,
+    const DeviceMemory<std::complex<double>> &b, int ldb,
+    std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+    blas::ProfileResult *output_profile_result) {
+  return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+                                     lda, b, ldb, beta, c, ldc,
+                                     output_profile_result);
+}
+
+template <typename T>
+bool ROCMBlas::DoBlasGemvWithProfilingImpl(
+    Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
+    const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
+    const T &beta, DeviceMemory<T> *y, int incy,
+    blas::ProfileResult *output_profile_result) {
+  // ROCM TODO: properly implement the interface
+  return false;
+}
+
+template <typename T, typename ParamType>
+bool ROCMBlas::DoBlasGemmWithProfilingImpl(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
+    int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
+    DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
+  // ROCM TODO: properly implement the interface
+  return false;
+}
+
+template <typename InT, typename OutT, typename CompT>
+bool ROCMBlas::DoBlasGemmWithAlgorithmImpl(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
+    const DeviceMemory<InT> &b, int ldb, const CompT &beta,
+    DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
+    blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+  // ROCM TODO: properly implement the interface
+  return false;
+}
+
+bool ROCMBlas::GetBlasGemmAlgorithms(
+    std::vector<blas::AlgorithmType> *out_algorithms) {
+  // ROCM TODO: properly implement the interface
+  return true;
+}
+
+bool ROCMBlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
+    const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
+    int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c,
+    int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
+             << "for the \"int8\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
+    const DeviceMemory<Eigen::half> &a, int lda,
+    const DeviceMemory<Eigen::half> &b, int ldb,
+    const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
+    int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
+             << "for the \"half\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
+    const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
+    int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
+    int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
+             << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
+    const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
+    int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
+    int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
+             << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
+    const DeviceMemory<std::complex<float>> &a, int lda,
+    const DeviceMemory<std::complex<float>> &b, int ldb,
+    const HostOrDeviceScalar<std::complex<float>> &beta,
+    DeviceMemory<std::complex<float>> *c, int ldc,
+    blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemmWithAlgorithm(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
+    const DeviceMemory<std::complex<double>> &a, int lda,
+    const DeviceMemory<std::complex<double>> &b, int ldb,
+    const HostOrDeviceScalar<std::complex<double>> &beta,
+    DeviceMemory<std::complex<double>> *c, int ldc,
+    blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+    blas::ProfileResult *output_profile_result) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMwithAlgorithm operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+template <typename T>
+struct EigenHalfToRocBlasHalf {
+  using type = T;
+};
+  
+template <>
+struct EigenHalfToRocBlasHalf<Eigen::half> {
+  using type = rocblas_half;
+};
+
+  template <typename T, typename FuncT>
+port::Status ROCMBlas::DoBlasGemmBatchedInternal(
+    FuncT rocblas_func, Stream *stream, blas::Transpose transa,
+    blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+    const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
+    const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
+    T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+    int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+
+  // MAPPED_T will be same as T for all types except Eigen::Half
+  // for T = Eigen::half, MAPPED_T = rocblas_half
+  using MAPPED_T = typename EigenHalfToRocBlasHalf<T>::type;  
+    
+  // Alocate local vectors to hold device pointers to matrices
+  std::vector<MAPPED_T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
+  for (int i = 0; i < batch_count; ++i) {
+    // static_cast does work when converting Eigen::half* to rocblas_half*,
+    // hence the use od reinterpret_cast
+    a_raw_ptrs.push_back(reinterpret_cast<MAPPED_T *>(a_ptrs_to_wrappers[i]->opaque()));
+    b_raw_ptrs.push_back(reinterpret_cast<MAPPED_T *>(b_ptrs_to_wrappers[i]->opaque()));
+    c_raw_ptrs.push_back(reinterpret_cast<MAPPED_T *>(c_ptrs_to_wrappers[i]->opaque()));
+  }
+
+  //  batch_count <= 1 is base case, no definable matrix stride, set it same as ld*
+  long long bsa = lda;
+  long long bsb = ldb;
+  long long bsc = ldc;
+  bool bsa_is_constant = true;
+  bool bsb_is_constant = true;
+  bool bsc_is_constant = true;
+
+  if( batch_count > 1 )
+  {
+    // Remember first stride; if any other stride is different that this one, KABLAM
+    bsa = a_raw_ptrs[1] - a_raw_ptrs[0];
+    bsb = b_raw_ptrs[1] - b_raw_ptrs[0];
+    bsc = c_raw_ptrs[1] - c_raw_ptrs[0];
+
+    //  Loop to verify that batched strides are constant
+    //  All the test cases from batch_matmul_op_test.py seem to satisfy this requirement of a constant
+    //  stride.  If this can be proven globally, then this loop check can be safely removed
+    for( int i=1; i < batch_count-1; ++i )
+    {
+      long long iterative_bsa = a_raw_ptrs[i+1] - a_raw_ptrs[i];
+      if( iterative_bsa != bsa)
+      {
+        bsa_is_constant = false;
+        break;
+      }
+
+      long long iterative_bsb = b_raw_ptrs[i+1] - b_raw_ptrs[i];
+      if( iterative_bsb != bsb)
+      {
+        bsb_is_constant = false;
+        break;
+      }
+
+      long long iterative_bsc = c_raw_ptrs[i+1] - c_raw_ptrs[i];
+      if( iterative_bsc != bsc)
+      {
+        bsc_is_constant = false;
+        break;
+      }
+    }
+  }
+
+  assert(!(ldc < m || bsc < ldc * n));
+
+  if (ROCMBlasTranspose(transa) == rocblas_operation_none)
+      assert(!(lda < m || bsa < lda * k));
+  else
+      assert(!(lda < k || bsa < lda * m));
+
+  if (ROCMBlasTranspose(transb) == rocblas_operation_none)
+      assert(!(ldb < k || bsb < ldb * n));
+  else
+      assert(!(ldb < n || bsc < ldc * k));
+
+
+  MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
+  MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
+  
+  if(bsa_is_constant && bsb_is_constant && bsc_is_constant)
+  {
+    bool ok = DoBlasInternal(
+            rocblas_func, stream, true /* = pointer_mode_host */,
+            ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
+            GpuComplex(alpha_ptr), a_raw_ptrs[ 0 ], lda, bsa,
+            b_raw_ptrs[ 0 ], ldb, bsb, GpuComplex(beta_ptr),
+            c_raw_ptrs[ 0 ], ldc, bsc, batch_count);
+
+      if (ok) {
+        return port::Status::OK();
+      }
+  }
+ 
+  return port::Status(port::error::INTERNAL,
+                      "failed BLAS call, see log for details");
+}
+
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+    const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb,
+    float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c,
+    int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+
+  const Eigen::half alpha_half(alpha);
+  const Eigen::half beta_half(beta);
+  
+  port::Status status = DoBlasGemmBatchedInternal(
+      wrap::rocblas_hgemm_strided_batched, stream, transa, transb, m, n, k,
+      alpha_half, a, lda, b, ldb, beta_half, c, ldc, batch_count,
+      scratch_allocator);
+  if (!status.ok()) {
+    LOG(ERROR) << status;
+  }
+  
+  return status.ok();
+}
+
+
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha,
+    const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
+    const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
+    const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
+    int batch_count, ScratchAllocator *scratch_allocator) {
+  port::Status status = DoBlasGemmBatchedInternal(
+      wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, alpha,
+      a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
+      scratch_allocator);
+  if (!status.ok()) {
+    LOG(ERROR) << status;
+  }
+  return status.ok();
+}
+
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, double alpha,
+    const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
+    const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
+    double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
+    int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+  port::Status status = DoBlasGemmBatchedInternal(
+      wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, alpha,
+      a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
+      scratch_allocator);
+  if (!status.ok()) {
+    LOG(ERROR) << status;
+  }
+  return status.ok();
+}
+
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, std::complex<float> alpha,
+    const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
+    int lda,
+    const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
+    int ldb, std::complex<float> beta,
+    const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
+    int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasGemmBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, std::complex<double> alpha,
+    const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
+    int lda,
+    const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
+    int ldb, std::complex<double> beta,
+    const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
+    int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+  LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, uint64 m, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &b, int ldb,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, uint64 m, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &b, int ldb,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the HEMM operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, uint64 n, uint64 k,
+                          float alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          float beta, DeviceMemory<std::complex<float>> *c,
+                          int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, uint64 n, uint64 k,
+                          double alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          double beta, DeviceMemory<std::complex<double>> *c,
+                          int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the HERK operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
+                           blas::Transpose trans, uint64 n, uint64 k,
+                           std::complex<float> alpha,
+                           const DeviceMemory<std::complex<float>> &a, int lda,
+                           const DeviceMemory<std::complex<float>> &b, int ldb,
+                           float beta, DeviceMemory<std::complex<float>> *c,
+                           int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
+                           blas::Transpose trans, uint64 n, uint64 k,
+                           std::complex<double> alpha,
+                           const DeviceMemory<std::complex<double>> &a, int lda,
+                           const DeviceMemory<std::complex<double>> &b, int ldb,
+                           double beta, DeviceMemory<std::complex<double>> *c,
+                           int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the HER2K operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, uint64 m, uint64 n,
+                          float alpha, const DeviceMemory<float> &a, int lda,
+                          const DeviceMemory<float> &b, int ldb, float beta,
+                          DeviceMemory<float> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, uint64 m, uint64 n,
+                          double alpha, const DeviceMemory<double> &a, int lda,
+                          const DeviceMemory<double> &b, int ldb, double beta,
+                          DeviceMemory<double> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, uint64 m, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          const DeviceMemory<std::complex<float>> &b, int ldb,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, uint64 m, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          const DeviceMemory<std::complex<double>> &b, int ldb,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYMM operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, uint64 n, uint64 k,
+                          float alpha, const DeviceMemory<float> &a, int lda,
+                          float beta, DeviceMemory<float> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, uint64 n, uint64 k,
+                          double alpha, const DeviceMemory<double> &a, int lda,
+                          double beta, DeviceMemory<double> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, uint64 n, uint64 k,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          std::complex<float> beta,
+                          DeviceMemory<std::complex<float>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+                          blas::Transpose trans, uint64 n, uint64 k,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          std::complex<double> beta,
+                          DeviceMemory<std::complex<double>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYRK operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+                           blas::Transpose trans, uint64 n, uint64 k,
+                           float alpha, const DeviceMemory<float> &a, int lda,
+                           const DeviceMemory<float> &b, int ldb, float beta,
+                           DeviceMemory<float> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+                           blas::Transpose trans, uint64 n, uint64 k,
+                           double alpha, const DeviceMemory<double> &a, int lda,
+                           const DeviceMemory<double> &b, int ldb, double beta,
+                           DeviceMemory<double> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+                           blas::Transpose trans, uint64 n, uint64 k,
+                           std::complex<float> alpha,
+                           const DeviceMemory<std::complex<float>> &a, int lda,
+                           const DeviceMemory<std::complex<float>> &b, int ldb,
+                           std::complex<float> beta,
+                           DeviceMemory<std::complex<float>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+                           blas::Transpose trans, uint64 n, uint64 k,
+                           std::complex<double> alpha,
+                           const DeviceMemory<std::complex<double>> &a, int lda,
+                           const DeviceMemory<std::complex<double>> &b, int ldb,
+                           std::complex<double> beta,
+                           DeviceMemory<std::complex<double>> *c, int ldc) {
+  LOG(ERROR) << "rocBLAS does not currently support the SYR2K operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n, float alpha,
+                          const DeviceMemory<float> &a, int lda,
+                          DeviceMemory<float> *b, int ldb) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n, double alpha,
+                          const DeviceMemory<double> &a, int lda,
+                          DeviceMemory<double> *b, int ldb) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          DeviceMemory<std::complex<float>> *b, int ldb) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          DeviceMemory<std::complex<double>> *b, int ldb) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRMM operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n, float alpha,
+                          const DeviceMemory<float> &a, int lda,
+                          DeviceMemory<float> *b, int ldb) {
+  return DoBlasInternal(
+      wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
+      ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+      ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<float*>(GpuMemory(a)),
+      lda, GpuMemoryMutable(b), ldb);
+}
+
+bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n, double alpha,
+                          const DeviceMemory<double> &a, int lda,
+                          DeviceMemory<double> *b, int ldb) {
+  return DoBlasInternal(
+      wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
+      ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
+      ROCMBlasDiagonal(diag), m, n, &alpha, const_cast<double*>(GpuMemory(a)),
+      lda, GpuMemoryMutable(b), ldb);
+}
+
+bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n,
+                          std::complex<float> alpha,
+                          const DeviceMemory<std::complex<float>> &a, int lda,
+                          DeviceMemory<std::complex<float>> *b, int ldb) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+
+bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
+                          blas::UpperLower uplo, blas::Transpose transa,
+                          blas::Diagonal diag, uint64 m, uint64 n,
+                          std::complex<double> alpha,
+                          const DeviceMemory<std::complex<double>> &a, int lda,
+                          DeviceMemory<std::complex<double>> *b, int ldb) {
+  LOG(ERROR) << "rocBLAS does not currently support the TRSM operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+bool ROCMBlas::DoBlasGemmStridedBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+    int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+    int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+    int64 stride_c, int batch_count) {
+  LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation "
+	     << "for the \"Eigen::half\" dataype" ;
+  return false;
+}
+ bool ROCMBlas::DoBlasGemmStridedBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+    int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+    float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+    int batch_count) {
+  LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation "
+	     << "for the \"float\" dataype" ;
+  return false;
+}
+ bool ROCMBlas::DoBlasGemmStridedBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+    int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+    double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+    int batch_count) {
+  LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation "
+	     << "for the \"double\" dataype" ;
+  return false;
+}
+ bool ROCMBlas::DoBlasGemmStridedBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, std::complex<float> alpha,
+    const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+    const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+    std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+    int64 stride_c, int batch_count) {
+  LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation "
+	     << "for the \"complex<float>\" dataype" ;
+  return false;
+}
+ bool ROCMBlas::DoBlasGemmStridedBatched(
+    Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+    uint64 n, uint64 k, std::complex<double> alpha,
+    const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+    const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+    std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+    int64 stride_c, int batch_count) {
+  LOG(ERROR) << "rocBLAS does not currently support the DoBlasGemmStridedBatched operation "
+	     << "for the \"complex<double>\" dataype" ;
+  return false;
+}
+}  // namespace gpu
+
+void initialize_rocblas() {
+  port::Status status =
+      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
+          rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
+          [](internal::StreamExecutorInterface* parent) -> blas::BlasSupport* {
+            gpu::GpuExecutor* rocm_executor =
+                dynamic_cast<gpu::GpuExecutor*>(parent);
+            if (rocm_executor == nullptr) {
+              LOG(ERROR)
+                  << "Attempting to initialize an instance of the rocBLAS "
+                  << "support library with a non-ROCM StreamExecutor";
+              return nullptr;
+            }
+
+            gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
+            if (!blas->Init()) {
+              // Note: Init() will log a more specific error.
+              delete blas;
+              return nullptr;
+            }
+            return blas;
+          });
+
+  if (!status.ok()) {
+    LOG(ERROR) << "Unable to register rocBLAS factory: "
+               << status.error_message();
+  }
+
+  PluginRegistry::Instance()->SetDefaultFactory(
+      rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
+}
+
+}  // namespace stream_executor
+
+REGISTER_MODULE_INITIALIZER(register_rocblas,
+                            { stream_executor::initialize_rocblas(); });
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h
new file mode 100644
index 00000000000..8e577127450
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/rocm_blas.h
@@ -0,0 +1,159 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS library
+// capabilities, and is only included into ROCM implementation code -- it will
+// not introduce rocm headers into other code.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+
+namespace stream_executor {
+
+class Stream;
+
+namespace gpu {
+
+// Opaque and unique identifier for the rocBLAS plugin.
+extern const PluginId kRocBlasPlugin;
+
+class GpuExecutor;
+
+// BLAS plugin for ROCM platform via rocBLAS library.
+//
+// This satisfies the platform-agnostic BlasSupport interface.
+//
+// Note that the rocBLAS handle that this encapsulates is implicitly tied to the
+// context (and, as a result, the device) that the parent GpuExecutor is tied
+// to. This simply happens as an artifact of creating the rocBLAS handle when a
+// ROCM context is active.
+//
+// Thread-safe post-initialization.
+class ROCMBlas : public blas::BlasSupport {
+ public:
+  explicit ROCMBlas(GpuExecutor* parent);
+
+  // Allocates a rocBLAS handle.
+  bool Init();
+
+  // Releases the rocBLAS handle, if present.
+  ~ROCMBlas() override;
+
+  TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
+
+ private:
+  // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
+  //
+  // rocBLAS is stateful, and only be associated with one stream (in order to
+  // enqueue dispatch) at a given time. As a result, this generally must be
+  // invoked before calling into rocBLAS.
+  bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // A helper function that calls the real rocBLAS function together with error
+  // handling.
+  //
+  // rocblas_func:       rocBLAS function pointer.
+  // rocblas_name:       rocBLAS function name.
+  // stream:             Stream to enqueue the BLAS operation onto.
+  // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
+  //                     (true) or device (false).
+  // err_on_failure:     Whether to print an error if the rocBLAS function fails.
+  // args:               Arguments of rocBLAS function.
+  template <typename FuncT, typename... Args>
+  bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
+                          bool pointer_mode_host, bool err_on_failure,
+                          Args... args);
+
+  // Convenience functions that call DoBlasInternalImpl with different values
+  // for err_on_failure.
+  template <typename FuncT, typename... Args>
+  bool DoBlasInternal(FuncT rocblas_func, Stream *stream, bool pointer_mode_host,
+                      Args... args) {
+    return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
+                              /*err_on_failure=*/true, args...);
+  }
+  template <typename FuncT, typename... Args>
+  bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
+                               bool pointer_mode_host, Args... args) {
+    return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
+                              /*err_on_failure=*/false, args...);
+  }
+
+  // A helper function to implement DoBlasGemmBatched interfaces for generic
+  // types.
+  template <typename T, typename FuncT>
+  port::Status DoBlasGemmBatchedInternal(
+      FuncT rocblas_func, Stream *stream, blas::Transpose transa,
+      blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+      const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
+      const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+      const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
+      int batch_count, ScratchAllocator *scratch_allocator);
+
+  // Helper function for implementing DoBlasGemmWithAlgorithm.
+  //
+  // We take alpha and beta by const reference because T might be Eigen::half,
+  // and we want to avoid pulling in a dependency on Eigen.  When we pass the
+  // references to rocBLAS, we essentially reinterpret_cast to __half, which is
+  // safe because Eigen::half inherits from __half.
+  template <typename InT, typename OutT, typename CompT>
+  bool DoBlasGemmWithAlgorithmImpl(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+      uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
+      int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
+      DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
+      blas::AlgorithmType algorithm,
+      blas::ProfileResult *output_profile_result);
+
+  // Helper function for implementing DoBlasGemmWithProfiling.
+  template <typename T, typename ParamType>
+  bool DoBlasGemmWithProfilingImpl(
+      Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+      uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
+      int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
+      DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
+
+  // Helper function for implementing DoBlasGemvWithProfiling.
+  template <typename T>
+  bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
+                                   uint64 m, uint64 n, const T &alpha,
+                                   const DeviceMemory<T> &a, int lda,
+                                   const DeviceMemory<T> &x, int incx,
+                                   const T &beta, DeviceMemory<T> *y, int incy,
+                                   blas::ProfileResult *output_profile_result);
+
+  // mutex that guards the rocBLAS handle for this device.
+  mutex mu_;
+
+  // GpuExecutor which instantiated this ROCMBlas.
+  // Immutable post-initialization.
+  GpuExecutor* parent_;
+
+  // rocBLAS library handle on the device.
+  rocblas_handle blas_ GUARDED_BY(mu_);
+
+  SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
+};
+
+}  // namespace gpu
+}  // namespace stream_executor
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_

From 19d3ff647d2174184d49ebfe2e505ffd266d0f09 Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Thu, 31 Jan 2019 19:49:46 +0000
Subject: [PATCH 3/7] adding code for rocfft plugin

---
 tensorflow/stream_executor/rocm/BUILD       |  51 +-
 tensorflow/stream_executor/rocm/rocm_fft.cc | 588 ++++++++++++++++++++
 tensorflow/stream_executor/rocm/rocm_fft.h  | 132 +++++
 3 files changed, 750 insertions(+), 21 deletions(-)
 create mode 100644 tensorflow/stream_executor/rocm/rocm_fft.cc
 create mode 100644 tensorflow/stream_executor/rocm/rocm_fft.h

diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index 737a4429469..c0da35121f2 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -175,26 +175,35 @@ cc_library(
    alwayslink = True,
 )
 
-# FIXME: enable in future PRs
-#cc_library(
-#    name = "rocfft_plugin",
-#    srcs = ["rocm_fft.cc"],
-#    hdrs = [],
-#    visibility = ["//visibility:public"],
-#    deps = [
-#        ":rocm_platform_id",
-#        "//tensorflow/stream_executor:event",
-#        "//tensorflow/stream_executor:fft",
-#        "//tensorflow/stream_executor:plugin_registry",
-#        "//tensorflow/stream_executor:scratch_allocator",
-#        "//tenosrflow/stream_executor/gpu:gpu_stream_header",
-#        "//tensorflow/stream_executor/lib",
-#        "//tensorflow/stream_executor/platform",
-#        "//tensorflow/stream_executor/platform:dso_loader",
-#        "@local_config_rocm//rocm:rocm_headers",
-#    ] + if_static(["@local_config_rocm//rocm:rocfft"]),
-#    alwayslink = True,
-#)
+cc_library(
+   name = "rocfft_plugin",
+   srcs = if_rocm_is_configured(["rocm_fft.cc"]),
+   hdrs = if_rocm_is_configured(["rocm_fft.h"]),
+   visibility = ["//visibility:public"],
+   deps = if_rocm_is_configured([
+       ":rocm_platform_id",
+       "//tensorflow/stream_executor:event",
+       "//tensorflow/stream_executor:fft",
+       "//tensorflow/stream_executor:plugin_registry",
+       "//tensorflow/stream_executor:scratch_allocator",
+       "//tensorflow/stream_executor/gpu:gpu_activation",
+       "//tensorflow/stream_executor/gpu:gpu_helpers_header",
+       "//tensorflow/stream_executor/gpu:gpu_executor_header",
+       "//tensorflow/stream_executor/gpu:gpu_stream_header",
+       "//tensorflow/stream_executor/gpu:gpu_kernel_header",
+       "//tensorflow/stream_executor/lib",
+       "//tensorflow/stream_executor/platform",
+       "//tensorflow/stream_executor/platform:dso_loader",
+       "@local_config_rocm//rocm:rocm_headers",
+   ] + if_static([
+       "@local_config_rocm//rocm:rocfft"
+       # Delete the following line once we switch the rocfft library from
+       # being dynamically linked (current behaviour) to being dynamically
+       # loaded (future behaviour)
+       ], ["@local_config_rocm//rocm:rocfft"
+   ])),
+   alwayslink = True,
+)
 
 # FIXME: enable in future PRs
 #cc_library(
@@ -263,7 +272,7 @@ cc_library(
     deps = if_rocm_is_configured([
         # FIXME: enable in future PRs
         #":miopen_plugin",
-        #":rocfft_plugin",
+        ":rocfft_plugin",
         ":rocblas_plugin",
         #":rocrand_plugin",
         ":rocm_driver",
diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc
new file mode 100644
index 00000000000..dd30911eadd
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/rocm_fft.cc
@@ -0,0 +1,588 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/rocm/rocm_fft.h"
+
+#include <complex>
+
+#include "tensorflow/stream_executor/gpu/gpu_activation.h"
+#include "tensorflow/stream_executor/gpu/gpu_executor.h"
+#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
+#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
+#include "tensorflow/stream_executor/gpu/gpu_stream.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace stream_executor {
+namespace gpu {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
+
+namespace wrap {
+
+// This macro wraps a global identifier, given by __name, in a callable
+// structure that loads the DLL symbol out of the DSO handle in a thread-safe
+// manner on first use. This dynamic loading technique is used to avoid DSO
+// dependencies on vendor libraries which may or may not be available in the
+// deployed binary environment.
+#define STREAM_EXECUTOR_ROCFFT_WRAP(__name)                      \
+  struct WrapperShim__##__name {                                 \
+    template <typename... Args>                                  \
+    hipfftResult operator()(GpuExecutor* parent, Args... args) { \
+      gpu::ScopedActivateExecutorContext sac{parent};            \
+      return ::__name(args...);                                  \
+    }                                                            \
+  } __name;
+
+#define ROCFFT_ROUTINE_EACH(__macro) \
+  __macro(hipfftDestroy)             \
+  __macro(hipfftSetStream)           \
+  __macro(hipfftPlan1d)              \
+  __macro(hipfftPlan2d)              \
+  __macro(hipfftPlan3d)              \
+  __macro(hipfftPlanMany)            \
+  __macro(hipfftCreate)              \
+  __macro(hipfftSetAutoAllocation)   \
+  __macro(hipfftSetWorkArea)         \
+  __macro(hipfftGetSize1d)           \
+  __macro(hipfftMakePlan1d)          \
+  __macro(hipfftGetSize2d)           \
+  __macro(hipfftMakePlan2d)          \
+  __macro(hipfftGetSize3d)           \
+  __macro(hipfftMakePlan3d)          \
+  __macro(hipfftGetSizeMany)         \
+  __macro(hipfftMakePlanMany)        \
+  __macro(hipfftExecD2Z)             \
+  __macro(hipfftExecZ2D)             \
+  __macro(hipfftExecC2C)             \
+  __macro(hipfftExecC2R)             \
+  __macro(hipfftExecZ2Z)             \
+  __macro(hipfftExecR2C)             \
+
+ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
+
+}  // namespace wrap
+
+namespace {
+
+// A helper function transforming gpu_fft arguments into rocFFT arguments.
+hipfftType ROCMFftType(fft::Type type) {
+  switch (type) {
+    case fft::Type::kC2CForward:
+    case fft::Type::kC2CInverse:
+      return HIPFFT_C2C;
+    case fft::Type::kC2R:
+      return HIPFFT_C2R;
+    case fft::Type::kR2C:
+      return HIPFFT_R2C;
+    case fft::Type::kZ2ZForward:
+    case fft::Type::kZ2ZInverse:
+      return HIPFFT_Z2Z;
+    case fft::Type::kZ2D:
+      return HIPFFT_Z2D;
+    case fft::Type::kD2Z:
+      return HIPFFT_D2Z;
+    default:
+      LOG(FATAL) << "Invalid value of fft::Type.";
+  }
+}
+
+// Associates the given stream with the given rocFFT plan.
+bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
+  auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
+  if (ret != HIPFFT_SUCCESS) {
+    LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
+    return false;
+  }
+  return true;
+}
+
+}  // namespace
+
+port::Status ROCMFftPlan::Initialize(
+    GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
+    uint64 *input_embed, uint64 input_stride, uint64 input_distance,
+    uint64 *output_embed, uint64 output_stride, uint64 output_distance,
+    fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
+  if (IsInitialized()) {
+    LOG(FATAL) << "Try to repeatedly initialize.";
+  }
+  is_initialized_ = true;
+  int elem_count_[3], input_embed_[3], output_embed_[3];
+  for (int i = 0; i < rank; ++i) {
+    elem_count_[i] = elem_count[i];
+    if (input_embed) {
+      input_embed_[i] = input_embed[i];
+    }
+    if (output_embed) {
+      output_embed_[i] = output_embed[i];
+    }
+  }
+  parent_ = parent;
+  fft_type_ = type;
+  if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
+    hipfftResult_t ret;
+    if (scratch_allocator == nullptr) {
+      switch (rank) {
+        case 1:
+          // hipfftPlan1d
+          ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
+                                  ROCMFftType(type), 1 /* = batch */);
+          if (ret != HIPFFT_SUCCESS) {
+            LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
+            return port::Status{port::error::INTERNAL,
+                                "Failed to create rocFFT 1d plan."};
+          }
+          return port::Status::OK();
+        case 2:
+          // hipfftPlan2d
+          ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
+                                  elem_count_[1], ROCMFftType(type));
+          if (ret != HIPFFT_SUCCESS) {
+            LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
+            return port::Status{port::error::INTERNAL,
+                                "Failed to create rocFFT 2d plan."};
+          }
+          return port::Status::OK();
+        case 3:
+          // hipfftPlan3d
+          ret =
+              wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
+                                elem_count_[2], ROCMFftType(type));
+          if (ret != HIPFFT_SUCCESS) {
+            LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
+            return port::Status{port::error::INTERNAL,
+                                "Failed to create rocFFT 3d plan."};
+          }
+          return port::Status::OK();
+        default:
+          LOG(ERROR) << "Invalid rank value for hipfftPlan. "
+                        "Requested 1, 2, or 3, given: "
+                     << rank;
+          return port::Status{port::error::INVALID_ARGUMENT,
+                              "hipfftPlan only takes rank 1, 2, or 3."};
+      }
+    } else {
+      ret = wrap::hipfftCreate(parent, &plan_);
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to create rocFFT plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to create rocFFT plan."};
+      }
+      ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to set auto allocation for rocFFT plan."};
+      }
+      size_t size_in_bytes;
+      switch (rank) {
+        case 1:
+          ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
+                                      ROCMFftType(type), /*batch=*/1,
+                                      &size_in_bytes);
+          if (ret != HIPFFT_SUCCESS) {
+            LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
+            return port::Status{port::error::INTERNAL,
+                                "Failed to make rocFFT 1d plan."};
+          }
+          break;
+        case 2:
+          ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
+                                      elem_count_[1], ROCMFftType(type),
+                                      &size_in_bytes);
+          if (ret != HIPFFT_SUCCESS) {
+            LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
+            return port::Status{port::error::INTERNAL,
+                                "Failed to make rocFFT 2d plan."};
+          }
+          break;
+        case 3:
+          ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
+                                      elem_count_[1], elem_count_[2],
+                                      ROCMFftType(type), &size_in_bytes);
+          if (ret != HIPFFT_SUCCESS) {
+            LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
+            return port::Status{port::error::INTERNAL,
+                                "Failed to make rocFFT 3d plan."};
+          }
+          break;
+        default:
+          LOG(ERROR) << "Invalid rank value for hipfftPlan. "
+                        "Requested 1, 2, or 3, given: "
+                     << rank;
+          return port::Status{port::error::INVALID_ARGUMENT,
+                              "hipfftPlan only takes rank 1, 2, or 3."};
+      }
+      // TODO(yangzihao): refactor this code and the one with the same function
+      // in the batch mode.
+      if (size_in_bytes != 0) {
+        auto allocated =
+            scratch_allocator->AllocateBytes(stream, size_in_bytes);
+        if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
+          LOG(ERROR) << "failed to allocate work area.";
+          return allocated.status();
+        }
+      }
+      // Connect work area with allocated space.
+      ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to set work area for rocFFT plan."};
+      }
+      return port::Status::OK();
+    }
+  } else {
+    // For either multiple batches or rank higher than 3, use hipfftPlanMany().
+    if (scratch_allocator == nullptr) {
+      auto ret = wrap::hipfftPlanMany(
+          parent, &plan_, rank, elem_count_,
+          input_embed ? input_embed_ : nullptr, input_stride, input_distance,
+          output_embed ? output_embed_ : nullptr, output_stride,
+          output_distance, ROCMFftType(type), batch_count);
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to create rocFFT bacthed plan."};
+      }
+    } else {
+      auto ret = wrap::hipfftCreate(parent, &plan_);
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to create rocFFT bacthed plan."};
+      }
+      ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
+                   << ret;
+        return port::Status{
+            port::error::INTERNAL,
+            "Failed to set auto allocation for rocFFT bacthed plan."};
+      }
+      size_t size_in_bytes;
+      ret = wrap::hipfftMakePlanMany(
+          parent, plan_, rank, elem_count_,
+          input_embed ? input_embed_ : nullptr, input_stride, input_distance,
+          output_embed ? output_embed_ : nullptr, output_stride,
+          output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to make rocFFT bacthed plan."};
+      }
+      if (size_in_bytes != 0) {
+        auto allocated =
+            scratch_allocator->AllocateBytes(stream, size_in_bytes);
+        if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
+          LOG(ERROR) << "failed to allocate work area.";
+          return allocated.status();
+        }
+      }
+      // Connect work area with allocated space.
+      ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
+      if (ret != HIPFFT_SUCCESS) {
+        LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
+        return port::Status{port::error::INTERNAL,
+                            "Failed to set work area for rocFFT bacthed plan."};
+      }
+    }
+  }
+  return port::Status::OK();
+}
+
+port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
+                                     int rank, uint64 *elem_count,
+                                     fft::Type type,
+                                     ScratchAllocator *scratch_allocator) {
+  return Initialize(parent_, stream, rank, elem_count,
+                    /*input_embed=*/nullptr, /*input_stride=*/0,
+                    /*input_distance=*/0,
+                    /*output_embed=*/nullptr, /*output_stride=*/0,
+                    /*output_distance=*/0, type, 1, scratch_allocator);
+}
+
+ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
+
+int ROCMFftPlan::GetFftDirection() const {
+  if (!IsInitialized()) {
+    LOG(FATAL) << "Try to get fft direction before initialization.";
+  } else {
+    switch (fft_type_) {
+      case fft::Type::kC2CForward:
+      case fft::Type::kZ2ZForward:
+      case fft::Type::kR2C:
+      case fft::Type::kD2Z:
+        return HIPFFT_FORWARD;
+      case fft::Type::kC2CInverse:
+      case fft::Type::kZ2ZInverse:
+      case fft::Type::kC2R:
+      case fft::Type::kZ2D:
+        return HIPFFT_BACKWARD;
+      default:
+        LOG(FATAL) << "Invalid value of fft::Type.";
+    }
+  }
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
+                                                 fft::Type type,
+                                                 bool in_place_fft) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  uint64 elem_count[1] = {num_x};
+  port::Status status = fft_plan_ptr->Initialize(
+      parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
+  // TODO(yangzihao): In the future, send error msg back to TensorFlow
+  // so it can fail gracefully,
+  if (!status.ok()) {
+    LOG(FATAL) << "failed to initialize hipfft 1d plan: "
+               << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
+    Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
+    ScratchAllocator *scratch_allocator) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  uint64 elem_count[1] = {num_x};
+  port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
+                                                 type, scratch_allocator);
+  if (!status.ok()) {
+    LOG(FATAL)
+        << "failed to initialize hipfft 1d plan with customized allocator: "
+        << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
+                                                 uint64 num_y, fft::Type type,
+                                                 bool in_place_fft) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  uint64 elem_count[2] = {num_x, num_y};
+  port::Status status = fft_plan_ptr->Initialize(
+      parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
+  if (!status.ok()) {
+    LOG(FATAL) << "failed to initialize hipfft 2d plan: "
+               << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
+    Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
+    bool in_place_fft, ScratchAllocator *scratch_allocator) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  uint64 elem_count[2] = {num_x, num_y};
+  port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
+                                                 type, scratch_allocator);
+  if (!status.ok()) {
+    LOG(FATAL)
+        << "failed to initialize hipfft 2d plan with customized allocator: "
+        << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
+                                                 uint64 num_y, uint64 num_z,
+                                                 fft::Type type,
+                                                 bool in_place_fft) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  uint64 elem_count[3] = {num_x, num_y, num_z};
+  port::Status status = fft_plan_ptr->Initialize(
+      parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
+  if (!status.ok()) {
+    LOG(FATAL) << "failed to initialize hipfft 3d plan: "
+               << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
+    Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
+    bool in_place_fft, ScratchAllocator *scratch_allocator) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  uint64 elem_count[3] = {num_x, num_y, num_z};
+  port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
+                                                 type, scratch_allocator);
+  if (!status.ok()) {
+    LOG(FATAL)
+        << "failed to initialize hipfft 3d plan with customized allocator: "
+        << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
+    Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
+    uint64 input_stride, uint64 input_distance, uint64 *output_embed,
+    uint64 output_stride, uint64 output_distance, fft::Type type,
+    bool in_place_fft, int batch_count) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  port::Status status = fft_plan_ptr->Initialize(
+      parent_, stream, rank, elem_count, input_embed, input_stride,
+      input_distance, output_embed, output_stride, output_distance, type,
+      batch_count, /*scratch_allocator=*/nullptr);
+  if (!status.ok()) {
+    LOG(FATAL) << "failed to initialize batched hipfft plan: "
+               << status.error_message();
+  }
+
+  return std::move(fft_plan_ptr);
+}
+
+std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
+    Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
+    uint64 input_stride, uint64 input_distance, uint64 *output_embed,
+    uint64 output_stride, uint64 output_distance, fft::Type type,
+    bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
+  std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
+  port::Status status = fft_plan_ptr->Initialize(
+      parent_, stream, rank, elem_count, input_embed, input_stride,
+      input_distance, output_embed, output_stride, output_distance, type,
+      batch_count, scratch_allocator);
+  if (!status.ok()) {
+    LOG(FATAL)
+        << "failed to initialize batched hipfft plan with customized allocator: "
+        << status.error_message();
+  }
+  return std::move(fft_plan_ptr);
+}
+
+void ROCMFft::UpdatePlanWithScratchAllocator(
+    Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
+  LOG(ERROR) << "update plan with scratch allocator not implemented";
+}
+
+template <typename FuncT, typename InputT, typename OutputT>
+bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
+                            const DeviceMemory<InputT> &input,
+                            DeviceMemory<OutputT> *output) {
+  ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
+  if (rocm_fft_plan == nullptr) {
+    LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
+    return false;
+  }
+
+  if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
+    return false;
+  }
+
+  auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
+                        GpuComplex(const_cast<InputT*>(GpuMemory(input))),
+                        GpuComplex(GpuMemoryMutable(output)));
+
+  if (ret != HIPFFT_SUCCESS) {
+    LOG(ERROR) << "failed to run rocFFT routine: " << ret;
+    return false;
+  }
+
+  return true;
+}
+
+template <typename FuncT, typename InputT, typename OutputT>
+bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
+                                         FuncT hipfftExec,
+                                         const DeviceMemory<InputT> &input,
+                                         DeviceMemory<OutputT> *output) {
+  ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
+  if (rocm_fft_plan == nullptr) {
+    LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
+    return false;
+  }
+
+  if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
+    return false;
+  }
+
+  auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
+                        GpuComplex(const_cast<InputT*>(GpuMemory(input))),
+                        GpuComplex(GpuMemoryMutable(output)),
+                        rocm_fft_plan->GetFftDirection());
+
+  if (ret != HIPFFT_SUCCESS) {
+    LOG(ERROR) << "failed to run rocFFT routine: " << ret;
+    return false;
+  }
+
+  return true;
+}
+
+#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
+                                           __fft_type3)                      \
+  bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
+                      const DeviceMemory<std::complex<__type>> &input,       \
+                      DeviceMemory<std::complex<__type>> *output) {          \
+    return DoFftWithDirectionInternal(                                       \
+         stream, plan, wrap::hipfftExec##__fft_type1, input, output);        \
+  }                                                                          \
+  bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
+                      const DeviceMemory<__type> &input,                     \
+                      DeviceMemory<std::complex<__type>> *output) {          \
+    return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
+                         output);                                            \
+  }                                                                          \
+  bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
+                      const DeviceMemory<std::complex<__type>> &input,       \
+                      DeviceMemory<__type> *output) {                        \
+    return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
+                         output);                                            \
+  }
+
+STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
+STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
+
+#undef STREAM_EXECUTOR_ROCM_DEFINE_FFT
+
+}  // namespace gpu
+
+void initialize_rocfft() {
+  port::Status status =
+      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
+          rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
+          [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
+            gpu::GpuExecutor *rocm_executor =
+                dynamic_cast<gpu::GpuExecutor *>(parent);
+            if (rocm_executor == nullptr) {
+              LOG(ERROR)
+                  << "Attempting to initialize an instance of the rocFFT "
+                  << "support library with a non-ROCM StreamExecutor";
+              return nullptr;
+            }
+
+            return new gpu::ROCMFft(rocm_executor);
+          });
+  if (!status.ok()) {
+    LOG(ERROR) << "Unable to register rocFFT factory: "
+               << status.error_message();
+  }
+
+  PluginRegistry::Instance()->SetDefaultFactory(
+      rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
+}
+
+}  // namespace stream_executor
+
+REGISTER_MODULE_INITIALIZER(register_rocfft,
+                            { stream_executor::initialize_rocfft(); });
diff --git a/tensorflow/stream_executor/rocm/rocm_fft.h b/tensorflow/stream_executor/rocm/rocm_fft.h
new file mode 100644
index 00000000000..3dbe5800b74
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/rocm_fft.h
@@ -0,0 +1,132 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// ROCM-specific support for FFT functionality -- this wraps the rocFFT library
+// capabilities, and is only included into ROCM implementation code -- it will
+// not introduce rocm headers into other code.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
+
+#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/scratch_allocator.h"
+#include "rocm/include/rocfft/hipfft.h"
+
+namespace stream_executor {
+
+class Stream;
+
+namespace gpu {
+
+class GpuExecutor;
+
+// Opaque and unique indentifier for the rocFFT plugin.
+extern const PluginId kRocFftPlugin;
+
+// ROCMFftPlan uses deferred initialization. Only a single call of
+// Initialize() is allowed to properly create hipfft plan and set member
+// variable is_initialized_ to true. Newly added interface that uses member
+// variables should first check is_initialized_ to make sure that the values of
+// member variables are valid.
+class ROCMFftPlan : public fft::Plan {
+ public:
+  ROCMFftPlan()
+      : parent_(nullptr),
+        plan_(),
+        fft_type_(fft::Type::kInvalid),
+        scratch_(nullptr),
+        is_initialized_(false) {}
+  ~ROCMFftPlan() override;
+
+  // Get FFT direction in hipFFT based on FFT type.
+  int GetFftDirection() const;
+  hipfftHandle GetPlan() const {
+    if (IsInitialized()) {
+      return plan_;
+    } else {
+      LOG(FATAL) << "Try to get hipfftHandle value before initialization.";
+    }
+  }
+
+  // Initialize function for batched plan
+  port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
+                          uint64 *elem_count, uint64 *input_embed,
+                          uint64 input_stride, uint64 input_distance,
+                          uint64 *output_embed, uint64 output_stride,
+                          uint64 output_distance, fft::Type type,
+                          int batch_count, ScratchAllocator *scratch_allocator);
+
+  // Initialize function for 1d,2d, and 3d plan
+  port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
+                          uint64 *elem_count, fft::Type type,
+                          ScratchAllocator *scratch_allocator);
+
+ protected:
+  bool IsInitialized() const { return is_initialized_; }
+
+ private:
+  GpuExecutor *parent_;
+  hipfftHandle plan_;
+  fft::Type fft_type_;
+  DeviceMemory<uint8> scratch_;
+  bool is_initialized_;
+};
+
+// FFT support for ROCM platform via rocFFT library.
+//
+// This satisfies the platform-agnostic FftSupport interface.
+//
+// Note that the hipFFT handle that this encapsulates is implicitly tied to the
+// context (and, as a result, the device) that the parent GpuExecutor is tied
+// to. This simply happens as an artifact of creating the hipFFT handle when a
+// ROCM context is active.
+//
+// Thread-safe. The ROCM context associated with all operations is the ROCM
+// context of parent_, so all context is explicit.
+class ROCMFft : public fft::FftSupport {
+ public:
+  explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {}
+  ~ROCMFft() override {}
+
+  TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
+
+ private:
+  GpuExecutor *parent_;
+
+  // Two helper functions that execute dynload::hipfftExec?2?.
+
+  // This is for complex to complex FFT, when the direction is required.
+  template <typename FuncT, typename InputT, typename OutputT>
+  bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
+                                  FuncT hipfft_exec,
+                                  const DeviceMemory<InputT> &input,
+                                  DeviceMemory<OutputT> *output);
+
+  // This is for complex to real or real to complex FFT, when the direction
+  // is implied.
+  template <typename FuncT, typename InputT, typename OutputT>
+  bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec,
+                     const DeviceMemory<InputT> &input,
+                     DeviceMemory<OutputT> *output);
+
+  SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft);
+};
+
+}  // namespace gpu
+}  // namespace stream_executor
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_

From 234f47031bdd5346b6a3dc29670c215c94311cc8 Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Thu, 31 Jan 2019 20:09:53 +0000
Subject: [PATCH 4/7] adding code for hiprand plugin

---
 tensorflow/stream_executor/rocm/BUILD       | 50 +++++++-----
 tensorflow/stream_executor/rocm/rocm_rng.cc | 89 +++++++++++----------
 2 files changed, 74 insertions(+), 65 deletions(-)

diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index c0da35121f2..85ab38d4242 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -244,26 +244,34 @@ cc_library(
 #    alwayslink = True,
 #)
 
-# FIXME: enable in future PRs
-#cc_library(
-#    name = "rocrand_plugin",
-#    srcs = ["rocm_rng.cc"],
-#    hdrs = [],
-#    deps = [
-#        ":rocm_gpu_executor",
-#        ":rocm_platform_id",
-#        "@local_config_rocm//rocm:rocm_headers",
-#        "//tensorflow/stream_executor:event",
-#        "//tensorflow/stream_executor:plugin_registry",
-#        "//tensorflow/stream_executor:rng",
-#        "//tenosrflow/stream_executor/gpu:gpu_activation_header",
-#        "//tenosrflow/stream_executor/gpu:gpu_stream_header",
-#        "//tensorflow/stream_executor/lib",
-#        "//tensorflow/stream_executor/platform",
-#        "//tensorflow/stream_executor/platform:dso_loader",
-#    ] + if_static(["@local_config_rocm//rocm:curand"]),
-#    alwayslink = True,
-#)
+cc_library(
+   name = "rocrand_plugin",
+   srcs = if_rocm_is_configured(["rocm_rng.cc"]),
+   hdrs = if_rocm_is_configured([]),
+   deps = if_rocm_is_configured([
+       ":rocm_gpu_executor",
+       ":rocm_platform_id",
+       "@local_config_rocm//rocm:rocm_headers",
+       "//tensorflow/stream_executor:event",
+       "//tensorflow/stream_executor:plugin_registry",
+       "//tensorflow/stream_executor:rng",
+       "//tensorflow/stream_executor/gpu:gpu_activation_header",
+       "//tensorflow/stream_executor/gpu:gpu_helpers_header",
+       "//tensorflow/stream_executor/gpu:gpu_executor_header",
+       "//tensorflow/stream_executor/gpu:gpu_rng_header",
+       "//tensorflow/stream_executor/gpu:gpu_stream_header",
+       "//tensorflow/stream_executor/lib",
+       "//tensorflow/stream_executor/platform",
+       "//tensorflow/stream_executor/platform:dso_loader",
+   ] + if_static([
+       "@local_config_rocm//rocm:hiprand"
+       # Delete the following line once we switch the hiprand library from
+       # being dynamically linked (current behaviour) to being dynamically
+       # loaded (future behaviour)
+       ], ["@local_config_rocm//rocm:hiprand"
+   ])),
+   alwayslink = True,
+)
 
 cc_library(
     name = "all_runtime",
@@ -274,7 +282,7 @@ cc_library(
         #":miopen_plugin",
         ":rocfft_plugin",
         ":rocblas_plugin",
-        #":rocrand_plugin",
+        ":rocrand_plugin",
         ":rocm_driver",
         ":rocm_platform",
     ]),
diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc
index 65acd03c92b..79250579087 100644
--- a/tensorflow/stream_executor/rocm/rocm_rng.cc
+++ b/tensorflow/stream_executor/rocm/rocm_rng.cc
@@ -14,21 +14,22 @@ limitations under the License.
 ==============================================================================*/
 
 #include "rocm/include/hiprand/hiprand.h"
-#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/gpu/gpu_rng.h"
+
 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
-#include "tensorflow/stream_executor/gpu/gpu_rng.h"
+#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
+#include "tensorflow/stream_executor/device_memory.h"
 #include "tensorflow/stream_executor/lib/env.h"
 #include "tensorflow/stream_executor/lib/initialize.h"
 #include "tensorflow/stream_executor/lib/status.h"
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/rng.h"
-#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
 
 // Formats hiprandStatus_t to output prettified values into a log stream.
-std::ostream& operator<<(std::ostream& in, const hiprandStatus_t& status) {
+std::ostream &operator<<(std::ostream &in, const hiprandStatus_t &status) {
 #define OSTREAM_HIPRAND_STATUS(__name) \
   case HIPRAND_STATUS_##__name:        \
     in << "HIPRAND_STATUS_" #__name;   \
@@ -60,7 +61,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
 
 namespace wrap {
 
-#define PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(__name)                     \
+#define STREAM_EXECUTOR_HIPRAND_WRAP(__name)                        \
   struct WrapperShim__##__name {                                    \
     template <typename... Args>                                     \
     hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
@@ -69,15 +70,15 @@ namespace wrap {
     }                                                               \
   } __name;
 
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandCreateGenerator);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandDestroyGenerator);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetStream);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniform);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniformDouble);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetGeneratorOffset);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormal);
-PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormalDouble);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
+STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
 
 }  // namespace wrap
 
@@ -245,40 +246,40 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
 }
 
 }  // namespace gpu
-}  // namespace stream_executor
 
-namespace se = ::stream_executor;
+void initialize_rocrand() {
+  port::Status status =
+      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
+          rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
+          [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
+            gpu::GpuExecutor* rocm_executor =
+                dynamic_cast<gpu::GpuExecutor*>(parent);
+            if (rocm_executor == nullptr) {
+              LOG(ERROR)
+                  << "Attempting to initialize an instance of the hipRAND "
+                  << "support library with a non-ROCM StreamExecutor";
+              return nullptr;
+            }
 
-REGISTER_MODULE_INITIALIZER(register_hiprand, {
-  se::port::Status status =
-      se::PluginRegistry::Instance()
-          ->RegisterFactory<se::PluginRegistry::RngFactory>(
-              se::rocm::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
-              [](se::internal::StreamExecutorInterface* parent)
-                  -> se::rng::RngSupport* {
-                se::gpu::GpuExecutor* rocm_executor =
-                    dynamic_cast<se::gpu::GpuExecutor*>(parent);
-                if (rocm_executor == nullptr) {
-                  LOG(ERROR)
-                      << "Attempting to initialize an instance of the hipRAND "
-                      << "support library with a non-ROCM StreamExecutor";
-                  return nullptr;
-                }
-
-                se::gpu::GpuRng* rng = new se::gpu::GpuRng(rocm_executor);
-                if (!rng->Init()) {
-                  // Note: Init() will log a more specific error.
-                  delete rng;
-                  return nullptr;
-                }
-                return rng;
-              });
+            gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
+            if (!rng->Init()) {
+              // Note: Init() will log a more specific error.
+              delete rng;
+              return nullptr;
+            }
+            return rng;
+          });
 
   if (!status.ok()) {
-    LOG(ERROR) << "Unable to register hipRAND factory: "
+    LOG(ERROR) << "Unable to register rocRAND factory: "
                << status.error_message();
   }
 
-  se::PluginRegistry::Instance()->SetDefaultFactory(
-      se::rocm::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
-});
+  PluginRegistry::Instance()->SetDefaultFactory(
+      rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
+}
+
+}  // namespace stream_executor
+
+REGISTER_MODULE_INITIALIZER(register_rocrand,
+                            { stream_executor::initialize_rocrand(); });

From 07b3f341995ff0b012bf9297f500a8e75c682f31 Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Fri, 1 Feb 2019 02:47:14 +0000
Subject: [PATCH 5/7] changing rocblas, rocfft and rocrand from being
 dynamically linked to being dynmically loaded

---
 .../platform/default/dso_loader.cc            | 52 ++++++++++++++++++-
 .../platform/default/dso_loader.h             | 13 +++++
 tensorflow/stream_executor/rocm/BUILD         | 12 -----
 tensorflow/stream_executor/rocm/rocm_blas.cc  | 37 +++++++++++++
 tensorflow/stream_executor/rocm/rocm_fft.cc   | 34 ++++++++++++
 tensorflow/stream_executor/rocm/rocm_rng.cc   | 35 +++++++++++++
 6 files changed, 169 insertions(+), 14 deletions(-)

diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc
index ad8112b831a..3f2bb5e4dd0 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.cc
+++ b/tensorflow/stream_executor/platform/default/dso_loader.cc
@@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
   port::Status status =
       port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
   if (status.ok()) {
-    LOG(INFO) << "Successfully opened CUDA library " << filename;
+    LOG(INFO) << "Successfully opened dynamic library " << filename;
     return dso_handle;
   }
 
@@ -54,6 +54,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
   return port::Status(port::error::FAILED_PRECONDITION, message);
 }
 }  // namespace
+  
 
 namespace DsoLoader {
 port::StatusOr<void*> GetCudaDriverDsoHandle() {
@@ -99,6 +100,27 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
 port::StatusOr<void*> GetCudnnDsoHandle() {
   return GetDsoHandle("cudnn", GetCudnnVersion());
 }
+
+port::StatusOr<void*> GetRocblasDsoHandle() {
+  return GetDsoHandle("rocblas", "");
+}
+  
+port::StatusOr<void*> GetMiopenDsoHandle() {
+  return GetDsoHandle("MIOpen", "");
+}
+  
+port::StatusOr<void*> GetRocfftDsoHandle() {
+  return GetDsoHandle("rocfft", "");
+}
+  
+port::StatusOr<void*> GetRocrandDsoHandle() {
+  return GetDsoHandle("rocrand", "");
+}
+  
+port::StatusOr<void*> GetHipDsoHandle() {
+  return GetDsoHandle("hip_hcc", "");
+}
+  
 }  // namespace DsoLoader
 
 namespace CachedDsoLoader {
@@ -131,11 +153,37 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
   static auto result = new auto(DsoLoader::GetCuptiDsoHandle());
   return *result;
 }
-
+  
 port::StatusOr<void*> GetCudnnDsoHandle() {
   static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
   return *result;
 }
+
+port::StatusOr<void*> GetRocblasDsoHandle() {
+  static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
+  return result;
+}
+
+port::StatusOr<void*> GetMiopenDsoHandle() {
+  static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
+  return result;
+}
+
+port::StatusOr<void*> GetRocfftDsoHandle() {
+  static auto result = new auto(DsoLoader::GetRocfftDsoHandle());
+  return result;
+}
+
+port::StatusOr<void*> GetRocrandDsoHandle() {
+  static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
+  return result;
+}
+
+port::StatusOr<void*> GetHipDsoHandle() {
+  static auto result = new auto(DsoLoader::GetHipDsoHandle());
+  return result;
+}
+
 }  // namespace CachedDsoLoader
 }  // namespace internal
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h
index 45a8315b436..8da8ea7be66 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.h
+++ b/tensorflow/stream_executor/platform/default/dso_loader.h
@@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle();
 port::StatusOr<void*> GetCurandDsoHandle();
 port::StatusOr<void*> GetCuptiDsoHandle();
 port::StatusOr<void*> GetCudnnDsoHandle();
+
+port::StatusOr<void*> GetRocblasDsoHandle();
+port::StatusOr<void*> GetMiopenDsoHandle();
+port::StatusOr<void*> GetRocfftDsoHandle();
+port::StatusOr<void*> GetRocrandDsoHandle();
+port::StatusOr<void*> GetHipDsoHandle();
 }  // namespace DsoLoader
 
 // Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
@@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle();
 port::StatusOr<void*> GetCurandDsoHandle();
 port::StatusOr<void*> GetCuptiDsoHandle();
 port::StatusOr<void*> GetCudnnDsoHandle();
+ 
+port::StatusOr<void*> GetRocblasDsoHandle();
+port::StatusOr<void*> GetMiopenDsoHandle();
+port::StatusOr<void*> GetRocfftDsoHandle();
+port::StatusOr<void*> GetRocrandDsoHandle();
+port::StatusOr<void*> GetHipDsoHandle();
 }  // namespace CachedDsoLoader
+
 }  // namespace internal
 }  // namespace stream_executor
 
diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index 85ab38d4242..f0b05822703 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -167,10 +167,6 @@ cc_library(
        "@local_config_rocm//rocm:rocm_headers",
    ] + if_static([
        "@local_config_rocm//rocm:rocblas"
-       # Delete the following line once we switch the rocblas library from
-       # being dynamically linked (current behaviour) to being dynamically
-       # loaded (future behaviour)
-       ], ["@local_config_rocm//rocm:rocblas"
    ])),
    alwayslink = True,
 )
@@ -197,10 +193,6 @@ cc_library(
        "@local_config_rocm//rocm:rocm_headers",
    ] + if_static([
        "@local_config_rocm//rocm:rocfft"
-       # Delete the following line once we switch the rocfft library from
-       # being dynamically linked (current behaviour) to being dynamically
-       # loaded (future behaviour)
-       ], ["@local_config_rocm//rocm:rocfft"
    ])),
    alwayslink = True,
 )
@@ -265,10 +257,6 @@ cc_library(
        "//tensorflow/stream_executor/platform:dso_loader",
    ] + if_static([
        "@local_config_rocm//rocm:hiprand"
-       # Delete the following line once we switch the hiprand library from
-       # being dynamically linked (current behaviour) to being dynamically
-       # loaded (future behaviour)
-       ], ["@local_config_rocm//rocm:hiprand"
    ])),
    alwayslink = True,
 )
diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index b2e225433e5..a626d168c26 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/lib/status.h"
 #include "tensorflow/stream_executor/lib/status_macros.h"
 #include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/dso_loader.h"
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/platform/port.h"
 #include "tensorflow/stream_executor/plugin_registry.h"
@@ -49,6 +50,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin);
 
 namespace wrap {
 
+#ifdef PLATFORM_GOOGLE
 #define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                       \
   struct WrapperShim__##__name {                                   \
     static const char* kName;                                      \
@@ -63,6 +65,41 @@ namespace wrap {
 #define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
   STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
 
+#else
+
+#define STREAM_EXECUTOR_ROCBLAS_WRAP(__name)                              \
+  struct DynLoadShim__##__name {                                          \
+    static const char* kName;                                             \
+    using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
+    static void* GetDsoHandle() {                                         \
+      auto s = internal::CachedDsoLoader::GetRocblasDsoHandle();          \
+      return s.ValueOrDie();                                              \
+    }                                                                     \
+    static FuncPtrT LoadOrDie() {                                         \
+      void* f;                                                            \
+      auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
+                                                          kName, &f);     \
+      CHECK(s.ok()) << "could not find " << kName                         \
+                    << " in rocblas DSO; dlerror: " << s.error_message(); \
+      return reinterpret_cast<FuncPtrT>(f);                               \
+    }                                                                     \
+    static FuncPtrT DynLoad() {                                           \
+      static FuncPtrT f = LoadOrDie();                                    \
+      return f;                                                           \
+    }                                                                     \
+    template <typename... Args>                                           \
+    rocblas_status operator()(GpuExecutor* parent, Args... args) {       \
+      gpu::ScopedActivateExecutorContext sac{parent};                    \
+      return DynLoad()(args...);                                          \
+    }                                                                     \
+  } __name;                                                               \
+  const char* DynLoadShim__##__name::kName = #__name;
+
+#define STREAM_EXECUTOR_ROCBLAS_V2_WRAP(__name) \
+  STREAM_EXECUTOR_ROCBLAS_WRAP(__name)
+
+#endif
+  
 #define ROCBLAS_BLAS_ROUTINE_EACH(__macro)                                     \
   __macro(rocblas_snrm2) __macro(rocblas_dnrm2) /*  __macro(rocblas_scnrm2)    \
                                                   __macro(rocblas_dznrm2) */   \
diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc
index dd30911eadd..e8a72f61d3f 100644
--- a/tensorflow/stream_executor/rocm/rocm_fft.cc
+++ b/tensorflow/stream_executor/rocm/rocm_fft.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/lib/env.h"
 #include "tensorflow/stream_executor/lib/initialize.h"
 #include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/dso_loader.h"
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/platform/port.h"
 #include "tensorflow/stream_executor/plugin_registry.h"
@@ -38,6 +39,7 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
 
 namespace wrap {
 
+#ifdef PLATFORM_GOOGLE
 // This macro wraps a global identifier, given by __name, in a callable
 // structure that loads the DLL symbol out of the DSO handle in a thread-safe
 // manner on first use. This dynamic loading technique is used to avoid DSO
@@ -52,6 +54,38 @@ namespace wrap {
     }                                                            \
   } __name;
 
+#else
+  
+#define STREAM_EXECUTOR_ROCFFT_WRAP(__name)                               \
+  struct DynLoadShim__##__name {                                          \
+    static const char *kName;                                             \
+    using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
+    static void *GetDsoHandle() {                                         \
+      auto s = internal::CachedDsoLoader::GetRocfftDsoHandle();           \
+      return s.ValueOrDie();                                              \
+    }                                                                     \
+    static FuncPtrT LoadOrDie() {                                         \
+      void *f;                                                            \
+      auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
+                                                          kName, &f);     \
+      CHECK(s.ok()) << "could not find " << kName                         \
+                    << " in rocfft DSO; dlerror: " << s.error_message();  \
+      return reinterpret_cast<FuncPtrT>(f);                               \
+    }                                                                     \
+    static FuncPtrT DynLoad() {                                           \
+      static FuncPtrT f = LoadOrDie();                                    \
+      return f;                                                           \
+    }                                                                     \
+    template <typename... Args>                                           \
+    hipfftResult operator()(GpuExecutor *parent, Args... args) {       \
+      gpu::ScopedActivateExecutorContext sac{parent};                    \
+      return DynLoad()(args...);                                          \
+    }                                                                     \
+  } __name;                                                               \
+  const char *DynLoadShim__##__name::kName = #__name;
+
+#endif
+  
 #define ROCFFT_ROUTINE_EACH(__macro) \
   __macro(hipfftDestroy)             \
   __macro(hipfftSetStream)           \
diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc
index 79250579087..27797845700 100644
--- a/tensorflow/stream_executor/rocm/rocm_rng.cc
+++ b/tensorflow/stream_executor/rocm/rocm_rng.cc
@@ -25,6 +25,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/lib/env.h"
 #include "tensorflow/stream_executor/lib/initialize.h"
 #include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/dso_loader.h"
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/rng.h"
 
@@ -61,6 +62,8 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
 
 namespace wrap {
 
+#ifdef PLATFORM_GOOGLE
+
 #define STREAM_EXECUTOR_HIPRAND_WRAP(__name)                        \
   struct WrapperShim__##__name {                                    \
     template <typename... Args>                                     \
@@ -70,6 +73,38 @@ namespace wrap {
     }                                                               \
   } __name;
 
+#else
+
+#define STREAM_EXECUTOR_HIPRAND_WRAP(__name)                              \
+  struct DynLoadShim__##__name {                                          \
+    static const char *kName;                                             \
+    using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
+    static void *GetDsoHandle() {                                         \
+      auto s = internal::CachedDsoLoader::GetRocrandDsoHandle();          \
+      return s.ValueOrDie();                                              \
+    }                                                                     \
+    static FuncPtrT LoadOrDie() {                                         \
+      void *f;                                                            \
+      auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
+                                                          kName, &f);     \
+      CHECK(s.ok()) << "could not find " << kName                         \
+                    << " in rocrand DSO; dlerror: " << s.error_message(); \
+      return reinterpret_cast<FuncPtrT>(f);                               \
+    }                                                                     \
+    static FuncPtrT DynLoad() {                                           \
+      static FuncPtrT f = LoadOrDie();                                    \
+      return f;                                                           \
+    }                                                                     \
+    template <typename... Args>                                           \
+    hiprandStatus operator()(GpuExecutor *parent, Args... args) {        \
+      gpu::ScopedActivateExecutorContext sac{parent};                    \
+      return DynLoad()(args...);                                          \
+    }                                                                     \
+  } __name;                                                               \
+  const char *DynLoadShim__##__name::kName = #__name;
+
+#endif
+  
 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
 STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);

From 834a3f7395a9db748349e0bf9dfff7af558cb4fb Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Fri, 1 Feb 2019 17:59:02 +0000
Subject: [PATCH 6/7] changing the rocm_driver API from being dynamically
 linked to being dynamically loaded

---
 tensorflow/stream_executor/rocm/BUILD         | 149 +++++++++---------
 .../stream_executor/rocm/rocm_driver.cc       | 125 +++++++--------
 .../rocm/rocm_driver_wrapper.h                | 147 +++++++++++++++++
 3 files changed, 285 insertions(+), 136 deletions(-)
 create mode 100644 tensorflow/stream_executor/rocm/rocm_driver_wrapper.h

diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index f0b05822703..5190b551f80 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -47,7 +47,7 @@ cc_library(
 cc_library(
     name = "rocm_driver",
     srcs = if_rocm_is_configured(["rocm_driver.cc"]),
-    hdrs = [],
+    hdrs = if_rocm_is_configured(["rocm_driver_wrapper.h"]),
     deps = if_rocm_is_configured([
         ":rocm_diagnostics",
         "@com_google_absl//absl/base",
@@ -57,6 +57,7 @@ cc_library(
         "//tensorflow/stream_executor/gpu:gpu_driver_header",
         "//tensorflow/stream_executor/lib",
         "//tensorflow/stream_executor/platform",
+        "//tensorflow/stream_executor/platform:dso_loader",
         "@local_config_rocm//rocm:rocm_headers",
     ]),
 )
@@ -141,60 +142,60 @@ cc_library(
 )
 
 cc_library(
-   name = "rocblas_plugin",
-   srcs = if_rocm_is_configured(["rocm_blas.cc"]),
-   hdrs = if_rocm_is_configured(["rocm_blas.h"]),
-   visibility = ["//visibility:public"],
-   deps = if_rocm_is_configured([
-       ":rocm_gpu_executor",
-       ":rocm_platform_id",
-       "//third_party/eigen3",
-       "//tensorflow/core:lib_internal",
-       "//tensorflow/stream_executor",
-       "//tensorflow/stream_executor:event",
-       "//tensorflow/stream_executor:host_or_device_scalar",
-       "//tensorflow/stream_executor:plugin_registry",
-       "//tensorflow/stream_executor:scratch_allocator",
-       "//tensorflow/stream_executor:timer",
-       "//tensorflow/stream_executor/gpu:gpu_activation",
-       "//tensorflow/stream_executor/gpu:gpu_helpers_header",
-       "//tensorflow/stream_executor/gpu:gpu_stream_header",
-       "//tensorflow/stream_executor/gpu:gpu_timer_header",
-       "//tensorflow/stream_executor/lib",
-       "//tensorflow/stream_executor/platform",
-       "//tensorflow/stream_executor/platform:dso_loader",
-       "@com_google_absl//absl/strings",
-       "@local_config_rocm//rocm:rocm_headers",
-   ] + if_static([
-       "@local_config_rocm//rocm:rocblas"
-   ])),
-   alwayslink = True,
+    name = "rocblas_plugin",
+    srcs = if_rocm_is_configured(["rocm_blas.cc"]),
+    hdrs = if_rocm_is_configured(["rocm_blas.h"]),
+    visibility = ["//visibility:public"],
+    deps = if_rocm_is_configured([
+        ":rocm_gpu_executor",
+        ":rocm_platform_id",
+        "//third_party/eigen3",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/stream_executor",
+        "//tensorflow/stream_executor:event",
+        "//tensorflow/stream_executor:host_or_device_scalar",
+        "//tensorflow/stream_executor:plugin_registry",
+        "//tensorflow/stream_executor:scratch_allocator",
+        "//tensorflow/stream_executor:timer",
+        "//tensorflow/stream_executor/gpu:gpu_activation",
+        "//tensorflow/stream_executor/gpu:gpu_helpers_header",
+        "//tensorflow/stream_executor/gpu:gpu_stream_header",
+        "//tensorflow/stream_executor/gpu:gpu_timer_header",
+        "//tensorflow/stream_executor/lib",
+        "//tensorflow/stream_executor/platform",
+        "//tensorflow/stream_executor/platform:dso_loader",
+        "@com_google_absl//absl/strings",
+        "@local_config_rocm//rocm:rocm_headers",
+    ] + if_static([
+        "@local_config_rocm//rocm:rocblas",
+    ])),
+    alwayslink = True,
 )
 
 cc_library(
-   name = "rocfft_plugin",
-   srcs = if_rocm_is_configured(["rocm_fft.cc"]),
-   hdrs = if_rocm_is_configured(["rocm_fft.h"]),
-   visibility = ["//visibility:public"],
-   deps = if_rocm_is_configured([
-       ":rocm_platform_id",
-       "//tensorflow/stream_executor:event",
-       "//tensorflow/stream_executor:fft",
-       "//tensorflow/stream_executor:plugin_registry",
-       "//tensorflow/stream_executor:scratch_allocator",
-       "//tensorflow/stream_executor/gpu:gpu_activation",
-       "//tensorflow/stream_executor/gpu:gpu_helpers_header",
-       "//tensorflow/stream_executor/gpu:gpu_executor_header",
-       "//tensorflow/stream_executor/gpu:gpu_stream_header",
-       "//tensorflow/stream_executor/gpu:gpu_kernel_header",
-       "//tensorflow/stream_executor/lib",
-       "//tensorflow/stream_executor/platform",
-       "//tensorflow/stream_executor/platform:dso_loader",
-       "@local_config_rocm//rocm:rocm_headers",
-   ] + if_static([
-       "@local_config_rocm//rocm:rocfft"
-   ])),
-   alwayslink = True,
+    name = "rocfft_plugin",
+    srcs = if_rocm_is_configured(["rocm_fft.cc"]),
+    hdrs = if_rocm_is_configured(["rocm_fft.h"]),
+    visibility = ["//visibility:public"],
+    deps = if_rocm_is_configured([
+        ":rocm_platform_id",
+        "//tensorflow/stream_executor:event",
+        "//tensorflow/stream_executor:fft",
+        "//tensorflow/stream_executor:plugin_registry",
+        "//tensorflow/stream_executor:scratch_allocator",
+        "//tensorflow/stream_executor/gpu:gpu_activation",
+        "//tensorflow/stream_executor/gpu:gpu_helpers_header",
+        "//tensorflow/stream_executor/gpu:gpu_executor_header",
+        "//tensorflow/stream_executor/gpu:gpu_stream_header",
+        "//tensorflow/stream_executor/gpu:gpu_kernel_header",
+        "//tensorflow/stream_executor/lib",
+        "//tensorflow/stream_executor/platform",
+        "//tensorflow/stream_executor/platform:dso_loader",
+        "@local_config_rocm//rocm:rocm_headers",
+    ] + if_static([
+        "@local_config_rocm//rocm:rocfft",
+    ])),
+    alwayslink = True,
 )
 
 # FIXME: enable in future PRs
@@ -237,28 +238,28 @@ cc_library(
 #)
 
 cc_library(
-   name = "rocrand_plugin",
-   srcs = if_rocm_is_configured(["rocm_rng.cc"]),
-   hdrs = if_rocm_is_configured([]),
-   deps = if_rocm_is_configured([
-       ":rocm_gpu_executor",
-       ":rocm_platform_id",
-       "@local_config_rocm//rocm:rocm_headers",
-       "//tensorflow/stream_executor:event",
-       "//tensorflow/stream_executor:plugin_registry",
-       "//tensorflow/stream_executor:rng",
-       "//tensorflow/stream_executor/gpu:gpu_activation_header",
-       "//tensorflow/stream_executor/gpu:gpu_helpers_header",
-       "//tensorflow/stream_executor/gpu:gpu_executor_header",
-       "//tensorflow/stream_executor/gpu:gpu_rng_header",
-       "//tensorflow/stream_executor/gpu:gpu_stream_header",
-       "//tensorflow/stream_executor/lib",
-       "//tensorflow/stream_executor/platform",
-       "//tensorflow/stream_executor/platform:dso_loader",
-   ] + if_static([
-       "@local_config_rocm//rocm:hiprand"
-   ])),
-   alwayslink = True,
+    name = "rocrand_plugin",
+    srcs = if_rocm_is_configured(["rocm_rng.cc"]),
+    hdrs = if_rocm_is_configured([]),
+    deps = if_rocm_is_configured([
+        ":rocm_gpu_executor",
+        ":rocm_platform_id",
+        "@local_config_rocm//rocm:rocm_headers",
+        "//tensorflow/stream_executor:event",
+        "//tensorflow/stream_executor:plugin_registry",
+        "//tensorflow/stream_executor:rng",
+        "//tensorflow/stream_executor/gpu:gpu_activation_header",
+        "//tensorflow/stream_executor/gpu:gpu_helpers_header",
+        "//tensorflow/stream_executor/gpu:gpu_executor_header",
+        "//tensorflow/stream_executor/gpu:gpu_rng_header",
+        "//tensorflow/stream_executor/gpu:gpu_stream_header",
+        "//tensorflow/stream_executor/lib",
+        "//tensorflow/stream_executor/platform",
+        "//tensorflow/stream_executor/platform:dso_loader",
+    ] + if_static([
+        "@local_config_rocm//rocm:hiprand",
+    ])),
+    alwayslink = True,
 )
 
 cc_library(
diff --git a/tensorflow/stream_executor/rocm/rocm_driver.cc b/tensorflow/stream_executor/rocm/rocm_driver.cc
index 39d52d28304..73b1b350f71 100644
--- a/tensorflow/stream_executor/rocm/rocm_driver.cc
+++ b/tensorflow/stream_executor/rocm/rocm_driver.cc
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/platform/mutex.h"
 #include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
 
 bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
 bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
@@ -143,7 +144,7 @@ string MemorySpaceString(MemorySpace memory_space) {
 // HIP driver (e.g., this value is not our cached view of the current device).
 static int CurrentDeviceOrDie() {
   int current = -1;
-  hipError_t result = hipGetDevice(&current);
+  hipError_t result = tensorflow::wrap::hipGetDevice(&current);
   if (result != hipSuccess) {
     LOG(FATAL) << "failed to query current device: " << ToString(result);
   }
@@ -154,7 +155,7 @@ namespace {
 
 // Call hipDeviceSynchronize and crash if it doesn't succeed.
 void SynchronizeOrDie() {
-  auto res = hipDeviceSynchronize();
+  auto res = tensorflow::wrap::hipDeviceSynchronize();
   if (res != hipSuccess) {
     LOG(FATAL) << "Synchronize found " << ToString(res)
                << " :: " << port::CurrentStackTrace();
@@ -197,7 +198,7 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* context) {
           << tls->current_device_ordinal << " to " << context->device_ordinal();
 
   // Set the device and update thread local.
-  CHECK_EQ(hipSuccess, hipSetDevice(context->device_ordinal()));
+  CHECK_EQ(hipSuccess, tensorflow::wrap::hipSetDevice(context->device_ordinal()));
   tls->current_device_ordinal = context->device_ordinal();
 }
 
@@ -225,7 +226,7 @@ ScopedActivateContext::~ScopedActivateContext() {
           << to_restore_->device_ordinal();
 
   // Set context and update thread local.
-  CHECK_EQ(hipSuccess, hipSetDevice(to_restore_->device_ordinal()));
+  CHECK_EQ(hipSuccess, tensorflow::wrap::hipSetDevice(to_restore_->device_ordinal()));
   tls->current_device_ordinal = to_restore_->device_ordinal();
 }
 
@@ -261,7 +262,7 @@ string ROCMPointerToMemorySpaceString(hipDeviceptr_t pointer) {
 // in the process of querying.
 string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
   hipPointerAttribute_t from_pointerAttributes;
-  hipError_t result = hipPointerGetAttributes(&from_pointerAttributes, from);
+  hipError_t result = tensorflow::wrap::hipPointerGetAttributes(&from_pointerAttributes, from);
   if (result != hipSuccess) {
     LOG(ERROR) << "could not retrieve source pointer's device: "
                << ToString(result);
@@ -269,7 +270,7 @@ string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
   }
 
   hipPointerAttribute_t to_pointerAttributes;
-  result = hipPointerGetAttributes(&to_pointerAttributes, to);
+  result = tensorflow::wrap::hipPointerGetAttributes(&to_pointerAttributes, to);
   if (result != hipSuccess) {
     LOG(ERROR) << "could not retrieve destination pointer's device: "
                << ToString(result);
@@ -289,7 +290,7 @@ static port::Status InternalInit() {
   if (FLAGS_gpuexec_rocm_driver_inject_init_error) {
     LOG(ERROR) << "injecting ROCM init error; initialization will fail";
   } else {
-    res = hipInit(0 /* = flags */);
+    res = tensorflow::wrap::hipInit(0 /* = flags */);
   }
 
   if (res == hipSuccess) {
@@ -322,7 +323,7 @@ static port::Status InternalInit() {
 
 /* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
                                                hipDevice_t* device) {
-  hipError_t res = hipDeviceGet(device, device_ordinal);
+  hipError_t res = tensorflow::wrap::hipDeviceGet(device, device_ordinal);
   if (res == hipSuccess) {
     return port::Status::OK();
   }
@@ -336,7 +337,7 @@ static port::Status InternalInit() {
                                            string* device_name) {
   static const size_t kCharLimit = 64;
   absl::InlinedVector<char, 4> chars(kCharLimit);
-  hipError_t res = hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
+  hipError_t res = tensorflow::wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to get device name for " << device << ": "
                << ToString(res);
@@ -382,7 +383,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
 
 /* static */ bool GpuDriver::FuncSetCacheConfig(hipFunction_t function,
                                                 hipFuncCache_t cache_config) {
-  hipError_t res = hipFuncSetCacheConfig(function, cache_config);
+  hipError_t res = tensorflow::wrap::hipFuncSetCacheConfig(function, cache_config);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to set ROCM kernel cache config. kernel: " << function
                << ", config: " << cache_config << ", result: " << ToString(res);
@@ -396,7 +397,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
 GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   hipSharedMemConfig shared_mem_config;
   ScopedActivateContext activation{context};
-  hipError_t result = hipDeviceGetSharedMemConfig(&shared_mem_config);
+  hipError_t result = tensorflow::wrap::hipDeviceGetSharedMemConfig(&shared_mem_config);
   if (result != hipSuccess) {
     LOG(ERROR) << "failed to get ROCM device shared memory config. "
                << "Context device ID: " << context->device_ordinal()
@@ -411,7 +412,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ port::Status GpuDriver::ContextSetSharedMemConfig(
     GpuContext* context, hipSharedMemConfig shared_mem_config) {
   ScopedActivateContext activation{context};
-  hipError_t result = hipDeviceSetSharedMemConfig(shared_mem_config);
+  hipError_t result = tensorflow::wrap::hipDeviceSetSharedMemConfig(shared_mem_config);
   if (result != hipSuccess) {
     LOG(ERROR) << "failed to set ROCM device shared memory config. "
                << "Context device ID: " << context->device_ordinal()
@@ -435,7 +436,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
           << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
           << " bdx: " << block_dim_x << " bdy: " << block_dim_y
           << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;
-  hipError_t res = hipModuleLaunchKernel(
+  hipError_t res = tensorflow::wrap::hipModuleLaunchKernel(
       function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
       block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
   if (res != hipSuccess) {
@@ -471,7 +472,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
         ScopedActivateContext activation{context};
         void* hsaco_data = const_cast<char*>(hsaco_contents);
 
-        hipError_t res = hipModuleLoadData(module, hsaco_data);
+        hipError_t res = tensorflow::wrap::hipModuleLoadData(module, hsaco_data);
 
         if (res != hipSuccess) {
           LOG(ERROR) << "failed to load HSACO: " << ToString(res);
@@ -491,7 +492,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                     hipDeviceptr_t location,
                                                     uint8 value, size_t size) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipMemset(location, value, size);
+  hipError_t res = tensorflow::wrap::hipMemset(location, value, size);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to memset memory: " << ToString(res);
     return false;
@@ -513,7 +514,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
     return false;
   }
   hipError_t res =
-      hipMemset(pointer, static_cast<int>(value), uint32_count * 4);
+      tensorflow::wrap::hipMemset(pointer, static_cast<int>(value), uint32_count * 4);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to memset memory: " << ToString(res);
     return false;
@@ -527,7 +528,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                      size_t uint32_count,
                                                      GpuStreamHandle stream) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipMemsetAsync(location, value, uint32_count, stream);
+  hipError_t res = tensorflow::wrap::hipMemsetAsync(location, value, uint32_count, stream);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
     return false;
@@ -552,7 +553,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
     LOG(ERROR) << "failed to memset memory";
     return false;
   }
-  hipError_t res = hipMemsetAsync(pointer, value, uint32_count * 4, stream);
+  hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value, uint32_count * 4, stream);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
     return false;
@@ -565,7 +566,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                GpuStreamHandle stream,
                                                StreamCallback callback,
                                                void* data) {
-  hipError_t res = hipStreamAddCallback(stream, (hipStreamCallback_t)callback,
+  hipError_t res = tensorflow::wrap::hipStreamAddCallback(stream, (hipStreamCallback_t)callback,
                                         data, 0 /* = flags */);
   if (res != hipSuccess) {
     LOG(ERROR) << "unable to add host callback: " << ToString(res);
@@ -580,7 +581,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                hipFunction_t* function) {
   ScopedActivateContext activated{context};
   CHECK(module != nullptr && kernel_name != nullptr);
-  hipError_t res = hipModuleGetFunction(function, module, kernel_name);
+  hipError_t res = tensorflow::wrap::hipModuleGetFunction(function, module, kernel_name);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to get kernel \"" << kernel_name
                << "\" from module: " << ToString(res);
@@ -598,7 +599,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   ScopedActivateContext activated{context};
   CHECK(module != nullptr && symbol_name != nullptr &&
         (dptr != nullptr || bytes != nullptr));
-  hipError_t res = hipModuleGetGlobal(dptr, bytes, module, symbol_name);
+  hipError_t res = tensorflow::wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name);
   if (res != hipSuccess) {
     // symbol may not be found in the current module, but it may reside in
     // another module.
@@ -613,7 +614,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ void GpuDriver::UnloadModule(GpuContext* context,
                                           hipModule_t module) {
   ScopedActivateContext activated{context};
-  hipError_t res = hipModuleUnload(module);
+  hipError_t res = tensorflow::wrap::hipModuleUnload(module);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to unload module " << module
                << "; leaking: " << ToString(res);
@@ -623,7 +624,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ bool GpuDriver::CreateStream(GpuContext* context,
                                           GpuStreamHandle* stream) {
   ScopedActivateContext activated{context};
-  hipError_t res = hipStreamCreateWithFlags(
+  hipError_t res = tensorflow::wrap::hipStreamCreateWithFlags(
       stream, hipStreamDefault);  // switch to hipStreamNonBlocking?
   if (res != hipSuccess) {
     LOG(ERROR) << "could not allocate ROCM stream for device "
@@ -643,7 +644,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   }
 
   ScopedActivateContext activated{context};
-  hipError_t res = hipStreamDestroy(*stream);
+  hipError_t res = tensorflow::wrap::hipStreamDestroy(*stream);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to destroy ROCM stream for device "
                << context->device_ordinal() << ": " << ToString(res);
@@ -658,7 +659,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                              uint64 bytes) {
   ScopedActivateContext activated{context};
   hipDeviceptr_t result = 0;
-  hipError_t res = hipMalloc(&result, bytes);
+  hipError_t res = tensorflow::wrap::hipMallocVanilla(&result, bytes);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to allocate "
                << port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
@@ -675,7 +676,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                               void* location) {
   ScopedActivateContext activation{context};
   hipDeviceptr_t pointer = absl::bit_cast<hipDeviceptr_t>(location);
-  hipError_t res = hipFree(pointer);
+  hipError_t res = tensorflow::wrap::hipFree(pointer);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to free device memory at " << location
                << "; result: " << ToString(res);
@@ -704,7 +705,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   ScopedActivateContext activation{context};
   void* host_mem = nullptr;
   // "Portable" memory is visible to all ROCM contexts. Safe for our use model.
-  hipError_t res = hipHostMalloc(&host_mem, bytes, hipHostMallocPortable);
+  hipError_t res = tensorflow::wrap::hipHostMallocVanilla(&host_mem, bytes, hipHostMallocPortable);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to alloc " << bytes
                << " bytes on host: " << ToString(res);
@@ -715,7 +716,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ void GpuDriver::HostDeallocate(GpuContext* context,
                                             void* location) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipHostFree(location);
+  hipError_t res = tensorflow::wrap::hipHostFree(location);
   if (res != hipSuccess) {
     LOG(ERROR) << "error deallocating host memory at " << location << ": "
                << ToString(res);
@@ -726,7 +727,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                           uint64 bytes) {
   ScopedActivateContext activation{context};
   // "Portable" memory is visible to all ROCM contexts. Safe for our use model.
-  hipError_t res = hipHostRegister(location, bytes, hipHostRegisterPortable);
+  hipError_t res = tensorflow::wrap::hipHostRegister(location, bytes, hipHostRegisterPortable);
   if (res != hipSuccess) {
     LOG(ERROR) << "error registering host memory at " << location << ": "
                << ToString(res);
@@ -738,7 +739,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ bool GpuDriver::HostUnregister(GpuContext* context,
                                             void* location) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipHostUnregister(location);
+  hipError_t res = tensorflow::wrap::hipHostUnregister(location);
   if (res != hipSuccess) {
     LOG(ERROR) << "error unregistering host memory at " << location << ": "
                << ToString(res);
@@ -755,7 +756,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   }
 
   ScopedActivateContext activated{context};
-  hipError_t res = hipEventDestroy(*event);
+  hipError_t res = tensorflow::wrap::hipEventDestroy(*event);
   *event = nullptr;
 
   switch (res) {
@@ -779,7 +780,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                  GpuEventHandle event,
                                                  GpuStreamHandle stream) {
   ScopedActivateContext activated{context};
-  hipError_t res = hipEventRecord(event, stream);
+  hipError_t res = tensorflow::wrap::hipEventRecord(event, stream);
   switch (res) {
     case hipSuccess:
       return port::Status::OK();
@@ -800,7 +801,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ port::StatusOr<hipError_t> GpuDriver::QueryEvent(
     GpuContext* context, GpuEventHandle event) {
   ScopedActivateContext activated{context};
-  hipError_t res = hipEventQuery(event);
+  hipError_t res = tensorflow::wrap::hipEventQuery(event);
   if (res != hipSuccess && res != hipErrorNotReady) {
     return port::Status{
         port::error::INTERNAL,
@@ -817,12 +818,12 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   ScopedActivateContext activated{context};
   // The stop event must have completed in order for hipEventElapsedTime to
   // work.
-  hipError_t res = hipEventSynchronize(stop);
+  hipError_t res = tensorflow::wrap::hipEventSynchronize(stop);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res);
     return false;
   }
-  res = hipEventElapsedTime(elapsed_milliseconds, start, stop);
+  res = tensorflow::wrap::hipEventElapsedTime(elapsed_milliseconds, start, stop);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to get elapsed time between events: "
                << ToString(res);
@@ -836,7 +837,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                GpuStreamHandle stream,
                                                GpuEventHandle event) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipStreamWaitEvent(stream, event, 0 /* = flags */);
+  hipError_t res = tensorflow::wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */);
   if (res != hipSuccess) {
     LOG(ERROR) << "could not wait stream on event: " << ToString(res);
     return false;
@@ -847,7 +848,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 
 /* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipDeviceSynchronize();
+  hipError_t res = tensorflow::wrap::hipDeviceSynchronize();
   if (res != hipSuccess) {
     LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res)
                << " :: " << port::CurrentStackTrace();
@@ -861,7 +862,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                        GpuStreamHandle stream) {
   ScopedActivateContext activated{context};
   CHECK(stream != nullptr);
-  hipError_t res = hipStreamSynchronize(stream);
+  hipError_t res = tensorflow::wrap::hipStreamSynchronize(stream);
   if (res != hipSuccess) {
     port::Status status = port::InternalError(
         absl::StrCat("could not synchronize on ROCM stream: ", ToString(res)));
@@ -877,7 +878,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                           GpuStreamHandle stream) {
   ScopedActivateContext activated{context};
   CHECK(stream != nullptr);
-  hipError_t res = hipStreamQuery(stream);
+  hipError_t res = tensorflow::wrap::hipStreamQuery(stream);
   if (res == hipSuccess) {
     return true;
   }
@@ -891,7 +892,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ port::Status GpuDriver::SynchronousMemcpyD2H(
     GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, uint64 size) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipMemcpyDtoH(host_dst, gpu_src, size);
+  hipError_t res = tensorflow::wrap::hipMemcpyDtoH(host_dst, gpu_src, size);
   if (res != hipSuccess) {
     return port::InternalError(
         absl::StrFormat("failed to synchronous memcpy from device to host: %s; "
@@ -908,7 +909,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
     GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src,
     uint64 size) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size);
+  hipError_t res = tensorflow::wrap::hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size);
   if (res != hipSuccess) {
     return port::InternalError(absl::StrFormat(
         "failed to synchronous memcpy from host to device: %s; Gpu dst: %p;"
@@ -924,7 +925,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
     GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src,
     uint64 size) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipMemcpyDtoD(gpu_dst, gpu_src, size);
+  hipError_t res = tensorflow::wrap::hipMemcpyDtoD(gpu_dst, gpu_src, size);
   if (res != hipSuccess) {
     return port::InternalError(absl::StrFormat(
         "failed to synchronous memcpy from host to device: %s; Gpu dst: %p; "
@@ -942,7 +943,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                    uint64 size,
                                                    GpuStreamHandle stream) {
   ScopedActivateContext activation{context};
-  hipError_t res = hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
+  hipError_t res = tensorflow::wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
   if (res != hipSuccess) {
     LOG(ERROR) << absl::StrFormat(
         "failed to enqueue async memcpy from device to host: %s; host dst: %p; "
@@ -964,7 +965,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                    GpuStreamHandle stream) {
   ScopedActivateContext activation{context};
   hipError_t res =
-      hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream);
+      tensorflow::wrap::hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream);
   if (res != hipSuccess) {
     LOG(ERROR) << absl::StrFormat(
         "failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; "
@@ -984,7 +985,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
                                                    uint64 size,
                                                    GpuStreamHandle stream) {
   ScopedActivateContext activation{context};
-  hipError_t result = hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
+  hipError_t result = tensorflow::wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
   if (result != hipSuccess) {
     LOG(ERROR) << absl::StrFormat(
         "failed to enqueue async memcpy from device to device: %s"
@@ -1021,7 +1022,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   }
 
   ScopedActivateContext activated{context};
-  hipError_t res = hipEventCreateWithFlags(event, hipflags);
+  hipError_t res = tensorflow::wrap::hipEventCreateWithFlags(event, hipflags);
 
   if (res == hipSuccess) {
     return port::Status::OK();
@@ -1037,7 +1038,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 
 /* static */ int GpuDriver::GetDeviceCount() {
   int device_count = 0;
-  hipError_t res = hipGetDeviceCount(&device_count);
+  hipError_t res = tensorflow::wrap::hipGetDeviceCount(&device_count);
   if (res != hipSuccess) {
     LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
     return 0;
@@ -1061,7 +1062,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 
 /* static */ port::Status GpuDriver::GetPointerAddressRange(
     hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) {
-  hipError_t result = hipMemGetAddressRange(base, size, dptr);
+  hipError_t result = tensorflow::wrap::hipMemGetAddressRange(base, size, dptr);
   if (result == hipSuccess) {
     return port::Status::OK();
   } else if (result == hipErrorNotFound) {
@@ -1106,7 +1107,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ port::StatusOr<hipDevice_t> GpuDriver::GetPointerDevice(
     hipDeviceptr_t pointer) {
   hipPointerAttribute_t pointerAttributes;
-  hipError_t result = hipPointerGetAttributes(&pointerAttributes, pointer);
+  hipError_t result = tensorflow::wrap::hipPointerGetAttributes(&pointerAttributes, pointer);
   if (result != hipSuccess) {
     return port::Status{
         port::error::INTERNAL,
@@ -1114,7 +1115,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
   }
 
   hipDevice_t device;
-  result = hipDeviceGet(&device, pointerAttributes.device);
+  result = tensorflow::wrap::hipDeviceGet(&device, pointerAttributes.device);
   if (result != hipSuccess) {
     return port::Status{
         port::error::INTERNAL,
@@ -1127,7 +1128,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
 /* static */ port::Status GpuDriver::GetGpuISAVersion(int* version,
                                                       hipDevice_t device) {
   hipDeviceProp_t props;
-  hipError_t result = hipGetDeviceProperties(&props, device);
+  hipError_t result = tensorflow::wrap::hipGetDeviceProperties(&props, device);
   if (result == hipSuccess) {
     *version = props.gcnArch;
     return port::Status::OK();
@@ -1145,7 +1146,7 @@ template <typename T>
 static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
                                             hipDeviceAttribute_t attribute) {
   int value = -1;
-  hipError_t result = hipDeviceGetAttribute(&value, attribute, device);
+  hipError_t result = tensorflow::wrap::hipDeviceGetAttribute(&value, attribute, device);
   if (result != hipSuccess) {
     return port::Status{
         port::error::NOT_FOUND,
@@ -1200,21 +1201,21 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
                                            hipDevice_t device) {
   int value;
   hipError_t res =
-      hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device);
+      tensorflow::wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
     return false;
   }
   *x = value;
 
-  res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device);
+  res = tensorflow::wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
     return false;
   }
   *y = value;
 
-  res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device);
+  res = tensorflow::wrap::hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
     return false;
@@ -1224,7 +1225,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
 }
 
 /* static */ bool GpuDriver::GetDriverVersion(int* driver_version) {
-  hipError_t res = hipDriverGetVersion(driver_version);
+  hipError_t res = tensorflow::wrap::hipDriverGetVersion(driver_version);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query driver version: " << ToString(res);
     return false;
@@ -1235,7 +1236,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
 
 /* static */ bool GpuDriver::GetDeviceProperties(
     hipDeviceProp_t* device_properties, int device_ordinal) {
-  hipError_t res = hipGetDeviceProperties(device_properties, device_ordinal);
+  hipError_t res = tensorflow::wrap::hipGetDeviceProperties(device_properties, device_ordinal);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query device properties: " << ToString(res);
     return false;
@@ -1268,7 +1269,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
   ScopedActivateContext activation{context};
   size_t free = 0;
   size_t total = 0;
-  hipError_t res = hipMemGetInfo(&free, &total);
+  hipError_t res = tensorflow::wrap::hipMemGetInfo(&free, &total);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query device memory info: " << ToString(res);
     return false;
@@ -1282,7 +1283,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
 /* static */ bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device,
                                                   uint64* result) {
   size_t value = -1;
-  hipError_t res = hipDeviceTotalMem(&value, device);
+  hipError_t res = tensorflow::wrap::hipDeviceTotalMem(&value, device);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query total available memory: " << ToString(res);
     return false;
@@ -1297,7 +1298,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
   static const int kBufferSize = 64;
   absl::InlinedVector<char, 4> chars(kBufferSize);
   chars[kBufferSize - 1] = '\0';
-  hipError_t res = hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
+  hipError_t res = tensorflow::wrap::hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
     return pci_bus_id;
@@ -1313,7 +1314,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
   }
 
   int can_access_peer = -1;
-  hipError_t res = hipDeviceCanAccessPeer(
+  hipError_t res = tensorflow::wrap::hipDeviceCanAccessPeer(
       &can_access_peer, from->device_ordinal(), to->device_ordinal());
   if (res != hipSuccess) {
     LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
@@ -1331,7 +1332,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
 
   ScopedActivateContext activated{from};
   hipError_t result =
-      hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
+      tensorflow::wrap::hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
   if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
     return port::Status{
         port::error::INTERNAL,
diff --git a/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h
new file mode 100644
index 00000000000..0a0ab3ae745
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/rocm_driver_wrapper.h
@@ -0,0 +1,147 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file wraps rocm driver calls with dso loader so that we don't need to
+// have explicit linking to librocm. All TF rocm driver usage should route
+// through this wrapper.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
+
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/platform/dso_loader.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "rocm/include/hip/hip_runtime.h"
+
+
+#if defined(TENSORFLOW_USE_ROCM)
+
+#endif
+
+namespace tensorflow {
+namespace wrap {
+#ifdef PLATFORM_GOOGLE
+// Use static linked library
+#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName)                       \
+  template <typename... Args>                                              \
+  auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
+    return ::hipSymbolName(args...);                                      \
+  }
+
+// This macro wraps a global identifier, given by hipSymbolName, in a callable
+// structure that loads the DLL symbol out of the DSO handle in a thread-safe
+// manner on first use. This dynamic loading technique is used to avoid DSO
+// dependencies on vendor libraries which may or may not be available in the
+// deployed binary environment.
+#else
+#define TO_STR_(x) #x
+#define TO_STR(x) TO_STR_(x)
+
+// hipMalloc and hipHostMalloc are defined as funtion templates in the
+// HIP header files, and hence their names get mangled and the attempt
+// to resolve their name when trying to dynamically load them will fail
+// Updating the HIP header files to make them C functions is underway.
+// Until that change flows through, we will workaround the issue by
+// creating dummy wrappers for them here
+
+hipError_t hipMallocVanilla(void** ptr, size_t size) {
+  return hipErrorNotInitialized;
+}
+  
+hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) {
+  return hipErrorNotInitialized;
+}
+
+#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName)                             \
+  template <typename... Args>                                               \
+  auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) {    \
+    using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type;     \
+    static FuncPtrT loaded = []() -> FuncPtrT {                             \
+      static const char *kName = TO_STR(hipSymbolName);                     \
+      void *f;                                                              \
+      auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
+          stream_executor::internal::CachedDsoLoader::GetHipDsoHandle()     \
+              .ValueOrDie(),                                                \
+          kName, &f);                                                       \
+      CHECK(s.ok()) << "could not find " << kName                           \
+                    << " in HIP DSO; dlerror: " << s.error_message();       \
+      return reinterpret_cast<FuncPtrT>(f);                                 \
+    }();                                                                    \
+    return loaded(args...);                                                 \
+  }
+#endif
+
+// clang-format off
+#define HIP_ROUTINE_EACH(__macro)                   \
+  __macro(hipDeviceCanAccessPeer)                   \
+  __macro(hipDeviceEnablePeerAccess)                \
+  __macro(hipDeviceGet)                             \
+  __macro(hipDeviceGetAttribute)                    \
+  __macro(hipDeviceGetName)                         \
+  __macro(hipDeviceGetPCIBusId)                     \
+  __macro(hipDeviceGetSharedMemConfig)              \
+  __macro(hipDeviceSetSharedMemConfig)              \
+  __macro(hipDeviceSynchronize)                     \
+  __macro(hipDeviceTotalMem)                        \
+  __macro(hipDriverGetVersion)                      \
+  __macro(hipEventCreateWithFlags)                  \
+  __macro(hipEventElapsedTime)                      \
+  __macro(hipEventDestroy)                          \
+  __macro(hipEventQuery)                            \
+  __macro(hipEventRecord)                           \
+  __macro(hipEventSynchronize)                      \
+  __macro(hipFree)                                  \
+  __macro(hipFuncSetCacheConfig)                    \
+  __macro(hipGetDevice)                             \
+  __macro(hipGetDeviceCount)                        \
+  __macro(hipGetDeviceProperties)                   \
+  __macro(hipHostFree)                              \
+  __macro(hipHostRegister)                          \
+  __macro(hipHostUnregister)                        \
+  __macro(hipInit)                                  \
+  __macro(hipMemGetAddressRange)                    \
+  __macro(hipMemGetInfo)                            \
+  __macro(hipMemcpyDtoD)                            \
+  __macro(hipMemcpyDtoDAsync)                       \
+  __macro(hipMemcpyDtoH)                            \
+  __macro(hipMemcpyDtoHAsync)                       \
+  __macro(hipMemcpyHtoD)                            \
+  __macro(hipMemcpyHtoDAsync)                       \
+  __macro(hipMemset)                                \
+  __macro(hipMemsetAsync)                           \
+  __macro(hipModuleGetFunction)                     \
+  __macro(hipModuleGetGlobal)                       \
+  __macro(hipModuleLaunchKernel)                    \
+  __macro(hipModuleLoadData)                        \
+  __macro(hipModuleUnload)                          \
+  __macro(hipPointerGetAttributes)                  \
+  __macro(hipSetDevice)                             \
+  __macro(hipStreamAddCallback)                     \
+  __macro(hipStreamCreateWithFlags)                 \
+  __macro(hipStreamDestroy)                         \
+  __macro(hipStreamQuery)                           \
+  __macro(hipStreamSynchronize)                     \
+  __macro(hipStreamWaitEvent)                       \
+// clang-format on
+
+HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)
+#undef HIP_ROUTINE_EACH
+#undef STREAM_EXECUTOR_HIP_WRAP
+#undef TO_STR
+#undef TO_STR_
+}  // namespace wrap
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_

From 7066b1c04dab63bd44c1450fc93032e58727e3e6 Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Fri, 1 Feb 2019 18:32:40 +0000
Subject: [PATCH 7/7] adding a check in in the initialization routines for
 rocblas, rocfft, rocrand to avoid duplicate registrations

---
 tensorflow/stream_executor/rocm/rocm_blas.cc | 64 +++++++++++---------
 tensorflow/stream_executor/rocm/rocm_fft.cc  | 47 +++++++-------
 tensorflow/stream_executor/rocm/rocm_rng.cc  | 61 ++++++++++---------
 3 files changed, 95 insertions(+), 77 deletions(-)

diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc
index a626d168c26..2137bc00275 100644
--- a/tensorflow/stream_executor/rocm/rocm_blas.cc
+++ b/tensorflow/stream_executor/rocm/rocm_blas.cc
@@ -2324,35 +2324,43 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
 }  // namespace gpu
 
 void initialize_rocblas() {
-  port::Status status =
-      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
-          rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
-          [](internal::StreamExecutorInterface* parent) -> blas::BlasSupport* {
-            gpu::GpuExecutor* rocm_executor =
-                dynamic_cast<gpu::GpuExecutor*>(parent);
-            if (rocm_executor == nullptr) {
-              LOG(ERROR)
-                  << "Attempting to initialize an instance of the rocBLAS "
-                  << "support library with a non-ROCM StreamExecutor";
-              return nullptr;
-            }
-
-            gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
-            if (!blas->Init()) {
-              // Note: Init() will log a more specific error.
-              delete blas;
-              return nullptr;
-            }
-            return blas;
-          });
-
-  if (!status.ok()) {
-    LOG(ERROR) << "Unable to register rocBLAS factory: "
-               << status.error_message();
-  }
-
-  PluginRegistry::Instance()->SetDefaultFactory(
+  auto rocBlasAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
       rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
+
+  if (!rocBlasAlreadyRegistered) {
+    port::Status status =
+        PluginRegistry::Instance()
+            ->RegisterFactory<PluginRegistry::BlasFactory>(
+                rocm::kROCmPlatformId, gpu::kRocBlasPlugin, "rocBLAS",
+                [](internal::StreamExecutorInterface* parent)
+                    -> blas::BlasSupport* {
+                  gpu::GpuExecutor* rocm_executor =
+                      dynamic_cast<gpu::GpuExecutor*>(parent);
+                  if (rocm_executor == nullptr) {
+                    LOG(ERROR)
+                        << "Attempting to initialize an instance of the "
+                           "rocBLAS "
+                        << "support library with a non-ROCM StreamExecutor";
+                    return nullptr;
+                  }
+
+                  gpu::ROCMBlas* blas = new gpu::ROCMBlas(rocm_executor);
+                  if (!blas->Init()) {
+                    // Note: Init() will log a more specific error.
+                    delete blas;
+                    return nullptr;
+                  }
+                  return blas;
+                });
+
+    if (!status.ok()) {
+      LOG(ERROR) << "Unable to register rocBLAS factory: "
+                 << status.error_message();
+    }
+
+    PluginRegistry::Instance()->SetDefaultFactory(
+        rocm::kROCmPlatformId, PluginKind::kBlas, gpu::kRocBlasPlugin);
+  }
 }
 
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/rocm/rocm_fft.cc b/tensorflow/stream_executor/rocm/rocm_fft.cc
index e8a72f61d3f..b23e05d9dde 100644
--- a/tensorflow/stream_executor/rocm/rocm_fft.cc
+++ b/tensorflow/stream_executor/rocm/rocm_fft.cc
@@ -592,28 +592,33 @@ STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
 }  // namespace gpu
 
 void initialize_rocfft() {
-  port::Status status =
-      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
-          rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
-          [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
-            gpu::GpuExecutor *rocm_executor =
-                dynamic_cast<gpu::GpuExecutor *>(parent);
-            if (rocm_executor == nullptr) {
-              LOG(ERROR)
-                  << "Attempting to initialize an instance of the rocFFT "
-                  << "support library with a non-ROCM StreamExecutor";
-              return nullptr;
-            }
-
-            return new gpu::ROCMFft(rocm_executor);
-          });
-  if (!status.ok()) {
-    LOG(ERROR) << "Unable to register rocFFT factory: "
-               << status.error_message();
-  }
-
-  PluginRegistry::Instance()->SetDefaultFactory(
+  auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
       rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
+
+  if (!rocFftAlreadyRegistered) {
+    port::Status status =
+        PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
+            rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
+            [](internal::StreamExecutorInterface* parent) -> fft::FftSupport* {
+              gpu::GpuExecutor* rocm_executor =
+                  dynamic_cast<gpu::GpuExecutor*>(parent);
+              if (rocm_executor == nullptr) {
+                LOG(ERROR)
+                    << "Attempting to initialize an instance of the rocFFT "
+                    << "support library with a non-ROCM StreamExecutor";
+                return nullptr;
+              }
+
+              return new gpu::ROCMFft(rocm_executor);
+            });
+    if (!status.ok()) {
+      LOG(ERROR) << "Unable to register rocFFT factory: "
+                 << status.error_message();
+    }
+
+    PluginRegistry::Instance()->SetDefaultFactory(
+        rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
+  }
 }
 
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/rocm/rocm_rng.cc b/tensorflow/stream_executor/rocm/rocm_rng.cc
index 27797845700..545bfc6d943 100644
--- a/tensorflow/stream_executor/rocm/rocm_rng.cc
+++ b/tensorflow/stream_executor/rocm/rocm_rng.cc
@@ -283,35 +283,40 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
 }  // namespace gpu
 
 void initialize_rocrand() {
-  port::Status status =
-      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
-          rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
-          [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
-            gpu::GpuExecutor* rocm_executor =
-                dynamic_cast<gpu::GpuExecutor*>(parent);
-            if (rocm_executor == nullptr) {
-              LOG(ERROR)
-                  << "Attempting to initialize an instance of the hipRAND "
-                  << "support library with a non-ROCM StreamExecutor";
-              return nullptr;
-            }
-
-            gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
-            if (!rng->Init()) {
-              // Note: Init() will log a more specific error.
-              delete rng;
-              return nullptr;
-            }
-            return rng;
-          });
-
-  if (!status.ok()) {
-    LOG(ERROR) << "Unable to register rocRAND factory: "
-               << status.error_message();
-  }
-
-  PluginRegistry::Instance()->SetDefaultFactory(
+  auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
       rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
+
+  if (!rocRandAlreadyRegistered) {
+    port::Status status =
+        PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
+            rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
+            [](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
+              gpu::GpuExecutor* rocm_executor =
+                  dynamic_cast<gpu::GpuExecutor*>(parent);
+              if (rocm_executor == nullptr) {
+                LOG(ERROR)
+                    << "Attempting to initialize an instance of the hipRAND "
+                    << "support library with a non-ROCM StreamExecutor";
+                return nullptr;
+              }
+
+              gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
+              if (!rng->Init()) {
+                // Note: Init() will log a more specific error.
+                delete rng;
+                return nullptr;
+              }
+              return rng;
+            });
+
+    if (!status.ok()) {
+      LOG(ERROR) << "Unable to register rocRAND factory: "
+                 << status.error_message();
+    }
+
+    PluginRegistry::Instance()->SetDefaultFactory(
+        rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
+  }
 }
 
 }  // namespace stream_executor