[ROCm] Updates to dynamically load the ROCm "hipsparse" library

This commit is contained in:
Deven Desai 2020-07-24 18:23:03 +00:00
parent a8a50023bb
commit 35ac1e1bfe
9 changed files with 158 additions and 25 deletions

View File

@ -886,11 +886,11 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA
TF_RETURN_IF_ERROR(
@ -940,11 +940,11 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const int m = a.dense_shape_host(0);

View File

@ -626,7 +626,7 @@ tf_kernel_library(
"//tensorflow/stream_executor/cuda:cusparse_lib",
"@cub_archive//:cub",
]) + if_rocm([
"@local_config_rocm//rocm:hipsparse",
"//tensorflow/stream_executor/rocm:hipsparse_wrapper",
]),
)

View File

@ -46,7 +46,7 @@ using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hipsparse/hipsparse.h"
#include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
using gpusparseStatus_t = hipsparseStatus_t;
using gpusparseOperation_t = hipsparseOperation_t;
@ -485,7 +485,7 @@ class GpuSparseMatrixDescriptor {
#if GOOGLE_CUDA
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
#elif TENSORFLOW_USE_ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
#endif
initialized_ = true;
return Status::OK();
@ -507,7 +507,7 @@ class GpuSparseMatrixDescriptor {
#if GOOGLE_CUDA
cusparseDestroyMatDescr(descr_);
#elif TENSORFLOW_USE_ROCM
hipsparseDestroyMatDescr(descr_);
wrap::hipsparseDestroyMatDescr(descr_);
#endif
initialized_ = false;
}

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/util/cuda_sparse.h"
namespace tensorflow {
namespace {
// A set of initialized handles to the underlying ROCm libraries used by
@ -67,9 +68,9 @@ class HipSparseHandles {
Status Initialize() {
if (initialized_) return Status::OK();
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreate(&hipsparse_handle_));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetStream(hipsparse_handle_, stream_));
wrap::hipsparseSetStream(hipsparse_handle_, stream_));
initialized_ = true;
return Status::OK();
}
@ -88,7 +89,7 @@ class HipSparseHandles {
void Release() {
if (initialized_) {
// This should never return anything other than success
auto err = hipsparseDestroy(hipsparse_handle_);
auto err = wrap::hipsparseDestroy(hipsparse_handle_);
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
<< "Failed to destroy hipSPARSE instance.";
initialized_ = false;
@ -156,23 +157,23 @@ Status GpuSparse::Initialize() {
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
// Macros to construct hipsparse method names.
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
#define SPARSE_FN(method, sparse_prefix) wrap::hipsparse##sparse_prefix##method
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
int* csrRowPtr) const {
DCHECK(initialized_);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
nnz, m, csrRowPtr,
HIPSPARSE_INDEX_BASE_ZERO));
TF_RETURN_IF_GPUSPARSE_ERROR(
wrap::hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd, nnz, m, csrRowPtr,
HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
int* cooRowInd) const {
DCHECK(initialized_);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
nnz, m, cooRowInd,
HIPSPARSE_INDEX_BASE_ZERO));
TF_RETURN_IF_GPUSPARSE_ERROR(
wrap::hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr, nnz, m, cooRowInd,
HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
@ -252,7 +253,7 @@ Status GpuSparse::CsrgemmNnz(
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseXcsrgemmNnz(
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));

View File

@ -136,6 +136,10 @@ port::StatusOr<void*> GetRocrandDsoHandle() {
return GetDsoHandle("rocrand", "");
}
port::StatusOr<void*> GetHipsparseDsoHandle() {
return GetDsoHandle("hipsparse", "");
}
port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); }
} // namespace DsoLoader
@ -206,6 +210,11 @@ port::StatusOr<void*> GetRocrandDsoHandle() {
return *result;
}
port::StatusOr<void*> GetHipsparseDsoHandle() {
static auto result = new auto(DsoLoader::GetHipsparseDsoHandle());
return *result;
}
port::StatusOr<void*> GetHipDsoHandle() {
static auto result = new auto(DsoLoader::GetHipDsoHandle());
return *result;

View File

@ -50,6 +50,7 @@ port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();
port::StatusOr<void*> GetRocfftDsoHandle();
port::StatusOr<void*> GetRocrandDsoHandle();
port::StatusOr<void*> GetHipsparseDsoHandle();
port::StatusOr<void*> GetHipDsoHandle();
// The following method tries to dlopen all necessary GPU libraries for the GPU
@ -82,6 +83,7 @@ port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();
port::StatusOr<void*> GetRocfftDsoHandle();
port::StatusOr<void*> GetRocrandDsoHandle();
port::StatusOr<void*> GetHipsparseDsoHandle();
port::StatusOr<void*> GetHipDsoHandle();
} // namespace CachedDsoLoader

View File

@ -277,6 +277,23 @@ cc_library(
alwayslink = True,
)
cc_library(
name = "hipsparse_wrapper",
srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
deps = if_rocm_is_configured([
":rocm_gpu_executor",
":rocm_platform_id",
"@local_config_rocm//rocm:rocm_headers",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
] + if_static([
"@local_config_rocm//rocm:hiprand",
])),
alwayslink = True,
)
cc_library(
name = "all_runtime",
copts = tf_copts(),

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file wraps hipsparse API calls with dso loader so that we don't need to
// have explicit linking to libhipsparse. All TF hipsarse API usage should route
// through this wrapper.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
#include "rocm/include/hipsparse/hipsparse.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/port.h"
namespace tensorflow {
namespace wrap {
#ifdef PLATFORM_GOOGLE
#define HIPSPARSE_API_WRAPPER(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
hipsparseStatus_t operator()(Args... args) { \
hipSparseStatus_t retval = ::__name(args...); \
return retval; \
} \
} __name;
#else
#define HIPSPARSE_API_WRAPPER(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = \
stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = \
Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in miopen DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hipsparseStatus_t operator()(Args... args) { \
return DynLoad()(args...); \
} \
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
#endif
// clang-format off
#define FOREACH_HIPSPARSE_API(__macro) \
__macro(hipsparseCreate) \
__macro(hipsparseCreateMatDescr) \
__macro(hipsparseDcsr2csc) \
__macro(hipsparseDcsrgemm) \
__macro(hipsparseDcsrmm2) \
__macro(hipsparseDcsrmv) \
__macro(hipsparseDestroy) \
__macro(hipsparseDestroyMatDescr) \
__macro(hipsparseScsr2csc) \
__macro(hipsparseScsrgemm) \
__macro(hipsparseScsrmm2) \
__macro(hipsparseScsrmv) \
__macro(hipsparseSetStream) \
__macro(hipsparseSetMatIndexBase) \
__macro(hipsparseSetMatType) \
__macro(hipsparseXcoo2csr) \
__macro(hipsparseXcsr2coo) \
__macro(hipsparseXcsrgemmNnz)
// clang-format on
FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER)
#undef FOREACH_HIPSPARSE_API
#undef HIPSPARSE_API_WRAPPER
} // namespace wrap
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_

View File

@ -108,6 +108,7 @@ cc_library(
":rocfft",
":hiprand",
":miopen",
":hipsparse",
],
)
@ -137,11 +138,9 @@ cc_library(
],
)
cc_import(
cc_library(
name = "hipsparse",
hdrs = glob(["rocm/include/hipsparse/**",]),
shared_library = "rocm/lib/%{hipsparse_lib}",
visibility = ["//visibility:public"],
data = ["rocm/lib/%{hipsparse_lib}"],
)
%{copy_rules}