From 35ac1e1bfee2c9721a4da105253cbe9f8e475c07 Mon Sep 17 00:00:00 2001
From: Deven Desai <deven.desai.amd@gmail.com>
Date: Fri, 24 Jul 2020 18:23:03 +0000
Subject: [PATCH] [ROCm] Updates to dynamically load the ROCm "hipsparse"
 library

---
 tensorflow/core/kernels/sparse/mat_mul_op.cc  |  12 +-
 tensorflow/core/util/BUILD                    |   2 +-
 tensorflow/core/util/cuda_sparse.h            |   6 +-
 tensorflow/core/util/rocm_sparse.cc           |  23 ++--
 .../platform/default/dso_loader.cc            |   9 ++
 .../platform/default/dso_loader.h             |   2 +
 tensorflow/stream_executor/rocm/BUILD         |  17 +++
 .../stream_executor/rocm/hipsparse_wrapper.h  | 105 ++++++++++++++++++
 third_party/gpus/rocm/BUILD.tpl               |   7 +-
 9 files changed, 158 insertions(+), 25 deletions(-)
 create mode 100644 tensorflow/stream_executor/rocm/hipsparse_wrapper.h

diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc
index bf9de570fbf..799e33000ad 100644
--- a/tensorflow/core/kernels/sparse/mat_mul_op.cc
+++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc
@@ -886,11 +886,11 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
       const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
 
       gpusparseMatDescr_t descrA;
-      TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
+      TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
       TF_RETURN_IF_GPUSPARSE_ERROR(
-          hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
+          wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
       TF_RETURN_IF_GPUSPARSE_ERROR(
-          hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
+          wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
 #endif  // GOOGLE_CUDA
 
       TF_RETURN_IF_ERROR(
@@ -940,11 +940,11 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
           cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
 #elif TENSORFLOW_USE_ROCM
       gpusparseMatDescr_t descrA;
-      TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
+      TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
       TF_RETURN_IF_GPUSPARSE_ERROR(
-          hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
+          wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
       TF_RETURN_IF_GPUSPARSE_ERROR(
-          hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
+          wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
       const int m = a.dense_shape_host(0);
diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD
index dcb2787e309..0dc8f84aadf 100644
--- a/tensorflow/core/util/BUILD
+++ b/tensorflow/core/util/BUILD
@@ -626,7 +626,7 @@ tf_kernel_library(
         "//tensorflow/stream_executor/cuda:cusparse_lib",
         "@cub_archive//:cub",
     ]) + if_rocm([
-        "@local_config_rocm//rocm:hipsparse",
+        "//tensorflow/stream_executor/rocm:hipsparse_wrapper",
     ]),
 )
 
diff --git a/tensorflow/core/util/cuda_sparse.h b/tensorflow/core/util/cuda_sparse.h
index 76580766d69..cd10ba8d8cb 100644
--- a/tensorflow/core/util/cuda_sparse.h
+++ b/tensorflow/core/util/cuda_sparse.h
@@ -46,7 +46,7 @@ using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
 
 #elif TENSORFLOW_USE_ROCM
 
-#include "rocm/include/hipsparse/hipsparse.h"
+#include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
 
 using gpusparseStatus_t = hipsparseStatus_t;
 using gpusparseOperation_t = hipsparseOperation_t;
@@ -485,7 +485,7 @@ class GpuSparseMatrixDescriptor {
 #if GOOGLE_CUDA
     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
 #elif TENSORFLOW_USE_ROCM
-    TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
+    TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
 #endif
     initialized_ = true;
     return Status::OK();
@@ -507,7 +507,7 @@ class GpuSparseMatrixDescriptor {
 #if GOOGLE_CUDA
       cusparseDestroyMatDescr(descr_);
 #elif TENSORFLOW_USE_ROCM
-      hipsparseDestroyMatDescr(descr_);
+      wrap::hipsparseDestroyMatDescr(descr_);
 #endif
       initialized_ = false;
     }
diff --git a/tensorflow/core/util/rocm_sparse.cc b/tensorflow/core/util/rocm_sparse.cc
index cc7b56fdc01..22c2af780c7 100644
--- a/tensorflow/core/util/rocm_sparse.cc
+++ b/tensorflow/core/util/rocm_sparse.cc
@@ -37,6 +37,7 @@ limitations under the License.
 #include "tensorflow/core/util/cuda_sparse.h"
 
 namespace tensorflow {
+
 namespace {
 
 // A set of initialized handles to the underlying ROCm libraries used by
@@ -67,9 +68,9 @@ class HipSparseHandles {
 
   Status Initialize() {
     if (initialized_) return Status::OK();
-    TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
+    TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreate(&hipsparse_handle_));
     TF_RETURN_IF_GPUSPARSE_ERROR(
-        hipsparseSetStream(hipsparse_handle_, stream_));
+        wrap::hipsparseSetStream(hipsparse_handle_, stream_));
     initialized_ = true;
     return Status::OK();
   }
@@ -88,7 +89,7 @@ class HipSparseHandles {
   void Release() {
     if (initialized_) {
       // This should never return anything other than success
-      auto err = hipsparseDestroy(hipsparse_handle_);
+      auto err = wrap::hipsparseDestroy(hipsparse_handle_);
       DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
           << "Failed to destroy hipSPARSE instance.";
       initialized_ = false;
@@ -156,23 +157,23 @@ Status GpuSparse::Initialize() {
 #define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
 
 // Macros to construct hipsparse method names.
-#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
+#define SPARSE_FN(method, sparse_prefix) wrap::hipsparse##sparse_prefix##method
 
 Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
                           int* csrRowPtr) const {
   DCHECK(initialized_);
-  TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
-                                                 nnz, m, csrRowPtr,
-                                                 HIPSPARSE_INDEX_BASE_ZERO));
+  TF_RETURN_IF_GPUSPARSE_ERROR(
+      wrap::hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd, nnz, m, csrRowPtr,
+                              HIPSPARSE_INDEX_BASE_ZERO));
   return Status::OK();
 }
 
 Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
                           int* cooRowInd) const {
   DCHECK(initialized_);
-  TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
-                                                 nnz, m, cooRowInd,
-                                                 HIPSPARSE_INDEX_BASE_ZERO));
+  TF_RETURN_IF_GPUSPARSE_ERROR(
+      wrap::hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr, nnz, m, cooRowInd,
+                              HIPSPARSE_INDEX_BASE_ZERO));
   return Status::OK();
 }
 
@@ -252,7 +253,7 @@ Status GpuSparse::CsrgemmNnz(
     int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
   DCHECK(initialized_);
   DCHECK(nnzTotalDevHostPtr != nullptr);
-  TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
+  TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseXcsrgemmNnz(
       *gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
       csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
       csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc
index 6e0113ab05a..70b1ebe070a 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.cc
+++ b/tensorflow/stream_executor/platform/default/dso_loader.cc
@@ -136,6 +136,10 @@ port::StatusOr<void*> GetRocrandDsoHandle() {
   return GetDsoHandle("rocrand", "");
 }
 
+port::StatusOr<void*> GetHipsparseDsoHandle() {
+  return GetDsoHandle("hipsparse", "");
+}
+
 port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); }
 
 }  // namespace DsoLoader
@@ -206,6 +210,11 @@ port::StatusOr<void*> GetRocrandDsoHandle() {
   return *result;
 }
 
+port::StatusOr<void*> GetHipsparseDsoHandle() {
+  static auto result = new auto(DsoLoader::GetHipsparseDsoHandle());
+  return *result;
+}
+
 port::StatusOr<void*> GetHipDsoHandle() {
   static auto result = new auto(DsoLoader::GetHipDsoHandle());
   return *result;
diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h
index 7eee2e60785..91138f713bd 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.h
+++ b/tensorflow/stream_executor/platform/default/dso_loader.h
@@ -50,6 +50,7 @@ port::StatusOr<void*> GetRocblasDsoHandle();
 port::StatusOr<void*> GetMiopenDsoHandle();
 port::StatusOr<void*> GetRocfftDsoHandle();
 port::StatusOr<void*> GetRocrandDsoHandle();
+port::StatusOr<void*> GetHipsparseDsoHandle();
 port::StatusOr<void*> GetHipDsoHandle();
 
 // The following method tries to dlopen all necessary GPU libraries for the GPU
@@ -82,6 +83,7 @@ port::StatusOr<void*> GetRocblasDsoHandle();
 port::StatusOr<void*> GetMiopenDsoHandle();
 port::StatusOr<void*> GetRocfftDsoHandle();
 port::StatusOr<void*> GetRocrandDsoHandle();
+port::StatusOr<void*> GetHipsparseDsoHandle();
 port::StatusOr<void*> GetHipDsoHandle();
 }  // namespace CachedDsoLoader
 
diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD
index bd924125d77..bd4c45382f8 100644
--- a/tensorflow/stream_executor/rocm/BUILD
+++ b/tensorflow/stream_executor/rocm/BUILD
@@ -277,6 +277,23 @@ cc_library(
     alwayslink = True,
 )
 
+cc_library(
+    name = "hipsparse_wrapper",
+    srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
+    hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
+    deps = if_rocm_is_configured([
+        ":rocm_gpu_executor",
+        ":rocm_platform_id",
+        "@local_config_rocm//rocm:rocm_headers",
+        "//tensorflow/stream_executor/lib",
+        "//tensorflow/stream_executor/platform",
+        "//tensorflow/stream_executor/platform:dso_loader",
+    ] + if_static([
+        "@local_config_rocm//rocm:hiprand",
+    ])),
+    alwayslink = True,
+)
+
 cc_library(
     name = "all_runtime",
     copts = tf_copts(),
diff --git a/tensorflow/stream_executor/rocm/hipsparse_wrapper.h b/tensorflow/stream_executor/rocm/hipsparse_wrapper.h
new file mode 100644
index 00000000000..6444f015cf8
--- /dev/null
+++ b/tensorflow/stream_executor/rocm/hipsparse_wrapper.h
@@ -0,0 +1,105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file wraps hipsparse API calls with dso loader so that we don't need to
+// have explicit linking to libhipsparse. All TF hipsarse API usage should route
+// through this wrapper.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
+
+#include "rocm/include/hipsparse/hipsparse.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/platform/dso_loader.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace tensorflow {
+namespace wrap {
+
+#ifdef PLATFORM_GOOGLE
+
+#define HIPSPARSE_API_WRAPPER(__name)               \
+  struct WrapperShim__##__name {                    \
+    template <typename... Args>                     \
+    hipsparseStatus_t operator()(Args... args) {    \
+      hipSparseStatus_t retval = ::__name(args...); \
+      return retval;                                \
+    }                                               \
+  } __name;
+
+#else
+
+#define HIPSPARSE_API_WRAPPER(__name)                                          \
+  struct DynLoadShim__##__name {                                               \
+    static const char* kName;                                                  \
+    using FuncPtrT = std::add_pointer<decltype(::__name)>::type;               \
+    static void* GetDsoHandle() {                                              \
+      auto s =                                                                 \
+          stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
+      return s.ValueOrDie();                                                   \
+    }                                                                          \
+    static FuncPtrT LoadOrDie() {                                              \
+      void* f;                                                                 \
+      auto s =                                                                 \
+          Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f);     \
+      CHECK(s.ok()) << "could not find " << kName                              \
+                    << " in miopen DSO; dlerror: " << s.error_message();       \
+      return reinterpret_cast<FuncPtrT>(f);                                    \
+    }                                                                          \
+    static FuncPtrT DynLoad() {                                                \
+      static FuncPtrT f = LoadOrDie();                                         \
+      return f;                                                                \
+    }                                                                          \
+    template <typename... Args>                                                \
+    hipsparseStatus_t operator()(Args... args) {                               \
+      return DynLoad()(args...);                                               \
+    }                                                                          \
+  } __name;                                                                    \
+  const char* DynLoadShim__##__name::kName = #__name;
+
+#endif
+
+// clang-format off
+#define FOREACH_HIPSPARSE_API(__macro)		\
+  __macro(hipsparseCreate)			\
+  __macro(hipsparseCreateMatDescr)		\
+  __macro(hipsparseDcsr2csc)			\
+  __macro(hipsparseDcsrgemm)			\
+  __macro(hipsparseDcsrmm2)			\
+  __macro(hipsparseDcsrmv)			\
+  __macro(hipsparseDestroy)			\
+  __macro(hipsparseDestroyMatDescr)		\
+  __macro(hipsparseScsr2csc)			\
+  __macro(hipsparseScsrgemm)			\
+  __macro(hipsparseScsrmm2)			\
+  __macro(hipsparseScsrmv)			\
+  __macro(hipsparseSetStream)			\
+  __macro(hipsparseSetMatIndexBase)		\
+  __macro(hipsparseSetMatType)			\
+  __macro(hipsparseXcoo2csr)			\
+  __macro(hipsparseXcsr2coo)			\
+  __macro(hipsparseXcsrgemmNnz)
+
+// clang-format on
+
+FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER)
+
+#undef FOREACH_HIPSPARSE_API
+#undef HIPSPARSE_API_WRAPPER
+
+}  // namespace wrap
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl
index cf8950b5bc7..3c233b4f5b0 100644
--- a/third_party/gpus/rocm/BUILD.tpl
+++ b/third_party/gpus/rocm/BUILD.tpl
@@ -108,6 +108,7 @@ cc_library(
         ":rocfft",
         ":hiprand",
         ":miopen",
+        ":hipsparse",
     ],
 )
 
@@ -137,11 +138,9 @@ cc_library(
     ],
 )
 
-cc_import(
+cc_library(
     name = "hipsparse",
-    hdrs = glob(["rocm/include/hipsparse/**",]),
-    shared_library = "rocm/lib/%{hipsparse_lib}",
-    visibility = ["//visibility:public"],
+    data = ["rocm/lib/%{hipsparse_lib}"],
 )
 
 %{copy_rules}