Cleanup formatting with clang-format

This commit is contained in:
Nathan Luehr 2020-05-04 17:40:14 -05:00
parent 7e07d1fe65
commit 54e5bedaaf
6 changed files with 223 additions and 254 deletions

View File

@ -23,8 +23,6 @@ limitations under the License.
#include <utility>
#include <vector>
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cuda/include/library_types.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@ -38,6 +36,8 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cuda/include/library_types.h"
// TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
@ -180,11 +180,9 @@ Status GpuSparse::Initialize() {
return Status::OK();
}
#define TF_CALL_CUSPARSE_DTYPES(m) \
m(float, CUDA_R_32F) \
m(double, CUDA_R_64F) \
m(std::complex<float>, CUDA_C_32F) \
m(std::complex<double>, CUDA_C_64F)
#define TF_CALL_CUSPARSE_DTYPES(m) \
m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
// Macro that specializes a sparse method for all 4 standard
// numeric types.
@ -366,15 +364,12 @@ Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
return Status::OK();
}
Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
int nnzA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB,
const int* csrSortedColIndB,
const cusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr,
void* workspace) {
Status GpuSparse::CsrgeamNnz(
int m, int n, const cusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
#if CUDA_VERSION >= 10000
@ -435,36 +430,35 @@ TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
#else
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
template<> \
Status GpuSparse::SpMMBufferSize<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, \
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, \
size_t* bufferSize) const { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, \
matC, dtype, alg, bufferSize)); \
return Status::OK(); \
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
template <> \
Status GpuSparse::SpMMBufferSize<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, \
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
const { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \
dtype, alg, bufferSize)); \
return Status::OK(); \
}
TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
#define SPMM_INSTANCE(Scalar, dtype) \
template<> \
Status GpuSparse::SpMM<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, \
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, \
int8* buffer) const { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM( \
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, \
matC, dtype, alg, buffer)); \
return Status::OK(); \
#define SPMM_INSTANCE(Scalar, dtype) \
template <> \
Status GpuSparse::SpMM<Scalar>( \
cusparseOperation_t transA, cusparseOperation_t transB, \
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \
transB, alpha, matA, matB, beta, \
matC, dtype, alg, buffer)); \
return Status::OK(); \
}
TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
@ -515,11 +509,14 @@ TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
#else
template <typename Scalar>
static inline Status CsrmvExImpl(
cudaDataType_t dtype, OpKernelContext* context, cusparseHandle_t cusparse_handle,
cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const Scalar* x, const Scalar* beta_host, Scalar* y) {
static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
cusparseHandle_t cusparse_handle,
cusparseOperation_t transA, int m, int n,
int nnz, const Scalar* alpha_host,
const Scalar* csrSortedValA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) {
cusparseMatDescr_t descrA;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
TF_RETURN_IF_GPUSPARSE_ERROR(
@ -531,10 +528,9 @@ static inline Status CsrmvExImpl(
size_t bufferSize;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz,
alpha_host, dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA,
csrSortedColIndA, x, dtype, beta_host, dtype, y, dtype, dtype,
&bufferSize));
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(context->allocate_temp(
@ -543,40 +539,40 @@ static inline Status CsrmvExImpl(
DCHECK(pBuffer.data() != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz,
alpha_host, dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA,
csrSortedColIndA, x, dtype, beta_host, dtype, y, dtype, dtype,
pBuffer.data()));
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
return Status::OK();
}
template <typename Scalar>
static inline Status SpMVImpl(
cudaDataType_t dtype, OpKernelContext* context, cusparseHandle_t cusparse_handle,
cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const Scalar* x, const Scalar* beta_host, Scalar* y) {
static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
cusparseHandle_t cusparse_handle,
cusparseOperation_t transA, int m, int n, int nnz,
const Scalar* alpha_host,
const Scalar* csrSortedValA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* x,
const Scalar* beta_host, Scalar* y) {
cusparseSpMatDescr_t matA;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(&matA, m, n, nnz,
const_cast<int*>(csrSortedRowPtrA),
const_cast<int*>(csrSortedColIndA),
const_cast<Scalar*>(csrSortedValA),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, dtype));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
&matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
cusparseDnVecDescr_t vecX, vecY;
int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(
&vecX, sizeX, const_cast<Scalar*>(x), dtype));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
size_t bufferSize;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
cusparse_handle, transA, alpha_host, matA, vecX,
beta_host, vecY, dtype, CUSPARSE_CSRMV_ALG1, &bufferSize));
cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
CUSPARSE_CSRMV_ALG1, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(context->allocate_temp(
@ -584,9 +580,9 @@ static inline Status SpMVImpl(
auto pBuffer = buffer.flat<int8>();
DCHECK(pBuffer.data() != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV(
cusparse_handle, transA, alpha_host, matA, vecX,
beta_host, vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
@ -594,25 +590,23 @@ static inline Status SpMVImpl(
return Status::OK();
}
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
template <> \
Status GpuSparse::Csrmv<Scalar>( \
cusparseOperation_t transA, int m, int n, int nnz, \
const Scalar* alpha_host, const Scalar* csrSortedValA, \
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
DCHECK(initialized_); \
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, \
transA, m, n, nnz, alpha_host, \
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
x, beta_host, y); \
} else { \
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, \
transA, m, n, nnz, alpha_host, \
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
x, beta_host, y); \
} \
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
template <> \
Status GpuSparse::Csrmv<Scalar>( \
cusparseOperation_t transA, int m, int n, int nnz, \
const Scalar* alpha_host, const Scalar* csrSortedValA, \
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
DCHECK(initialized_); \
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \
m, n, nnz, alpha_host, csrSortedValA, \
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
} else { \
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \
n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, x, beta_host, y); \
} \
}
TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
@ -671,17 +665,16 @@ static inline Status Csrgeam2Impl(
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
csrSortedRowPtrB, csrSortedColIndB, descrC,
AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC,
workspace));
TF_RETURN_IF_GPUSPARSE_ERROR(op(
cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
csrSortedRowPtrC, csrSortedColIndC, workspace));
return Status::OK();
}
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::Csrgeam<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
@ -733,34 +726,32 @@ static inline Status CsrgeamBufferSizeExtImpl(
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
csrSortedRowPtrB, csrSortedColIndB, descrC,
AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC,
bufferSize));
TF_RETURN_IF_GPUSPARSE_ERROR(op(
sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
csrSortedRowPtrC, csrSortedColIndC, bufferSize));
return Status::OK();
}
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* beta, \
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
DCHECK(initialized_); \
return CsrgeamBufferSizeExtImpl(SPARSE_FN(csrgeam2_bufferSizeExt, \
sparse_prefix), context_, *gpusparse_handle_, m, n, \
alpha, descrA, nnzA, csrSortedValA, csrSortedRowPtrA, \
csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
csrSortedRowPtrB, csrSortedColIndB, descrC, \
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, \
bufferSize); \
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
template <> \
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
const int* csrSortedColIndA, const Scalar* beta, \
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
DCHECK(initialized_); \
return CsrgeamBufferSizeExtImpl( \
SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \
*gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \
csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \
csrSortedRowPtrC, csrSortedColIndC, bufferSize); \
}
#endif
@ -827,13 +818,13 @@ TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
#else
template<typename T>
template <typename T>
static const T* one_ptr() {
static const T one = static_cast<T>(1);
return &one;
}
template<typename T>
template <typename T>
static const T* null_ptr() {
return nullptr;
}
@ -847,55 +838,53 @@ static const T* null_ptr() {
const int* csrSortedColIndB, csrgemm2Info_t info, \
size_t* workspaceBytes) { \
DCHECK(initialized_); \
TF_RETURN_IF_GPUSPARSE_ERROR( \
SPARSE_FN(csrgemm2_bufferSizeExt, sparse_prefix)( \
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), \
descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
csrSortedRowPtrB, csrSortedColIndB, \
AsCudaComplex(null_ptr<Scalar>()), descrA, 0, null_ptr<int>(), \
null_ptr<int>(), info, workspaceBytes)); \
TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \
sparse_prefix)( \
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes)); \
return Status::OK(); \
}
TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
Status GpuSparse::CsrgemmNnz(
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const cusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
int* nnzTotalDevHostPtr, csrgemm2Info_t info, void* workspace) {
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
void* workspace) {
DCHECK(initialized_);
DCHECK(nnzTotalDevHostPtr != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
*gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC,
csrSortedRowPtrC, nnzTotalDevHostPtr, info, workspace));
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
nnzTotalDevHostPtr, info, workspace));
return Status::OK();
}
template <typename Scalar, typename SparseFnT>
static inline Status CsrgemmImpl(
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
int m, int n, int k, const cusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
const csrgemm2Info_t info, void* workspace) {
TF_RETURN_IF_GPUSPARSE_ERROR(
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()),
descrA, nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB, csrSortedColIndB,
AsCudaComplex(null_ptr<Scalar>()),
descrA, 0, AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC,
info, workspace));
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
csrSortedColIndC, info, workspace));
return Status::OK();
}
@ -1008,11 +997,9 @@ static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
const cusparseAction_t copyValues) {
size_t bufferSize;
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
cusparse_handle, m, n, nnz,
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
AsCudaComplex(cscVal), cscColPtr, cscRowInd,
dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG2, &bufferSize));
cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(context->allocate_temp(
@ -1021,27 +1008,25 @@ static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
DCHECK(buffer.flat<Scalar>().data() != nullptr);
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2(
cusparse_handle, m, n, nnz,
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
AsCudaComplex(cscVal), cscColPtr, cscRowInd,
dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG2,
buffer.flat<Scalar>().data()));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
return Status::OK();
}
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
template <> \
Status GpuSparse::Csr2csc<Scalar>( \
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
const cusparseAction_t copyValues) { \
DCHECK(initialized_); \
return Csr2cscImpl(cudaDataType, context_, \
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
template <> \
Status GpuSparse::Csr2csc<Scalar>( \
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
const cusparseAction_t copyValues) { \
DCHECK(initialized_); \
return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \
cscColPtr, copyValues); \
}
TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);

View File

@ -26,8 +26,8 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cusparse.h"
using gpusparseStatus_t = cusparseStatus_t;
using gpusparseOperation_t = cusparseOperation_t;
@ -253,7 +253,6 @@ class GpuSparse {
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
#if CUDA_VERSION < 10020
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
// where A is a sparse matrix in CSR format, B and C are dense tall
@ -275,13 +274,14 @@ class GpuSparse {
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
int ldc) const;
#else
// Workspace size query for sparse-dense matrix multiplication. Helper function
// for SpMM which computes y = alpha * op(A) * op(B) + beta * C, where A is
// a sparse matrix in CSR format, B and C are dense matricies in column-major
// format. Returns needed workspace size in bytes.
// Workspace size query for sparse-dense matrix multiplication. Helper
// function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
// where A is a sparse matrix in CSR format, B and C are dense matricies in
// column-major format. Returns needed workspace size in bytes.
template <typename Scalar>
Status SpMMBufferSize(gpusparseOperation_t transA, gpusparseOperation_t transB,
const Scalar* alpha, const gpusparseSpMatDescr_t matA,
Status SpMMBufferSize(gpusparseOperation_t transA,
gpusparseOperation_t transB, const Scalar* alpha,
const gpusparseSpMatDescr_t matA,
const gpusparseDnMatDescr_t matB, const Scalar* beta,
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
size_t* bufferSize) const;
@ -323,15 +323,14 @@ class GpuSparse {
// Computes workspace size for sparse - sparse matrix addition of matrices
// stored in CSR format.
template <typename Scalar>
Status CsrgeamBufferSizeExt(int m, int n, const Scalar* alpha,
const gpusparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* beta,
const gpusparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC, size_t* bufferSize);
Status CsrgeamBufferSizeExt(
int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const Scalar* beta,
const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
// Computes sparse-sparse matrix addition of matrices
// stored in CSR format. This is part one: calculate nnz of the
@ -365,15 +364,12 @@ class GpuSparse {
// Computes sparse-sparse matrix multiplication of matrices
// stored in CSR format. This is part zero: calculate required workspace
// size.
template<typename Scalar>
Status CsrgemmBufferSize(int m, int n, int k,
const gpusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA,
const int* csrSortedColIndA,
const gpusparseMatDescr_t descrB, int nnzB,
const int* csrSortedRowPtrB,
const int* csrSortedColIndB,
csrgemm2Info_t info, size_t* workspaceBytes);
template <typename Scalar>
Status CsrgemmBufferSize(
int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
#endif
// Computes sparse-sparse matrix multiplication of matrices
@ -419,13 +415,12 @@ class GpuSparse {
int* csrSortedColIndC);
#else
template <typename Scalar>
Status Csrgemm(int m, int n, int k,
const gpusparseMatDescr_t descrA, int nnzA,
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
const int* csrSortedColIndA, const gpusparseMatDescr_t descrB,
int nnzB, const Scalar* csrSortedValB,
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
const gpusparseMatDescr_t descrC,
Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
int nnzA, const Scalar* csrSortedValA,
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
const gpusparseMatDescr_t descrB, int nnzB,
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
Scalar* csrSortedValC, int* csrSortedRowPtrC,
int* csrSortedColIndC, const csrgemm2Info_t info,
void* workspace);

View File

@ -108,7 +108,7 @@ class CSRSparseMatrixAddFunctor {
set_zero(d, c_row_ptr_t.flat<int32>());
size_t maxWorkspaceSize = 0;
for (int i=0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
a.values_vec<T>(i), a_dense_shape};
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
@ -140,9 +140,8 @@ class CSRSparseMatrixAddFunctor {
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
rows + 1);
int c_nnz_i;
TF_RETURN_IF_ERROR(
csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i,
workspace));
TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure(
a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace));
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
}
@ -290,8 +289,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
}
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
size_t* bufferSize) {
const ConstCSRComponent<T>& b, size_t* bufferSize) {
DCHECK(initialized_);
const int m = a.row_ptr.size() - 1;
@ -310,13 +308,12 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt(
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T,
null_int, null_int, bufferSize));
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int,
null_int, bufferSize));
return Status::OK();
}
Status GetOutputStructure(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
TTypes<int32>::UnalignedVec c_row_ptr,

View File

@ -174,8 +174,7 @@ struct CSRStructureModifyingFunctor {
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
TTypes<int32>::UnalignedVec c_row_ptr,
int* output_nnz,
void* workspace) = 0;
int* output_nnz, void* workspace) = 0;
virtual Status Compute(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b, CSRComponent<T>* c,

View File

@ -838,35 +838,33 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
gpusparseDnMatDescr_t matB, matC;
// NOTE: the following APIs are not available in ROCM
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(&matA, m, k, nnz,
const_cast<int*>(a.row_ptr.data()),
const_cast<int*>(a.col_ind.data()),
const_cast<T*>(a.values.data()),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, CUDADataType<T>::type));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
&matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
CUDADataType<T>::type));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnMat(&matB, n, k, ldb,
const_cast<T*>(b.data()),
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnMat(&matC, m, n, ldc, c.data(),
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
TF_RETURN_IF_GPUSPARSE_ERROR(
cusparseCreateDnMat(&matC, m, n, ldc, c.data(), CUDADataType<T>::type,
CUSPARSE_ORDER_COL));
size_t bufferSize = 0;
TF_RETURN_IF_ERROR(
cuda_sparse.SpMMBufferSize(transA, transB, &alpha, matA, matB, &beta,
matC, CUSPARSE_MM_ALG_DEFAULT,
&bufferSize));
TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
transA, transB, &alpha, matA, matB, &beta, matC,
CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
Tensor buffer;
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
DCHECK(buffer.flat<int8>().data() != nullptr);
TF_RETURN_IF_ERROR(
cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, &beta,
matC, CUSPARSE_MM_ALG_DEFAULT,
buffer.flat<int8>().data()));
TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
&beta, matC, CUSPARSE_MM_ALG_DEFAULT,
buffer.flat<int8>().data()));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
@ -897,14 +895,12 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
#endif // GOOGLE_CUDA
TF_RETURN_IF_ERROR(
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
b.data(), ldb, &beta, c.data(), ldc));
#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020
}
return Status::OK();

View File

@ -462,9 +462,9 @@ class CSRSparseMatMulGPUOp : public OpKernel {
rows + 1);
int c_nnz_i;
OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp,
c_row_ptr_i, &c_nnz_i,
workspace));
OP_REQUIRES_OK(ctx,
csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
&c_nnz_i, workspace));
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
}
@ -604,8 +604,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
}
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
const ConstCSRComponent<T>& b,
size_t* bufferSize) {
const ConstCSRComponent<T>& b, size_t* bufferSize) {
DCHECK(initialized_);
const int m =
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
@ -624,9 +623,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
b.col_ind.data(), info_, bufferSize));
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
bufferSize));
return Status::OK();
}
@ -663,10 +662,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
#else
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz,
info_, workspace));
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
#endif
if (*output_nnz < 0) {
@ -708,11 +706,10 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
#else
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
m, n, k, descrA_.descr(), nnzA, a.values.data(),
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
b.values.data(), b.row_ptr.data(), b.col_ind.data(),
descrC_.descr(), c->values.data(), c->row_ptr.data(), c->col_ind.data(),
info_, workspace));
m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
c->row_ptr.data(), c->col_ind.data(), info_, workspace));
#endif
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix