Merge pull request #38802 from nluehr:cusparse_remove_deprecated

PiperOrigin-RevId: 310182962
Change-Id: Ic9c39c151181a99fa76981b4724d2e0f51fa5f69
This commit is contained in:
TensorFlower Gardener 2020-05-06 11:05:02 -07:00
commit 10b96cd214
6 changed files with 813 additions and 46 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <vector>
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cuda/include/library_types.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@ -179,6 +180,10 @@ Status GpuSparse::Initialize() {
return Status::OK();
}
#define TF_CALL_CUSPARSE_DTYPES(m) \
m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
// Macro that specializes a sparse method for all 4 standard
// numeric types.
// TODO: reuse with cuda_solvers
@ -359,23 +364,30 @@ Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
return Status::OK();
}
Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB,
const int* csrSortedColIndB,
const cusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
Status GpuSparse::CsrgeamNnz(
int m, int n, const cusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
#if CUDA_VERSION >= 10000
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeam2Nnz(
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr, workspace));
#else
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
#endif
return Status::OK();
}
#if CUDA_VERSION < 10020
template <typename Scalar, typename SparseFnT>
static inline Status CsrmmImpl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
@ -416,6 +428,45 @@ static inline Status CsrmmImpl(
TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
#else
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
template <> \
Status GpuSparse::SpMMBufferSize<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, \
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
const { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \
dtype, alg, bufferSize)); \
return Status::OK(); \
}
TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
#define SPMM_INSTANCE(Scalar, dtype) \
template <> \
Status GpuSparse::SpMM<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, \
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \
transB, alpha, matA, matB, beta, \
matC, dtype, alg, buffer)); \
return Status::OK(); \
}
TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
#endif
#if CUDA_VERSION < 10020
template <typename Scalar, typename SparseFnT>
static inline Status CsrmvImpl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
@ -455,6 +506,115 @@ static inline Status CsrmvImpl(
TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
#else
template <typename Scalar>
static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
cusparseHandle_t cusparse_handle,
cusparseOperation_t transA, int m, int n,
int nnz, const Scalar* alpha_host,
const Scalar* csrSortedValA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) {
cusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
// CUSPARSE_ALG_MERGE_PATH algo only supports non-transpose matrix.
DCHECK(transA == CUSPARSE_OPERATION_NON_TRANSPOSE);
size_t bufferSize;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(context->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
auto pBuffer = buffer.flat<int8>();
DCHECK(pBuffer.data() != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
return Status::OK();
}
template <typename Scalar>
static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
cusparseHandle_t cusparse_handle,
cusparseOperation_t transA, int m, int n, int nnz,
const Scalar* alpha_host,
const Scalar* csrSortedValA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) {
cusparseSpMatDescr_t matA;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
&matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
cusparseDnVecDescr_t vecX, vecY;
int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
size_t bufferSize;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
CUSPARSE_CSRMV_ALG1, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(context->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
auto pBuffer = buffer.flat<int8>();
DCHECK(pBuffer.data() != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
return Status::OK();
}
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
template <> \
Status GpuSparse::Csrmv<Scalar>( \
cusparseOperation_t transA, int m, int n, int nnz, \
const Scalar* alpha_host, const Scalar* csrSortedValA, \
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
DCHECK(initialized_); \
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \
m, n, nnz, alpha_host, csrSortedValA, \
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
} else { \
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \
n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, x, beta_host, y); \
} \
}
TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
#endif // CUDA_VERSION < 10020
#if CUDA_VERSION < 10000
template <typename Scalar, typename SparseFnT>
static inline Status CsrgeamImpl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
@ -483,7 +643,7 @@ static inline Status CsrgeamImpl(
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
int* csrSortedRowPtrC, int* csrSortedColIndC) { \
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
DCHECK(initialized_); \
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
@ -493,8 +653,113 @@ static inline Status CsrgeamImpl(
csrSortedRowPtrC, csrSortedColIndC); \
}
#else
template <typename Scalar, typename SparseFnT>
static inline Status Csrgeam2Impl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* beta,
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
TF_RETURN_IF_GPUSPARSE_ERROR(op(
cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
csrSortedRowPtrC, csrSortedColIndC, workspace));
return Status::OK();
}
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrgeam<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* beta, \
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
DCHECK(initialized_); \
return Csrgeam2Impl(SPARSE_FN(csrgeam2, sparse_prefix), context_, \
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
csrSortedColIndB, descrC, csrSortedValC, \
csrSortedRowPtrC, csrSortedColIndC, workspace); \
}
#endif
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
#if CUDA_VERSION < 10000
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* beta, \
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
DCHECK(initialized_); \
*bufferSize = 0; \
return Status::OK(); \
}
#else
template <typename Scalar, typename SparseFnT>
static inline Status CsrgeamBufferSizeExtImpl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t sparse_handle,
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* beta,
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
TF_RETURN_IF_GPUSPARSE_ERROR(op(
sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
csrSortedRowPtrC, csrSortedColIndC, bufferSize));
return Status::OK();
}
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* beta, \
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
DCHECK(initialized_); \
return CsrgeamBufferSizeExtImpl( \
SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \
*gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \
csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \
csrSortedRowPtrC, csrSortedColIndC, bufferSize); \
}
#endif
TF_CALL_LAPACK_TYPES(CSRGEAM_BUFFERSIZE_INSTANCE);
#if CUDA_VERSION < 10000
Status GpuSparse::CsrgemmNnz(
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
@ -551,6 +816,101 @@ static inline Status CsrgemmImpl(
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
#else
template <typename T>
static const T* one_ptr() {
static const T one = static_cast<T>(1);
return &one;
}
template <typename T>
static const T* null_ptr() {
return nullptr;
}
#define CSRGEMM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::CsrgemmBufferSize<Scalar>( \
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, \
const int* csrSortedColIndB, csrgemm2Info_t info, \
size_t* workspaceBytes) { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \
sparse_prefix)( \
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes)); \
return Status::OK(); \
}
TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
Status GpuSparse::CsrgemmNnz(
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
void* workspace) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
*gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
nnzTotalDevHostPtr, info, workspace));
return Status::OK();
}
template <typename Scalar, typename SparseFnT>
static inline Status CsrgemmImpl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
const csrgemm2Info_t info, void* workspace) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
csrSortedColIndC, info, workspace));
return Status::OK();
}
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrgemm<Scalar>( \
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC, \
const csrgemm2Info_t info, void* workspace) { \
DCHECK(initialized_); \
return CsrgemmImpl(SPARSE_FN(csrgemm2, sparse_prefix), context_, \
*gpusparse_handle_, m, n, k, descrA, nnzA, \
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
csrSortedColIndB, descrC, csrSortedValC, \
csrSortedRowPtrC, csrSortedColIndC, info, workspace); \
}
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
#endif // CUDA_VERSION < 10000
template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
OpKernelContext* context,
@ -596,6 +956,8 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
#if CUDA_VERSION < 10010
template <typename Scalar, typename SparseFnT>
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
cusparseHandle_t cusparse_handle, int m, int n,
@ -624,6 +986,53 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
#else
template <typename Scalar>
static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
cusparseHandle_t cusparse_handle, int m, int n,
int nnz, const Scalar* csrVal,
const int* csrRowPtr, const int* csrColInd,
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
const cusparseAction_t copyValues) {
size_t bufferSize;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(context->allocate_temp(
DataTypeToEnum<Scalar>::value,
TensorShape({static_cast<int64>(bufferSize)}), &buffer));
DCHECK(buffer.flat<Scalar>().data() != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
return Status::OK();
}
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
template <> \
Status GpuSparse::Csr2csc<Scalar>( \
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
const cusparseAction_t copyValues) { \
DCHECK(initialized_); \
return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \
cscColPtr, copyValues); \
}
TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);
#endif // CUDA_VERSION < 10010
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -26,6 +26,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cusparse.h"
using gpusparseStatus_t = cusparseStatus_t;
@ -34,6 +35,11 @@ using gpusparseMatDescr_t = cusparseMatDescr_t;
using gpusparseAction_t = cusparseAction_t;
using gpusparseHandle_t = cusparseHandle_t;
using gpuStream_t = cudaStream_t;
#if CUDA_VERSION >= 10020
using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
#endif
#define GPUSPARSE(postfix) CUSPARSE_##postfix
#define gpusparse(postfix) cusparse##postfix
@ -253,6 +259,7 @@ class GpuSparse {
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
#if CUDA_VERSION < 10020
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
// where A is a sparse matrix in CSR format, B and C are dense tall
// matrices. This routine allows transposition of matrix B, which
@ -272,18 +279,64 @@ class GpuSparse {
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
int ldc) const;
#else
// Workspace size query for sparse-dense matrix multiplication. Helper
// function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
// where A is a sparse matrix in CSR format, B and C are dense matricies in
// column-major format. Returns needed workspace size in bytes.
template <typename Scalar>
Status SpMMBufferSize(gpusparseOperation_t transA,
gpusparseOperation_t transB, const Scalar* alpha,
const gpusparseSpMatDescr_t matA,
const gpusparseDnMatDescr_t matB, const Scalar* beta,
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
size_t* bufferSize) const;
// Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
// where A is a sparse matrix in CSR format, B and C are dense matricies in
// column-major format. Buffer is assumed to be at least as large as the
// workspace size returned by SpMMBufferSize().
//
// **NOTE** This is an in-place operation for data in C.
template <typename Scalar>
Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
const Scalar* alpha, const gpusparseSpMatDescr_t matA,
const gpusparseDnMatDescr_t matB, const Scalar* beta,
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
int8* buffer) const;
#endif
// Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
// where A is a sparse matrix in CSR format, x and y are dense vectors. See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
//
// **NOTE** This is an in-place operation for data in y.
#if CUDA_VERSION < 10020
template <typename Scalar>
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) const;
#else
template <typename Scalar>
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
const Scalar* alpha_host, const Scalar* csrSortedValA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const Scalar* x, const Scalar* beta_host, Scalar* y) const;
#endif // CUDA_VERSION < 10020
// Computes workspace size for sparse - sparse matrix addition of matrices
// stored in CSR format.
template <typename Scalar>
Status CsrgeamBufferSizeExt(
int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* beta,
const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
// Computes sparse-sparse matrix addition of matrices
// stored in CSR format. This is part one: calculate nnz of the
@ -295,7 +348,7 @@ class GpuSparse {
const gpusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
int* nnzTotalDevHostPtr);
int* nnzTotalDevHostPtr, void* workspace);
// Computes sparse - sparse matrix addition of matrices
// stored in CSR format. This is part two: perform sparse-sparse
@ -311,13 +364,26 @@ class GpuSparse {
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC);
int* csrSortedColIndC, void* workspace);
#if CUDA_VERSION >= 10000
// Computes sparse-sparse matrix multiplication of matrices
// stored in CSR format. This is part zero: calculate required workspace
// size.
template <typename Scalar>
Status CsrgemmBufferSize(
int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
#endif
// Computes sparse-sparse matrix multiplication of matrices
// stored in CSR format. This is part one: calculate nnz of the
// output. csrSortedRowPtrC must be preallocated on device with
// m + 1 entries. See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
#if CUDA_VERSION < 10000
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
int m, int k, int n, const gpusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA,
@ -326,12 +392,23 @@ class GpuSparse {
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
int* nnzTotalDevHostPtr);
#else
Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA,
const gpusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
int* nnzTotalDevHostPtr, csrgemm2Info_t info,
void* workspace);
#endif
// Computes sparse - sparse matrix matmul of matrices
// stored in CSR format. This is part two: perform sparse-sparse
// addition. csrValC and csrColIndC must be allocated on the device
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
#if CUDA_VERSION < 10000
template <typename Scalar>
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
int m, int k, int n, const gpusparseMatDescr_t descrA,
@ -342,6 +419,18 @@ class GpuSparse {
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC);
#else
template <typename Scalar>
Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const gpusparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC, const csrgemm2Info_t info,
void* workspace);
#endif
// In-place reordering of unsorted CSR to sorted CSR.
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr

View File

@ -107,6 +107,26 @@ class CSRSparseMatrixAddFunctor {
const Device& d = ctx_->eigen_device<Device>();
set_zero(d, c_row_ptr_t.flat<int32>());
size_t maxWorkspaceSize = 0;
for (int i = 0; i < batch_size; ++i) {
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
a.values_vec<T>(i), a_dense_shape};
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
b.values_vec<T>(i), b_dense_shape};
size_t thisWorkspaceSize;
TF_RETURN_IF_ERROR(
csr_geam.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
if (thisWorkspaceSize > maxWorkspaceSize) {
maxWorkspaceSize = thisWorkspaceSize;
}
}
Tensor temp;
TF_RETURN_IF_ERROR(ctx_->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}), &temp));
void* workspace = temp.flat<int8>().data();
for (int i = 0; i < batch_size; ++i) {
// Calculate output sizes for all minibatch entries.
// Store in c_batch_ptr and update c_row_ptrs.
@ -121,8 +141,8 @@ class CSRSparseMatrixAddFunctor {
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
rows + 1);
int c_nnz_i;
TF_RETURN_IF_ERROR(
csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i));
TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure(
a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace));
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
}
@ -151,7 +171,7 @@ class CSRSparseMatrixAddFunctor {
CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i),
c->values_vec<T>(i), c_dense_shape_t.vec<int64>()};
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp));
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace));
}
return Status::OK();
@ -269,10 +289,36 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
return Status::OK();
}
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b, size_t* bufferSize) {
DCHECK(initialized_);
const int m = a.row_ptr.size() - 1;
DCHECK_EQ(m, b.row_ptr.size() - 1);
const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
DCHECK_EQ(m, a.dense_shape_host(row_dim));
DCHECK_EQ(m, b.dense_shape_host(row_dim));
const int nnzA = a.col_ind.size();
const int nnzB = b.col_ind.size();
const int n = a.dense_shape_host(row_dim + 1);
DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
T* null_T = nullptr;
int* null_int = nullptr;
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt(
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int,
null_int, bufferSize));
return Status::OK();
}
Status GetOutputStructure(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
TTypes<int32>::UnalignedVec c_row_ptr,
int* output_nnz) {
int* output_nnz, void* workspace) {
DCHECK(initialized_);
const int m = a.row_ptr.size() - 1;
@ -290,7 +336,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz(
m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
descrC_.descr(), c_row_ptr.data(), output_nnz));
descrC_.descr(), c_row_ptr.data(), output_nnz, workspace));
if (*output_nnz < 0) {
return errors::Internal(
@ -300,7 +346,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
}
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
CSRComponent<T>* c) {
CSRComponent<T>* c, void* workspace) {
DCHECK(initialized_);
const int m = a.row_ptr.size() - 1;
@ -319,7 +365,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
c->row_ptr.data(), c->col_ind.data()));
c->row_ptr.data(), c->col_ind.data(), workspace));
return Status::OK();
}

View File

@ -167,13 +167,18 @@ struct CSRStructureModifyingFunctor {
virtual Status Initialize() = 0;
virtual Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
size_t* bufferSize) = 0;
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
TTypes<int32>::UnalignedVec c_row_ptr,
int* output_nnz) = 0;
int* output_nnz, void* workspace) = 0;
virtual Status Compute(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b, CSRComponent<T>* c) = 0;
const ConstCSRComponent<T>& b, CSRComponent<T>* c,
void* workspace) = 0;
};
// Calculates C = alpha * A + beta * B, where A and B are in CSR

View File

@ -721,6 +721,56 @@ REGISTER_GPU(complex128)
namespace functor {
namespace {
// CUDADataType<T>::type translates from a C++ type (e.g. float) to a
// cudaDataType_t (e.g. CUDA_R_32F).
template <typename T>
struct CUDADataType;
template <>
struct CUDADataType<Eigen::half> {
static constexpr cudaDataType_t type = CUDA_R_16F;
};
template <>
struct CUDADataType<float> {
#if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_R_32F;
#elif TENSORFLOW_USE_ROCM
static constexpr cudaDataType_t type = HIPBLAS_R_32F;
#endif
};
template <>
struct CUDADataType<std::complex<float>> {
#if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_C_32F;
#elif TENSORFLOW_USE_ROCM
static constexpr cudaDataType_t type = HIPBLAS_C_32F;
#endif
};
template <>
struct CUDADataType<double> {
#if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_R_64F;
#elif TENSORFLOW_USE_ROCM
static constexpr cudaDataType_t type = HIPBLAS_R_64F;
#endif
};
template <>
struct CUDADataType<std::complex<double>> {
#if GOOGLE_CUDA
static constexpr cudaDataType_t type = CUDA_C_64F;
#elif TENSORFLOW_USE_ROCM
static constexpr cudaDataType_t type = HIPBLAS_C_64F;
#endif
};
} // namespace
template <typename T>
class CSRSparseMatrixMatMul<GPUDevice, T> {
public:
@ -733,10 +783,10 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
GpuSparse cuda_sparse(ctx);
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
{
// Use Csrmm to calculate:
// Use Csrmm/SpMM to calculate:
// C = alpha * op(A) * op(B) + beta * C
// where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
// Note that Csrmm assumes B and C are in column-major form; so we
// Note that Csrmm/Spmm assumes B and C are in column-major form; so we
// use transB == true, and manually transpose the output in place
// using blas<t>geam.
// TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
@ -746,22 +796,6 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
const T alpha = 1;
const T beta = 0;
// transA must be non-transpose if transB is transpose (cusparse
// limitation).
const gpusparseOperation_t transA = GPUSPARSE(OPERATION_NON_TRANSPOSE);
// transB: b is row-major, and cusparse requires col-major b (or
// equivalently transB == transpose). this version is actually more
// efficient.
const gpusparseOperation_t transB = GPUSPARSE(OPERATION_TRANSPOSE);
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
TF_RETURN_IF_GPUSPARSE_ERROR(
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
const int k = b.dimension(0);
DCHECK_EQ(k, a.dense_shape_host(1));
@ -786,10 +820,87 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
// op(A) = A and at least max(1, k) otherwise.
const int ldc = m;
// transA must be non-transpose if transB is transpose (cusparse
// limitation).
#if GOOGLE_CUDA
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
#elif TENSORFLOW_USE_ROCM
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
#endif
// transB: b is row-major, and cusparse requires col-major b (or
// equivalently transB == transpose). this version is actually more
// efficient.
#if GOOGLE_CUDA && CUDA_VERSION >= 10020
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
gpusparseSpMatDescr_t matA;
gpusparseDnMatDescr_t matB, matC;
// NOTE: the following APIs are not available in ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
&matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
CUDADataType<T>::type));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCreateDnMat(&matC, m, n, ldc, c.data(), CUDADataType<T>::type,
CUSPARSE_ORDER_COL));
size_t bufferSize = 0;
TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
transA, transB, &alpha, matA, matB, &beta, matC,
CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
DCHECK(buffer.flat<int8>().data() != nullptr);
TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
&beta, matC, CUSPARSE_MM_ALG_DEFAULT,
buffer.flat<int8>().data()));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
#else
#if GOOGLE_CUDA
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA
TF_RETURN_IF_ERROR(
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
b.data(), ldb, &beta, c.data(), ldc));
#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020
}
return Status::OK();
@ -822,20 +933,35 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
const T alpha = 1;
const T beta = 0;
#if GOOGLE_CUDA && CUDA_VERSION < 10020
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
#elif TENSORFLOW_USE_ROCM
gpusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
TF_RETURN_IF_GPUSPARSE_ERROR(
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const int m = a.dense_shape_host(0);
const int n = a.dense_shape_host(1);
const int nnz = a.values.size();
DCHECK_EQ(nnz, a.col_ind.size());
#if CUDA_VERSION >= 10020
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
a.values.data(), a.row_ptr.data(),
a.col_ind.data(), x, &beta, y));
#else
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
a.values.data(), a.row_ptr.data(),
a.col_ind.data(), x, &beta, y));
#endif
}
return Status::OK();

View File

@ -417,6 +417,36 @@ class CSRSparseMatMulGPUOp : public OpKernel {
}
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
#if CUDA_VERSION >= 10000
size_t maxWorkspaceSize = 0;
for (int i = 0; i < batch_size; ++i) {
// Calculate maximum workspace size over batch.
ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
a_input_matrix->col_indices_vec(i),
a_input_matrix->values_vec<T>(i),
a_input_dense_shape};
ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
b_input_matrix->col_indices_vec(i),
b_input_matrix->values_vec<T>(i),
b_input_dense_shape};
size_t thisWorkspaceSize;
OP_REQUIRES_OK(
ctx, csr_gemm.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
if (thisWorkspaceSize > maxWorkspaceSize) {
maxWorkspaceSize = thisWorkspaceSize;
}
}
Tensor temp;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}),
&temp));
void* workspace = temp.flat<int8>().data();
#else
void* workspace = nullptr;
#endif
for (int i = 0; i < batch_size; ++i) {
// Calculate output sizes for all minibatch entries.
// Store in c_batch_ptr and update c_row_ptrs.
@ -433,8 +463,9 @@ class CSRSparseMatMulGPUOp : public OpKernel {
rows + 1);
int c_nnz_i;
OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp,
c_row_ptr_i, &c_nnz_i));
OP_REQUIRES_OK(ctx,
csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
&c_nnz_i, workspace));
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
}
@ -464,7 +495,7 @@ class CSRSparseMatMulGPUOp : public OpKernel {
b_input_dense_shape};
CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
c.values_vec<T>(i), c_dense_shape};
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp));
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp, workspace));
}
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
@ -527,7 +558,12 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
initialized_(false),
transpose_a_(transpose_a),
adjoint_a_(adjoint_a),
#if CUDA_VERSION < 10000
transpose_b_(transpose_b) {
#else
transpose_b_(transpose_b),
info_(nullptr) {
#endif // CUDA_VERSION < 10000
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
transA_ = transpose_a
? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
@ -537,6 +573,14 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
}
#if CUDA_VERSION >= 10000
~CSRSparseSparseMatrixMatMul() {
if (initialized_) {
cusparseDestroyCsrgemm2Info(info_);
}
}
#endif
Status Initialize() {
if (adjoint_a_ && transpose_a_) {
return errors::InvalidArgument(
@ -547,14 +591,44 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
TF_RETURN_IF_ERROR(descrA_.Initialize());
TF_RETURN_IF_ERROR(descrB_.Initialize());
TF_RETURN_IF_ERROR(descrC_.Initialize());
#if CUDA_VERSION >= 10000
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
#endif
initialized_ = true;
return Status::OK();
}
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b, size_t* bufferSize) {
DCHECK(initialized_);
const int m =
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
if (!transpose_a_) {
DCHECK_EQ(m, a.row_ptr.size() - 1);
}
const int k =
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
if (!transpose_b_) {
DCHECK_EQ(k, b.row_ptr.size() - 1);
}
const int nnzA = a.col_ind.size();
const int nnzB = b.col_ind.size();
const int n =
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
bufferSize));
return Status::OK();
}
Status GetOutputStructure(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
TTypes<int32>::UnalignedVec c_row_ptr,
int* output_nnz) {
int* output_nnz, void* workspace) {
DCHECK(initialized_);
const int m =
@ -576,10 +650,17 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
*output_nnz = -1;
#if CUDA_VERSION < 10000
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
#else
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
#endif
if (*output_nnz < 0) {
return errors::Internal(
@ -590,7 +671,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
}
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
CSRComponent<T>* c) {
CSRComponent<T>* c, void* workspace) {
DCHECK(initialized_);
const int m =
@ -612,11 +693,19 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
#if CUDA_VERSION < 10000
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
#else
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
c->row_ptr.data(), c->col_ind.data(), info_, workspace));
#endif
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
// columns are sorted? Above operation leads to unsorted columns.
@ -643,6 +732,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
GpuSparseMatrixDescriptor descrC_;
gpusparseOperation_t transA_;
gpusparseOperation_t transB_;
#if CUDA_VERSION >= 10000
csrgemm2Info_t info_;
#endif
};
} // namespace functor