Adding ROCm support for the GpuSparse API (TF wrapper for cuSPARSE/hipSPARSE)

This commit is contained in:
Deven Desai 2019-11-20 15:43:03 +00:00
parent f725b46454
commit 7e8ccbd22b
3 changed files with 418 additions and 7 deletions
tensorflow/core/kernels

View File

@ -3459,14 +3459,18 @@ tf_kernel_library(
tf_kernel_library(
name = "cuda_sparse",
srcs = ["cuda_sparse.cc"],
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
hdrs = ["cuda_sparse.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:cuda_solvers",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cusparse_lib",
] + if_cuda(["@cub_archive//:cub"]),
"@cub_archive//:cub",
]) + if_rocm([
"@local_config_rocm//rocm:hipsparse",
]),
)
LINALG_DEPS = [

View File

@ -19,11 +19,13 @@ limitations under the License.
// This header declares the class GpuSparse, which contains wrappers of
// cuSparse libraries for use in TensorFlow kernels.
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <functional>
#include <vector>
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cusparse.h"
using gpusparseStatus_t = cusparseStatus_t;
@ -33,6 +35,19 @@ using gpusparseAction_t = cusparseAction_t;
using gpusparseHandle_t = cusparseHandle_t;
using gpuStream_t = cudaStream_t;
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hipsparse/hipsparse.h"
using gpusparseStatus_t = hipsparseStatus_t;
using gpusparseOperation_t = hipsparseOperation_t;
using gpusparseMatDescr_t = hipsparseMatDescr_t;
using gpusparseAction_t = hipsparseAction_t;
using gpusparseHandle_t = hipsparseHandle_t;
using gpuStream_t = hipStream_t;
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@ -55,6 +70,8 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
case err: \
return STRINGIZE(err);
#if GOOGLE_CUDA
RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
@ -65,14 +82,34 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
#undef RETURN_IF_STATUS
#undef STRINGIZE
default:
return strings::StrCat("Unknown CUSPARSE error: ",
static_cast<int>(status));
#elif TENSORFLOW_USE_ROCM
RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
default:
return strings::StrCat("Unknown hipSPARSE error: ",
static_cast<int>(status));
#endif
#undef RETURN_IF_STATUS
#undef STRINGIZE
}
}
#if GOOGLE_CUDA
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
do { \
auto status = (expr); \
@ -83,9 +120,24 @@ inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
} \
} while (0)
#elif TENSORFLOW_USE_ROCM
#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
do { \
auto status = (expr); \
if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
"): hipSPARSE call failed with status ", \
ConvertGPUSparseErrorToString(status)); \
} \
} while (0)
#endif
inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
bool conjugate,
Status* status) {
#if GOOGLE_CUDA
if (transpose) {
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
: CUSPARSE_OPERATION_TRANSPOSE;
@ -97,6 +149,19 @@ inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
}
return CUSPARSE_OPERATION_NON_TRANSPOSE;
}
#elif TENSORFLOW_USE_ROCM
if (transpose) {
return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
: HIPSPARSE_OPERATION_TRANSPOSE;
} else {
if (conjugate) {
DCHECK(status != nullptr);
*status = errors::InvalidArgument(
"Conjugate == True and transpose == False is not supported.");
}
return HIPSPARSE_OPERATION_NON_TRANSPOSE;
}
#endif
}
// The GpuSparse class provides a simplified templated API for cuSparse
@ -353,7 +418,11 @@ class GpuSparseMatrixDescriptor {
// called more than once.
Status Initialize() {
DCHECK(!initialized_);
#if GOOGLE_CUDA
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
#elif TENSORFLOW_USE_ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descr_));
#endif
initialized_ = true;
return Status::OK();
}
@ -371,7 +440,11 @@ class GpuSparseMatrixDescriptor {
private:
void Release() {
if (initialized_) {
#if GOOGLE_CUDA
cusparseDestroyMatDescr(descr_);
#elif TENSORFLOW_USE_ROCM
hipsparseDestroyMatDescr(descr_);
#endif
initialized_ = false;
}
}
@ -382,6 +455,8 @@ class GpuSparseMatrixDescriptor {
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
};
#if GOOGLE_CUDA
// A wrapper class to ensure that an unsorted/sorted CSR conversion information
// struct (csru2csrInfo_t) is initialized only once. See:
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
@ -439,8 +514,10 @@ class GpuSparseCsrSortingConversionInfo {
TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
};
} // namespace tensorflow
#endif // GOOGLE_CUDA
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_

View File

@ -0,0 +1,330 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if TENSORFLOW_USE_ROCM
#include <complex>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// A set of initialized handles to the underlying ROCm libraries used by
// GpuSparse. We maintain one such set of handles per unique stream.
class HipSparseHandles {
public:
explicit HipSparseHandles(hipStream_t stream)
: initialized_(false), stream_(stream) {}
HipSparseHandles(HipSparseHandles&& rhs)
: initialized_(rhs.initialized_),
stream_(std::move(rhs.stream_)),
hipsparse_handle_(rhs.hipsparse_handle_) {
rhs.initialized_ = false;
}
HipSparseHandles& operator=(HipSparseHandles&& rhs) {
if (this == &rhs) return *this;
Release();
stream_ = std::move(rhs.stream_);
hipsparse_handle_ = std::move(rhs.hipsparse_handle_);
initialized_ = rhs.initialized_;
rhs.initialized_ = false;
return *this;
}
~HipSparseHandles() { Release(); }
Status Initialize() {
if (initialized_) return Status::OK();
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreate(&hipsparse_handle_));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetStream(hipsparse_handle_, stream_));
initialized_ = true;
return Status::OK();
}
hipsparseHandle_t& handle() {
DCHECK(initialized_);
return hipsparse_handle_;
}
const hipsparseHandle_t& handle() const {
DCHECK(initialized_);
return hipsparse_handle_;
}
private:
void Release() {
if (initialized_) {
// This should never return anything other than success
auto err = hipsparseDestroy(hipsparse_handle_);
DCHECK(err == HIPSPARSE_STATUS_SUCCESS)
<< "Failed to destroy hipSPARSE instance.";
initialized_ = false;
}
}
bool initialized_;
hipStream_t stream_;
hipsparseHandle_t hipsparse_handle_;
TF_DISALLOW_COPY_AND_ASSIGN(HipSparseHandles);
};
// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
// lookup with one of:
// 1. Adding the handle to the CudaStream structure; do the lookup there.
// 2. Add a thread-local cusparse, set it to the current stream
// upon each call.
// #1 seems like the cleanest option but will need to wait until this
// is moved into TF core.
static mutex handle_map_mutex(LINKER_INITIALIZED);
using HandleMap = std::unordered_map<hipStream_t, HipSparseHandles>;
// Returns a singleton map used for storing initialized handles for each unique
// cuda stream.
HandleMap* GetHandleMapSingleton() {
static HandleMap* cm = new HandleMap;
return cm;
}
} // namespace
GpuSparse::GpuSparse(OpKernelContext* context)
: initialized_(false), context_(context) {
auto hip_stream_ptr =
reinterpret_cast<const hipStream_t*>(context->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack());
DCHECK(hip_stream_ptr);
gpu_stream_ = *hip_stream_ptr;
}
Status GpuSparse::Initialize() {
HandleMap* handle_map = GetHandleMapSingleton();
DCHECK(handle_map);
mutex_lock lock(handle_map_mutex);
auto it = handle_map->find(gpu_stream_);
if (it == handle_map->end()) {
LOG(INFO) << "Creating GpuSparse handles for stream " << gpu_stream_;
// Previously unseen ROCm stream. Initialize a set of ROCm sparse library
// handles for it.
HipSparseHandles new_handles(gpu_stream_);
TF_RETURN_IF_ERROR(new_handles.Initialize());
it = handle_map->insert(std::make_pair(gpu_stream_, std::move(new_handles)))
.first;
}
gpusparse_handle_ = &it->second.handle();
initialized_ = true;
return Status::OK();
}
// Macro that specializes a sparse method for all 4 standard
// numeric types.
#define TF_CALL_HIP_LAPACK_TYPES(m) m(float, S) m(double, D)
// Macros to construct hipsparse method names.
#define SPARSE_FN(method, sparse_prefix) hipsparse##sparse_prefix##method
Status GpuSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
int* csrRowPtr) const {
DCHECK(initialized_);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcoo2csr(*gpusparse_handle_, cooRowInd,
nnz, m, csrRowPtr,
HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
int* cooRowInd) const {
DCHECK(initialized_);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsr2coo(*gpusparse_handle_, csrRowPtr,
nnz, m, cooRowInd,
HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
template <typename Scalar, typename SparseFnT>
static inline Status CsrmmImpl(
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
int k, int nnz, const Scalar* alpha_host, const hipsparseMatDescr_t descrA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* B, int ldb,
const Scalar* beta_host, Scalar* C, int ldc) {
TF_RETURN_IF_GPUSPARSE_ERROR(op(hipsparse_handle, transA, transB, m, n, k,
nnz, alpha_host, descrA, csrSortedValA,
csrSortedRowPtrA, csrSortedColIndA, B, ldb,
beta_host, C, ldc));
return Status::OK();
}
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrmm<Scalar>( \
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
int k, int nnz, const Scalar* alpha_host, \
const hipsparseMatDescr_t descrA, const Scalar* csrSortedValA, \
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
const { \
DCHECK(initialized_); \
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
*gpusparse_handle_, transA, transB, m, n, k, nnz, \
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
}
TF_CALL_HIP_LAPACK_TYPES(CSRMM_INSTANCE);
template <typename Scalar, typename SparseFnT>
static inline Status CsrmvImpl(SparseFnT op, OpKernelContext* context,
hipsparseHandle_t hipsparse_handle,
hipsparseOperation_t transA, int m, int n,
int nnz, const Scalar* alpha_host,
const hipsparseMatDescr_t descrA,
const Scalar* csrSortedValA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(hipsparse_handle, transA, m, n, nnz, alpha_host, descrA, csrSortedValA,
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y));
return Status::OK();
}
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrmv<Scalar>( \
hipsparseOperation_t transA, int m, int n, int nnz, \
const Scalar* alpha_host, const hipsparseMatDescr_t descrA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
Scalar* y) const { \
DCHECK(initialized_); \
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
*gpusparse_handle_, transA, m, n, nnz, alpha_host, \
descrA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, x, beta_host, y); \
}
TF_CALL_HIP_LAPACK_TYPES(CSRMV_INSTANCE);
Status GpuSparse::CsrgemmNnz(
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
int k, const hipsparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const hipsparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseXcsrgemmNnz(
*gpusparse_handle_, transA, transB, m, n, k, descrA, nnzA,
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
return Status::OK();
}
template <typename Scalar, typename SparseFnT>
static inline Status CsrgemmImpl(
SparseFnT op, OpKernelContext* context, hipsparseHandle_t hipsparse_handle,
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n,
int k, const hipsparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(hipsparse_handle, transA, transB, m, n, k, descrA, nnzA, csrSortedValA,
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedValB,
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC,
csrSortedRowPtrC, csrSortedColIndC));
return Status::OK();
}
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrgemm<Scalar>( \
hipsparseOperation_t transA, hipsparseOperation_t transB, int m, int n, \
int k, const hipsparseMatDescr_t descrA, int nnzA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const hipsparseMatDescr_t descrB, int nnzB, \
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
const int* csrSortedColIndB, const hipsparseMatDescr_t descrC, \
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
DCHECK(initialized_); \
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
*gpusparse_handle_, transA, transB, m, n, k, descrA, \
nnzA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
csrSortedRowPtrB, csrSortedColIndB, descrC, \
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
}
TF_CALL_HIP_LAPACK_TYPES(CSRGEMM_INSTANCE);
template <typename Scalar, typename SparseFnT>
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
hipsparseHandle_t hipsparse_handle, int m,
int n, int nnz, const Scalar* csrVal,
const int* csrRowPtr, const int* csrColInd,
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
const hipsparseAction_t copyValues) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(hipsparse_handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal,
cscRowInd, cscColPtr, copyValues, HIPSPARSE_INDEX_BASE_ZERO));
return Status::OK();
}
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csr2csc<Scalar>( \
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
const hipsparseAction_t copyValues) { \
DCHECK(initialized_); \
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
}
TF_CALL_HIP_LAPACK_TYPES(CSR2CSC_INSTANCE);
} // namespace tensorflow
#endif // TENSORFLOW_USE_ROCM