[ROCm] Updates to dynamically load the ROCm "hipsparse" library
This commit is contained in:
parent
a8a50023bb
commit
35ac1e1bfe
@ -886,11 +886,11 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -940,11 +940,11 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
wrap::hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
const int m = a.dense_shape_host(0);
|
||||
|
@ -626,7 +626,7 @@ tf_kernel_library(
|
||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:hipsparse",
|
||||
"//tensorflow/stream_executor/rocm:hipsparse_wrapper",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,7 @@ using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
#include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
|
||||
|
||||
using gpusparseStatus_t = hipsparseStatus_t;
|
||||
using gpusparseOperation_t = hipsparseOperation_t;
|
||||
@ -485,7 +485,7 @@ class GpuSparseMatrixDescriptor {
|
||||
#if GOOGLE_CUDA
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
|
||||
#endif
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
@ -507,7 +507,7 @@ class GpuSparseMatrixDescriptor {
|
||||
#if GOOGLE_CUDA
|
||||
cusparseDestroyMatDescr(descr_);
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
hipsparseDestroyMatDescr(descr_);
|
||||
wrap::hipsparseDestroyMatDescr(descr_);
|
||||
#endif
|
||||
initialized_ = false;
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/cuda_sparse.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// A set of initialized handles to the underlying ROCm libraries used by
|
||||
@ -67,9 +68,9 @@ class HipSparseHandles {
|
||||
|
||||
Status Initialize() {
|
||||
if (initialized_) return Status::OK();
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreate(&hipsparse_handle_));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetStream(hipsparse_handle_, stream_));
|
||||
wrap::hipsparseSetStream(hipsparse_handle_, stream_));
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -88,7 +89,7 @@ class HipSparseHandles {
|
||||
void Release() {
|
||||
if (initialized_) {
|
||||
// This should never return anything other than success
|
||||
auto err = hipsparseDestroy(hipsparse_handle_);
|
||||
auto err = wrap::hipsparseDestroy(hipsparse_handle_);
|
||||
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
|
||||
<< "Failed to destroy hipSPARSE instance.";
|
||||
initialized_ = false;
|
||||
@ -156,23 +157,23 @@ Status GpuSparse::Initialize() {
|
||||
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
|
||||
|
||||
// Macros to construct hipsparse method names.
|
||||
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
|
||||
#define SPARSE_FN(method, sparse_prefix) wrap::hipsparse##sparse_prefix##method
|
||||
|
||||
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||
int* csrRowPtr) const {
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
|
||||
nnz, m, csrRowPtr,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
wrap::hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd, nnz, m, csrRowPtr,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
int* cooRowInd) const {
|
||||
DCHECK(initialized_);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
|
||||
nnz, m, cooRowInd,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
wrap::hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr, nnz, m, cooRowInd,
|
||||
HIPSPARSE_INDEX_BASE_ZERO));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -252,7 +253,7 @@ Status GpuSparse::CsrgemmNnz(
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseXcsrgemmNnz(
|
||||
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
|
||||
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
|
||||
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
|
@ -136,6 +136,10 @@ port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||
return GetDsoHandle("rocrand", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipsparseDsoHandle() {
|
||||
return GetDsoHandle("hipsparse", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); }
|
||||
|
||||
} // namespace DsoLoader
|
||||
@ -206,6 +210,11 @@ port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipsparseDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetHipsparseDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetHipDsoHandle());
|
||||
return *result;
|
||||
|
@ -50,6 +50,7 @@ port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||
port::StatusOr<void*> GetHipsparseDsoHandle();
|
||||
port::StatusOr<void*> GetHipDsoHandle();
|
||||
|
||||
// The following method tries to dlopen all necessary GPU libraries for the GPU
|
||||
@ -82,6 +83,7 @@ port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||
port::StatusOr<void*> GetHipsparseDsoHandle();
|
||||
port::StatusOr<void*> GetHipDsoHandle();
|
||||
} // namespace CachedDsoLoader
|
||||
|
||||
|
@ -277,6 +277,23 @@ cc_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipsparse_wrapper",
|
||||
srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
|
||||
hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_gpu_executor",
|
||||
":rocm_platform_id",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:hiprand",
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_runtime",
|
||||
copts = tf_copts(),
|
||||
|
105
tensorflow/stream_executor/rocm/hipsparse_wrapper.h
Normal file
105
tensorflow/stream_executor/rocm/hipsparse_wrapper.h
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file wraps hipsparse API calls with dso loader so that we don't need to
|
||||
// have explicit linking to libhipsparse. All TF hipsarse API usage should route
|
||||
// through this wrapper.
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
||||
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace wrap {
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
||||
#define HIPSPARSE_API_WRAPPER(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
hipsparseStatus_t operator()(Args... args) { \
|
||||
hipSparseStatus_t retval = ::__name(args...); \
|
||||
return retval; \
|
||||
} \
|
||||
} __name;
|
||||
|
||||
#else
|
||||
|
||||
#define HIPSPARSE_API_WRAPPER(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char* kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void* GetDsoHandle() { \
|
||||
auto s = \
|
||||
stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void* f; \
|
||||
auto s = \
|
||||
Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in miopen DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
} \
|
||||
static FuncPtrT DynLoad() { \
|
||||
static FuncPtrT f = LoadOrDie(); \
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
hipsparseStatus_t operator()(Args... args) { \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char* DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
#define FOREACH_HIPSPARSE_API(__macro) \
|
||||
__macro(hipsparseCreate) \
|
||||
__macro(hipsparseCreateMatDescr) \
|
||||
__macro(hipsparseDcsr2csc) \
|
||||
__macro(hipsparseDcsrgemm) \
|
||||
__macro(hipsparseDcsrmm2) \
|
||||
__macro(hipsparseDcsrmv) \
|
||||
__macro(hipsparseDestroy) \
|
||||
__macro(hipsparseDestroyMatDescr) \
|
||||
__macro(hipsparseScsr2csc) \
|
||||
__macro(hipsparseScsrgemm) \
|
||||
__macro(hipsparseScsrmm2) \
|
||||
__macro(hipsparseScsrmv) \
|
||||
__macro(hipsparseSetStream) \
|
||||
__macro(hipsparseSetMatIndexBase) \
|
||||
__macro(hipsparseSetMatType) \
|
||||
__macro(hipsparseXcoo2csr) \
|
||||
__macro(hipsparseXcsr2coo) \
|
||||
__macro(hipsparseXcsrgemmNnz)
|
||||
|
||||
// clang-format on
|
||||
|
||||
FOREACH_HIPSPARSE_API(HIPSPARSE_API_WRAPPER)
|
||||
|
||||
#undef FOREACH_HIPSPARSE_API
|
||||
#undef HIPSPARSE_API_WRAPPER
|
||||
|
||||
} // namespace wrap
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_HIPSPARSE_WRAPPER_H_
|
7
third_party/gpus/rocm/BUILD.tpl
vendored
7
third_party/gpus/rocm/BUILD.tpl
vendored
@ -108,6 +108,7 @@ cc_library(
|
||||
":rocfft",
|
||||
":hiprand",
|
||||
":miopen",
|
||||
":hipsparse",
|
||||
],
|
||||
)
|
||||
|
||||
@ -137,11 +138,9 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_import(
|
||||
cc_library(
|
||||
name = "hipsparse",
|
||||
hdrs = glob(["rocm/include/hipsparse/**",]),
|
||||
shared_library = "rocm/lib/%{hipsparse_lib}",
|
||||
visibility = ["//visibility:public"],
|
||||
data = ["rocm/lib/%{hipsparse_lib}"],
|
||||
)
|
||||
|
||||
%{copy_rules}
|
||||
|
Loading…
Reference in New Issue
Block a user