diff --git a/tensorflow/core/kernels/cuda_sparse.cc b/tensorflow/core/kernels/cuda_sparse.cc index 9d4ddc13d0d..141aae61571 100644 --- a/tensorflow/core/kernels/cuda_sparse.cc +++ b/tensorflow/core/kernels/cuda_sparse.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "third_party/gpus/cuda/include/cusparse.h" +#include "third_party/gpus/cuda/include/library_types.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -179,6 +180,10 @@ Status GpuSparse::Initialize() { return Status::OK(); } +#define TF_CALL_CUSPARSE_DTYPES(m) \ + m(float, CUDA_R_32F) m(double, CUDA_R_64F) \ + m(std::complex, CUDA_C_32F) m(std::complex, CUDA_C_64F) + // Macro that specializes a sparse method for all 4 standard // numeric types. // TODO: reuse with cuda_solvers @@ -359,23 +364,30 @@ Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m, return Status::OK(); } -Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, - int nnzA, const int* csrSortedRowPtrA, - const int* csrSortedColIndA, - const cusparseMatDescr_t descrB, int nnzB, - const int* csrSortedRowPtrB, - const int* csrSortedColIndB, - const cusparseMatDescr_t descrC, - int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) { +Status GpuSparse::CsrgeamNnz( + int m, int n, const cusparseMatDescr_t descrA, int nnzA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const cusparseMatDescr_t descrC, + int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) { DCHECK(initialized_); DCHECK(nnzTotalDevHostPtr != nullptr); +#if CUDA_VERSION >= 10000 + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeam2Nnz( + *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, + descrC, csrSortedRowPtrC, nnzTotalDevHostPtr, workspace)); +#else TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz( *gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr)); +#endif return Status::OK(); } +#if CUDA_VERSION < 10020 + template static inline Status CsrmmImpl( SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle, @@ -416,6 +428,45 @@ static inline Status CsrmmImpl( TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE); +#else + +#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \ + template <> \ + Status GpuSparse::SpMMBufferSize( \ + cusparseOperation_t transA, cusparseOperation_t transB, \ + const Scalar* alpha, const cusparseSpMatDescr_t matA, \ + const gpusparseDnMatDescr_t matB, const Scalar* beta, \ + gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \ + const { \ + DCHECK(initialized_); \ + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \ + *gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \ + dtype, alg, bufferSize)); \ + return Status::OK(); \ + } + +TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE); + +#define SPMM_INSTANCE(Scalar, dtype) \ + template <> \ + Status GpuSparse::SpMM( \ + cusparseOperation_t transA, cusparseOperation_t transB, \ + const Scalar* alpha, const cusparseSpMatDescr_t matA, \ + const gpusparseDnMatDescr_t matB, const Scalar* beta, \ + gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \ + DCHECK(initialized_); \ + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \ + transB, alpha, matA, matB, beta, \ + matC, dtype, alg, buffer)); \ + return Status::OK(); \ + } + +TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE); + +#endif + +#if CUDA_VERSION < 10020 + template static inline Status CsrmvImpl( SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle, @@ -455,6 +506,115 @@ static inline Status CsrmvImpl( TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE); +#else + +template +static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context, + cusparseHandle_t cusparse_handle, + cusparseOperation_t transA, int m, int n, + int nnz, const Scalar* alpha_host, + const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* x, + const Scalar* beta_host, Scalar* y) { + cusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); + // CUSPARSE_ALG_MERGE_PATH algo only supports non-transpose matrix. + DCHECK(transA == CUSPARSE_OPERATION_NON_TRANSPOSE); + + size_t bufferSize; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize( + cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host, + dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA, + x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize)); + + Tensor buffer; + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT8, TensorShape({static_cast(bufferSize)}), &buffer)); + auto pBuffer = buffer.flat(); + DCHECK(pBuffer.data() != nullptr); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx( + cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host, + dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA, + x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data())); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA)); + return Status::OK(); +} + +template +static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context, + cusparseHandle_t cusparse_handle, + cusparseOperation_t transA, int m, int n, int nnz, + const Scalar* alpha_host, + const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* x, + const Scalar* beta_host, Scalar* y) { + cusparseSpMatDescr_t matA; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( + &matA, m, n, nnz, const_cast(csrSortedRowPtrA), + const_cast(csrSortedColIndA), const_cast(csrSortedValA), + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype)); + + cusparseDnVecDescr_t vecX, vecY; + int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m; + int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n; + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseCreateDnVec(&vecX, sizeX, const_cast(x), dtype)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype)); + + size_t bufferSize; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize( + cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype, + CUSPARSE_CSRMV_ALG1, &bufferSize)); + + Tensor buffer; + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT8, TensorShape({static_cast(bufferSize)}), &buffer)); + auto pBuffer = buffer.flat(); + DCHECK(pBuffer.data() != nullptr); + + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host, + vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data())); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA)); + return Status::OK(); +} + +#define CSRMV_INSTANCE(Scalar, cudaDataType) \ + template <> \ + Status GpuSparse::Csrmv( \ + cusparseOperation_t transA, int m, int n, int nnz, \ + const Scalar* alpha_host, const Scalar* csrSortedValA, \ + const int* csrSortedRowPtrA, const int* csrSortedColIndA, \ + const Scalar* x, const Scalar* beta_host, Scalar* y) const { \ + DCHECK(initialized_); \ + if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \ + return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \ + m, n, nnz, alpha_host, csrSortedValA, \ + csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \ + } else { \ + return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \ + n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \ + csrSortedColIndA, x, beta_host, y); \ + } \ + } + +TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE); + +#endif // CUDA_VERSION < 10020 + +#if CUDA_VERSION < 10000 + template static inline Status CsrgeamImpl( SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle, @@ -483,7 +643,7 @@ static inline Status CsrgeamImpl( const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \ const int* csrSortedRowPtrB, const int* csrSortedColIndB, \ const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \ - int* csrSortedRowPtrC, int* csrSortedColIndC) { \ + int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \ DCHECK(initialized_); \ return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \ *gpusparse_handle_, m, n, alpha, descrA, nnzA, \ @@ -493,8 +653,113 @@ static inline Status CsrgeamImpl( csrSortedRowPtrC, csrSortedColIndC); \ } +#else + +template +static inline Status Csrgeam2Impl( + SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle, + int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* beta, + const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const cusparseMatDescr_t descrC, Scalar* csrSortedValC, + int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { + TF_RETURN_IF_GPUSPARSE_ERROR(op( + cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA, + AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, + AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB), + csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC), + csrSortedRowPtrC, csrSortedColIndC, workspace)); + return Status::OK(); +} + +#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Csrgeam( \ + int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \ + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ + const int* csrSortedColIndA, const Scalar* beta, \ + const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \ + const int* csrSortedRowPtrB, const int* csrSortedColIndB, \ + const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \ + int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \ + DCHECK(initialized_); \ + return Csrgeam2Impl(SPARSE_FN(csrgeam2, sparse_prefix), context_, \ + *gpusparse_handle_, m, n, alpha, descrA, nnzA, \ + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \ + beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \ + csrSortedColIndB, descrC, csrSortedValC, \ + csrSortedRowPtrC, csrSortedColIndC, workspace); \ + } + +#endif + TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE); +#if CUDA_VERSION < 10000 + +#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::CsrgeamBufferSizeExt( \ + int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \ + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ + const int* csrSortedColIndA, const Scalar* beta, \ + const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \ + const int* csrSortedRowPtrB, const int* csrSortedColIndB, \ + const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \ + int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \ + DCHECK(initialized_); \ + *bufferSize = 0; \ + return Status::OK(); \ + } + +#else + +template +static inline Status CsrgeamBufferSizeExtImpl( + SparseFnT op, OpKernelContext* context, cusparseHandle_t sparse_handle, + int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* beta, + const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const cusparseMatDescr_t descrC, Scalar* csrSortedValC, + int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { + TF_RETURN_IF_GPUSPARSE_ERROR(op( + sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA, + AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, + AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB), + csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC), + csrSortedRowPtrC, csrSortedColIndC, bufferSize)); + return Status::OK(); +} + +#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::CsrgeamBufferSizeExt( \ + int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \ + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ + const int* csrSortedColIndA, const Scalar* beta, \ + const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \ + const int* csrSortedRowPtrB, const int* csrSortedColIndB, \ + const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \ + int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \ + DCHECK(initialized_); \ + return CsrgeamBufferSizeExtImpl( \ + SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \ + *gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \ + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \ + csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \ + csrSortedRowPtrC, csrSortedColIndC, bufferSize); \ + } + +#endif + +TF_CALL_LAPACK_TYPES(CSRGEAM_BUFFERSIZE_INSTANCE); + +#if CUDA_VERSION < 10000 + Status GpuSparse::CsrgemmNnz( cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n, const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA, @@ -551,6 +816,101 @@ static inline Status CsrgemmImpl( TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE); +#else + +template +static const T* one_ptr() { + static const T one = static_cast(1); + return &one; +} + +template +static const T* null_ptr() { + return nullptr; +} + +#define CSRGEMM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::CsrgemmBufferSize( \ + int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \ + const int* csrSortedRowPtrA, const int* csrSortedColIndA, \ + const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, \ + const int* csrSortedColIndB, csrgemm2Info_t info, \ + size_t* workspaceBytes) { \ + DCHECK(initialized_); \ + TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \ + sparse_prefix)( \ + *gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr()), descrA, \ + nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \ + csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr()), \ + descrA, 0, null_ptr(), null_ptr(), info, workspaceBytes)); \ + return Status::OK(); \ + } + +TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE); + +Status GpuSparse::CsrgemmNnz( + int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const cusparseMatDescr_t descrC, + int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info, + void* workspace) { + DCHECK(initialized_); + DCHECK(nnzTotalDevHostPtr != nullptr); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz( + *gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, + descrA, 0, null_ptr(), null_ptr(), descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, workspace)); + return Status::OK(); +} + +template +static inline Status CsrgemmImpl( + SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle, + int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const cusparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC, + const csrgemm2Info_t info, void* workspace) { + TF_RETURN_IF_GPUSPARSE_ERROR( + op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr()), descrA, + nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA, + descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB, + csrSortedColIndB, AsCudaComplex(null_ptr()), descrA, 0, + AsCudaComplex(null_ptr()), null_ptr(), null_ptr(), + descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC, + csrSortedColIndC, info, workspace)); + return Status::OK(); +} + +#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \ + template <> \ + Status GpuSparse::Csrgemm( \ + int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \ + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \ + const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \ + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \ + const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \ + Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC, \ + const csrgemm2Info_t info, void* workspace) { \ + DCHECK(initialized_); \ + return CsrgemmImpl(SPARSE_FN(csrgemm2, sparse_prefix), context_, \ + *gpusparse_handle_, m, n, k, descrA, nnzA, \ + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \ + descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \ + csrSortedColIndB, descrC, csrSortedValC, \ + csrSortedRowPtrC, csrSortedColIndC, info, workspace); \ + } + +TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE); + +#endif // CUDA_VERSION < 10000 + template static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op, OpKernelContext* context, @@ -596,6 +956,8 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op, TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE); +#if CUDA_VERSION < 10010 + template static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle, int m, int n, @@ -624,6 +986,53 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context, TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE); +#else + +template +static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context, + cusparseHandle_t cusparse_handle, int m, int n, + int nnz, const Scalar* csrVal, + const int* csrRowPtr, const int* csrColInd, + Scalar* cscVal, int* cscRowInd, int* cscColPtr, + const cusparseAction_t copyValues) { + size_t bufferSize; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize( + cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd, + AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues, + CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize)); + + Tensor buffer; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, + TensorShape({static_cast(bufferSize)}), &buffer)); + + DCHECK(buffer.flat().data() != nullptr); + + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), + csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr, + cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO, + CUSPARSE_CSR2CSC_ALG2, buffer.flat().data())); + + return Status::OK(); +} + +#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \ + template <> \ + Status GpuSparse::Csr2csc( \ + int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \ + const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \ + const cusparseAction_t copyValues) { \ + DCHECK(initialized_); \ + return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \ + csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \ + cscColPtr, copyValues); \ + } + +TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE); + +#endif // CUDA_VERSION < 10010 + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cuda_sparse.h b/tensorflow/core/kernels/cuda_sparse.h index 35bd5ccf0d7..eb69469b615 100644 --- a/tensorflow/core/kernels/cuda_sparse.h +++ b/tensorflow/core/kernels/cuda_sparse.h @@ -26,6 +26,7 @@ limitations under the License. #if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cusparse.h" using gpusparseStatus_t = cusparseStatus_t; @@ -34,6 +35,11 @@ using gpusparseMatDescr_t = cusparseMatDescr_t; using gpusparseAction_t = cusparseAction_t; using gpusparseHandle_t = cusparseHandle_t; using gpuStream_t = cudaStream_t; +#if CUDA_VERSION >= 10020 +using gpusparseDnMatDescr_t = cusparseDnMatDescr_t; +using gpusparseSpMatDescr_t = cusparseSpMatDescr_t; +using gpusparseSpMMAlg_t = cusparseSpMMAlg_t; +#endif #define GPUSPARSE(postfix) CUSPARSE_##postfix #define gpusparse(postfix) cusparse##postfix @@ -253,6 +259,7 @@ class GpuSparse { // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr. Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const; +#if CUDA_VERSION < 10020 // Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C, // where A is a sparse matrix in CSR format, B and C are dense tall // matrices. This routine allows transposition of matrix B, which @@ -272,18 +279,64 @@ class GpuSparse { const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) const; +#else + // Workspace size query for sparse-dense matrix multiplication. Helper + // function for SpMM which computes y = alpha * op(A) * op(B) + beta * C, + // where A is a sparse matrix in CSR format, B and C are dense matricies in + // column-major format. Returns needed workspace size in bytes. + template + Status SpMMBufferSize(gpusparseOperation_t transA, + gpusparseOperation_t transB, const Scalar* alpha, + const gpusparseSpMatDescr_t matA, + const gpusparseDnMatDescr_t matB, const Scalar* beta, + gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg, + size_t* bufferSize) const; + + // Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C, + // where A is a sparse matrix in CSR format, B and C are dense matricies in + // column-major format. Buffer is assumed to be at least as large as the + // workspace size returned by SpMMBufferSize(). + // + // **NOTE** This is an in-place operation for data in C. + template + Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB, + const Scalar* alpha, const gpusparseSpMatDescr_t matA, + const gpusparseDnMatDescr_t matB, const Scalar* beta, + gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg, + int8* buffer) const; +#endif // Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y, // where A is a sparse matrix in CSR format, x and y are dense vectors. See: // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath // // **NOTE** This is an in-place operation for data in y. +#if CUDA_VERSION < 10020 template Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host, const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, Scalar* y) const; +#else + template + Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz, + const Scalar* alpha_host, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const Scalar* x, const Scalar* beta_host, Scalar* y) const; +#endif // CUDA_VERSION < 10020 + + // Computes workspace size for sparse - sparse matrix addition of matrices + // stored in CSR format. + template + Status CsrgeamBufferSizeExt( + int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* beta, + const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const gpusparseMatDescr_t descrC, Scalar* csrSortedValC, + int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize); // Computes sparse-sparse matrix addition of matrices // stored in CSR format. This is part one: calculate nnz of the @@ -295,7 +348,7 @@ class GpuSparse { const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, - int* nnzTotalDevHostPtr); + int* nnzTotalDevHostPtr, void* workspace); // Computes sparse - sparse matrix addition of matrices // stored in CSR format. This is part two: perform sparse-sparse @@ -311,13 +364,26 @@ class GpuSparse { const Scalar* csrSortedValB, const int* csrSortedRowPtrB, const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, Scalar* csrSortedValC, int* csrSortedRowPtrC, - int* csrSortedColIndC); + int* csrSortedColIndC, void* workspace); + +#if CUDA_VERSION >= 10000 + // Computes sparse-sparse matrix multiplication of matrices + // stored in CSR format. This is part zero: calculate required workspace + // size. + template + Status CsrgemmBufferSize( + int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes); +#endif // Computes sparse-sparse matrix multiplication of matrices // stored in CSR format. This is part one: calculate nnz of the // output. csrSortedRowPtrC must be preallocated on device with // m + 1 entries. See: // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. +#if CUDA_VERSION < 10000 Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB, int m, int k, int n, const gpusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA, @@ -326,12 +392,23 @@ class GpuSparse { const int* csrSortedRowPtrB, const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, int* nnzTotalDevHostPtr); +#else + Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA, + int nnzA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, + int* nnzTotalDevHostPtr, csrgemm2Info_t info, + void* workspace); +#endif // Computes sparse - sparse matrix matmul of matrices // stored in CSR format. This is part two: perform sparse-sparse // addition. csrValC and csrColIndC must be allocated on the device // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See: // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm. +#if CUDA_VERSION < 10000 template Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m, int k, int n, const gpusparseMatDescr_t descrA, @@ -342,6 +419,18 @@ class GpuSparse { const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC); +#else + template + Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, + int* csrSortedColIndC, const csrgemm2Info_t info, + void* workspace); +#endif // In-place reordering of unsorted CSR to sorted CSR. // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr diff --git a/tensorflow/core/kernels/sparse/add_op.cc b/tensorflow/core/kernels/sparse/add_op.cc index 81bc7dfdb7d..b6265a1412c 100644 --- a/tensorflow/core/kernels/sparse/add_op.cc +++ b/tensorflow/core/kernels/sparse/add_op.cc @@ -107,6 +107,26 @@ class CSRSparseMatrixAddFunctor { const Device& d = ctx_->eigen_device(); set_zero(d, c_row_ptr_t.flat()); + size_t maxWorkspaceSize = 0; + for (int i = 0; i < batch_size; ++i) { + ConstCSRComponent a_comp{a.row_pointers_vec(i), a.col_indices_vec(i), + a.values_vec(i), a_dense_shape}; + ConstCSRComponent b_comp{b.row_pointers_vec(i), b.col_indices_vec(i), + b.values_vec(i), b_dense_shape}; + + size_t thisWorkspaceSize; + TF_RETURN_IF_ERROR( + csr_geam.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize)); + if (thisWorkspaceSize > maxWorkspaceSize) { + maxWorkspaceSize = thisWorkspaceSize; + } + } + + Tensor temp; + TF_RETURN_IF_ERROR(ctx_->allocate_temp( + DT_INT8, TensorShape({static_cast(maxWorkspaceSize)}), &temp)); + void* workspace = temp.flat().data(); + for (int i = 0; i < batch_size; ++i) { // Calculate output sizes for all minibatch entries. // Store in c_batch_ptr and update c_row_ptrs. @@ -121,8 +141,8 @@ class CSRSparseMatrixAddFunctor { TTypes::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)), rows + 1); int c_nnz_i; - TF_RETURN_IF_ERROR( - csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i)); + TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure( + a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace)); c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i; } @@ -151,7 +171,7 @@ class CSRSparseMatrixAddFunctor { CSRComponent c_comp{c->row_pointers_vec(i), c->col_indices_vec(i), c->values_vec(i), c_dense_shape_t.vec()}; - TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp)); + TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace)); } return Status::OK(); @@ -269,10 +289,36 @@ struct CSRSparseMatrixAdd return Status::OK(); } + Status GetWorkspaceSize(const ConstCSRComponent& a, + const ConstCSRComponent& b, size_t* bufferSize) { + DCHECK(initialized_); + + const int m = a.row_ptr.size() - 1; + DCHECK_EQ(m, b.row_ptr.size() - 1); + const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1; + DCHECK_EQ(m, a.dense_shape_host(row_dim)); + DCHECK_EQ(m, b.dense_shape_host(row_dim)); + const int nnzA = a.col_ind.size(); + const int nnzB = b.col_ind.size(); + + const int n = a.dense_shape_host(row_dim + 1); + DCHECK_EQ(n, b.dense_shape_host(row_dim + 1)); + T* null_T = nullptr; + int* null_int = nullptr; + + TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt( + m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(), + a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(), + b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int, + null_int, bufferSize)); + + return Status::OK(); + } + Status GetOutputStructure(const ConstCSRComponent& a, const ConstCSRComponent& b, TTypes::UnalignedVec c_row_ptr, - int* output_nnz) { + int* output_nnz, void* workspace) { DCHECK(initialized_); const int m = a.row_ptr.size() - 1; @@ -290,7 +336,7 @@ struct CSRSparseMatrixAdd TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz( m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), - descrC_.descr(), c_row_ptr.data(), output_nnz)); + descrC_.descr(), c_row_ptr.data(), output_nnz, workspace)); if (*output_nnz < 0) { return errors::Internal( @@ -300,7 +346,7 @@ struct CSRSparseMatrixAdd } Status Compute(const ConstCSRComponent& a, const ConstCSRComponent& b, - CSRComponent* c) { + CSRComponent* c, void* workspace) { DCHECK(initialized_); const int m = a.row_ptr.size() - 1; @@ -319,7 +365,7 @@ struct CSRSparseMatrixAdd m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(), a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(), - c->row_ptr.data(), c->col_ind.data())); + c->row_ptr.data(), c->col_ind.data(), workspace)); return Status::OK(); } diff --git a/tensorflow/core/kernels/sparse/kernels.h b/tensorflow/core/kernels/sparse/kernels.h index f795829af05..0c4ef9e26dc 100644 --- a/tensorflow/core/kernels/sparse/kernels.h +++ b/tensorflow/core/kernels/sparse/kernels.h @@ -167,13 +167,18 @@ struct CSRStructureModifyingFunctor { virtual Status Initialize() = 0; + virtual Status GetWorkspaceSize(const ConstCSRComponent& a, + const ConstCSRComponent& b, + size_t* bufferSize) = 0; + virtual Status GetOutputStructure(const ConstCSRComponent& a, const ConstCSRComponent& b, TTypes::UnalignedVec c_row_ptr, - int* output_nnz) = 0; + int* output_nnz, void* workspace) = 0; virtual Status Compute(const ConstCSRComponent& a, - const ConstCSRComponent& b, CSRComponent* c) = 0; + const ConstCSRComponent& b, CSRComponent* c, + void* workspace) = 0; }; // Calculates C = alpha * A + beta * B, where A and B are in CSR diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.cc b/tensorflow/core/kernels/sparse/mat_mul_op.cc index 36b1ec18ded..23b01e29b5b 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/mat_mul_op.cc @@ -721,6 +721,56 @@ REGISTER_GPU(complex128) namespace functor { +namespace { + +// CUDADataType::type translates from a C++ type (e.g. float) to a +// cudaDataType_t (e.g. CUDA_R_32F). +template +struct CUDADataType; + +template <> +struct CUDADataType { + static constexpr cudaDataType_t type = CUDA_R_16F; +}; + +template <> +struct CUDADataType { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_R_32F; +#elif TENSORFLOW_USE_ROCM + static constexpr cudaDataType_t type = HIPBLAS_R_32F; +#endif +}; + +template <> +struct CUDADataType> { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_C_32F; +#elif TENSORFLOW_USE_ROCM + static constexpr cudaDataType_t type = HIPBLAS_C_32F; +#endif +}; + +template <> +struct CUDADataType { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_R_64F; +#elif TENSORFLOW_USE_ROCM + static constexpr cudaDataType_t type = HIPBLAS_R_64F; +#endif +}; + +template <> +struct CUDADataType> { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_C_64F; +#elif TENSORFLOW_USE_ROCM + static constexpr cudaDataType_t type = HIPBLAS_C_64F; +#endif +}; + +} // namespace + template class CSRSparseMatrixMatMul { public: @@ -733,10 +783,10 @@ class CSRSparseMatrixMatMul { GpuSparse cuda_sparse(ctx); TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); { - // Use Csrmm to calculate: + // Use Csrmm/SpMM to calculate: // C = alpha * op(A) * op(B) + beta * C // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense. - // Note that Csrmm assumes B and C are in column-major form; so we + // Note that Csrmm/Spmm assumes B and C are in column-major form; so we // use transB == true, and manually transpose the output in place // using blasgeam. // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint. @@ -746,22 +796,6 @@ class CSRSparseMatrixMatMul { const T alpha = 1; const T beta = 0; - // transA must be non-transpose if transB is transpose (cusparse - // limitation). - const gpusparseOperation_t transA = GPUSPARSE(OPERATION_NON_TRANSPOSE); - - // transB: b is row-major, and cusparse requires col-major b (or - // equivalently transB == transpose). this version is actually more - // efficient. - const gpusparseOperation_t transB = GPUSPARSE(OPERATION_TRANSPOSE); - - gpusparseMatDescr_t descrA; - TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA)); - TF_RETURN_IF_GPUSPARSE_ERROR( - gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL))); - TF_RETURN_IF_GPUSPARSE_ERROR( - gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO))); - // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) const int k = b.dimension(0); DCHECK_EQ(k, a.dense_shape_host(1)); @@ -786,10 +820,87 @@ class CSRSparseMatrixMatMul { // op(A) = A and at least max(1, k) otherwise. const int ldc = m; + // transA must be non-transpose if transB is transpose (cusparse + // limitation). +#if GOOGLE_CUDA + const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; +#elif TENSORFLOW_USE_ROCM + const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; +#endif + + // transB: b is row-major, and cusparse requires col-major b (or + // equivalently transB == transpose). this version is actually more + // efficient. +#if GOOGLE_CUDA && CUDA_VERSION >= 10020 + + const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; + gpusparseSpMatDescr_t matA; + gpusparseDnMatDescr_t matB, matC; + + // NOTE: the following APIs are not available in ROCM + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( + &matA, m, k, nnz, const_cast(a.row_ptr.data()), + const_cast(a.col_ind.data()), const_cast(a.values.data()), + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, + CUDADataType::type)); + + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseCreateDnMat(&matB, n, k, ldb, const_cast(b.data()), + CUDADataType::type, CUSPARSE_ORDER_COL)); + + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseCreateDnMat(&matC, m, n, ldc, c.data(), CUDADataType::type, + CUSPARSE_ORDER_COL)); + + size_t bufferSize = 0; + TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( + transA, transB, &alpha, matA, matB, &beta, matC, + CUSPARSE_MM_ALG_DEFAULT, &bufferSize)); + + Tensor buffer; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(bufferSize)}), &buffer)); + DCHECK(buffer.flat().data() != nullptr); + + TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, + &beta, matC, CUSPARSE_MM_ALG_DEFAULT, + buffer.flat().data())); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA)); + +#else + +#if GOOGLE_CUDA + + const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; + + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); + +#elif TENSORFLOW_USE_ROCM + + const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; + + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); +#endif // GOOGLE_CUDA + TF_RETURN_IF_ERROR( cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA, a.values.data(), a.row_ptr.data(), a.col_ind.data(), b.data(), ldb, &beta, c.data(), ldc)); + +#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020 } return Status::OK(); @@ -822,20 +933,35 @@ class CSRSparseMatrixMatVec { const T alpha = 1; const T beta = 0; +#if GOOGLE_CUDA && CUDA_VERSION < 10020 gpusparseMatDescr_t descrA; - TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); TF_RETURN_IF_GPUSPARSE_ERROR( - gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL))); + cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); TF_RETURN_IF_GPUSPARSE_ERROR( - gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO))); + cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); +#elif TENSORFLOW_USE_ROCM + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO)); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM const int m = a.dense_shape_host(0); const int n = a.dense_shape_host(1); const int nnz = a.values.size(); DCHECK_EQ(nnz, a.col_ind.size()); +#if CUDA_VERSION >= 10020 + TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, + a.values.data(), a.row_ptr.data(), + a.col_ind.data(), x, &beta, y)); +#else TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA, a.values.data(), a.row_ptr.data(), a.col_ind.data(), x, &beta, y)); +#endif } return Status::OK(); diff --git a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc index 7a66c8af163..7325d5f6873 100644 --- a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc @@ -417,6 +417,36 @@ class CSRSparseMatMulGPUOp : public OpKernel { } auto b_input_dense_shape = b_input_matrix->dense_shape().vec(); +#if CUDA_VERSION >= 10000 + size_t maxWorkspaceSize = 0; + for (int i = 0; i < batch_size; ++i) { + // Calculate maximum workspace size over batch. + ConstCSRComponent a_comp{a_input_matrix->row_pointers_vec(i), + a_input_matrix->col_indices_vec(i), + a_input_matrix->values_vec(i), + a_input_dense_shape}; + ConstCSRComponent b_comp{b_input_matrix->row_pointers_vec(i), + b_input_matrix->col_indices_vec(i), + b_input_matrix->values_vec(i), + b_input_dense_shape}; + size_t thisWorkspaceSize; + OP_REQUIRES_OK( + ctx, csr_gemm.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize)); + if (thisWorkspaceSize > maxWorkspaceSize) { + maxWorkspaceSize = thisWorkspaceSize; + } + } + + Tensor temp; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(maxWorkspaceSize)}), + &temp)); + void* workspace = temp.flat().data(); +#else + void* workspace = nullptr; +#endif + for (int i = 0; i < batch_size; ++i) { // Calculate output sizes for all minibatch entries. // Store in c_batch_ptr and update c_row_ptrs. @@ -433,8 +463,9 @@ class CSRSparseMatMulGPUOp : public OpKernel { rows + 1); int c_nnz_i; - OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp, - c_row_ptr_i, &c_nnz_i)); + OP_REQUIRES_OK(ctx, + csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, + &c_nnz_i, workspace)); c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i; } @@ -464,7 +495,7 @@ class CSRSparseMatMulGPUOp : public OpKernel { b_input_dense_shape}; CSRComponent c_comp{c.row_pointers_vec(i), c.col_indices_vec(i), c.values_vec(i), c_dense_shape}; - OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp)); + OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp, workspace)); } Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({})); @@ -527,7 +558,12 @@ struct CSRSparseSparseMatrixMatMul initialized_(false), transpose_a_(transpose_a), adjoint_a_(adjoint_a), +#if CUDA_VERSION < 10000 transpose_b_(transpose_b) { +#else + transpose_b_(transpose_b), + info_(nullptr) { +#endif // CUDA_VERSION < 10000 // TODO(ebrevdo): Figure out why transposed implementations crash cuSparse. transA_ = transpose_a ? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE) @@ -537,6 +573,14 @@ struct CSRSparseSparseMatrixMatMul : GPUSPARSE(OPERATION_NON_TRANSPOSE); } +#if CUDA_VERSION >= 10000 + ~CSRSparseSparseMatrixMatMul() { + if (initialized_) { + cusparseDestroyCsrgemm2Info(info_); + } + } +#endif + Status Initialize() { if (adjoint_a_ && transpose_a_) { return errors::InvalidArgument( @@ -547,14 +591,44 @@ struct CSRSparseSparseMatrixMatMul TF_RETURN_IF_ERROR(descrA_.Initialize()); TF_RETURN_IF_ERROR(descrB_.Initialize()); TF_RETURN_IF_ERROR(descrC_.Initialize()); +#if CUDA_VERSION >= 10000 + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_)); +#endif initialized_ = true; return Status::OK(); } + Status GetWorkspaceSize(const ConstCSRComponent& a, + const ConstCSRComponent& b, size_t* bufferSize) { + DCHECK(initialized_); + const int m = + a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2)); + if (!transpose_a_) { + DCHECK_EQ(m, a.row_ptr.size() - 1); + } + const int k = + a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1)); + if (!transpose_b_) { + DCHECK_EQ(k, b.row_ptr.size() - 1); + } + const int nnzA = a.col_ind.size(); + const int nnzB = b.col_ind.size(); + + const int n = + b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1)); + + TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize( + m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(), + descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_, + bufferSize)); + + return Status::OK(); + } + Status GetOutputStructure(const ConstCSRComponent& a, const ConstCSRComponent& b, TTypes::UnalignedVec c_row_ptr, - int* output_nnz) { + int* output_nnz, void* workspace) { DCHECK(initialized_); const int m = @@ -576,10 +650,17 @@ struct CSRSparseSparseMatrixMatMul *output_nnz = -1; +#if CUDA_VERSION < 10000 TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz( transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz)); +#else + TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz( + m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(), + descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), + descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace)); +#endif if (*output_nnz < 0) { return errors::Internal( @@ -590,7 +671,7 @@ struct CSRSparseSparseMatrixMatMul } Status Compute(const ConstCSRComponent& a, const ConstCSRComponent& b, - CSRComponent* c) { + CSRComponent* c, void* workspace) { DCHECK(initialized_); const int m = @@ -612,11 +693,19 @@ struct CSRSparseSparseMatrixMatMul b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1)); DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1)); +#if CUDA_VERSION < 10000 TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm( transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(), c->row_ptr.data(), c->col_ind.data())); +#else + TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm( + m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(), + a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(), + b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(), + c->row_ptr.data(), c->col_ind.data(), info_, workspace)); +#endif // TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix // columns are sorted? Above operation leads to unsorted columns. @@ -643,6 +732,9 @@ struct CSRSparseSparseMatrixMatMul GpuSparseMatrixDescriptor descrC_; gpusparseOperation_t transA_; gpusparseOperation_t transB_; +#if CUDA_VERSION >= 10000 + csrgemm2Info_t info_; +#endif }; } // namespace functor