Merge pull request #38802 from nluehr:cusparse_remove_deprecated
PiperOrigin-RevId: 310182962 Change-Id: Ic9c39c151181a99fa76981b4724d2e0f51fa5f69
This commit is contained in:
commit
10b96cd214
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||||
|
#include "third_party/gpus/cuda/include/library_types.h"
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
@ -179,6 +180,10 @@ Status GpuSparse::Initialize() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define TF_CALL_CUSPARSE_DTYPES(m) \
|
||||||
|
m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
|
||||||
|
m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
|
||||||
|
|
||||||
// Macro that specializes a sparse method for all 4 standard
|
// Macro that specializes a sparse method for all 4 standard
|
||||||
// numeric types.
|
// numeric types.
|
||||||
// TODO: reuse with cuda_solvers
|
// TODO: reuse with cuda_solvers
|
||||||
@ -359,23 +364,30 @@ Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
|
Status GpuSparse::CsrgeamNnz(
|
||||||
int nnzA, const int* csrSortedRowPtrA,
|
int m, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||||
const int* csrSortedColIndA,
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
const cusparseMatDescr_t descrB, int nnzB,
|
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||||
const int* csrSortedRowPtrB,
|
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||||
const int* csrSortedColIndB,
|
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
|
||||||
const cusparseMatDescr_t descrC,
|
|
||||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
|
||||||
DCHECK(initialized_);
|
DCHECK(initialized_);
|
||||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||||
|
#if CUDA_VERSION >= 10000
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeam2Nnz(
|
||||||
|
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
|
||||||
|
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||||
|
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr, workspace));
|
||||||
|
#else
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
|
||||||
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
|
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
|
||||||
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||||
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||||
|
#endif
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10020
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFnT>
|
template <typename Scalar, typename SparseFnT>
|
||||||
static inline Status CsrmmImpl(
|
static inline Status CsrmmImpl(
|
||||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
@ -416,6 +428,45 @@ static inline Status CsrmmImpl(
|
|||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
|
TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::SpMMBufferSize<Scalar>( \
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||||
|
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||||
|
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||||
|
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
|
||||||
|
const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
|
||||||
|
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \
|
||||||
|
dtype, alg, bufferSize)); \
|
||||||
|
return Status::OK(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
|
||||||
|
|
||||||
|
#define SPMM_INSTANCE(Scalar, dtype) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::SpMM<Scalar>( \
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||||
|
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||||
|
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||||
|
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \
|
||||||
|
transB, alpha, matA, matB, beta, \
|
||||||
|
matC, dtype, alg, buffer)); \
|
||||||
|
return Status::OK(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10020
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFnT>
|
template <typename Scalar, typename SparseFnT>
|
||||||
static inline Status CsrmvImpl(
|
static inline Status CsrmvImpl(
|
||||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
@ -455,6 +506,115 @@ static inline Status CsrmvImpl(
|
|||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
|
TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||||
|
cusparseHandle_t cusparse_handle,
|
||||||
|
cusparseOperation_t transA, int m, int n,
|
||||||
|
int nnz, const Scalar* alpha_host,
|
||||||
|
const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* x,
|
||||||
|
const Scalar* beta_host, Scalar* y) {
|
||||||
|
cusparseMatDescr_t descrA;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
// CUSPARSE_ALG_MERGE_PATH algo only supports non-transpose matrix.
|
||||||
|
DCHECK(transA == CUSPARSE_OPERATION_NON_TRANSPOSE);
|
||||||
|
|
||||||
|
size_t bufferSize;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
|
||||||
|
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
|
||||||
|
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
|
||||||
|
|
||||||
|
Tensor buffer;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||||
|
auto pBuffer = buffer.flat<int8>();
|
||||||
|
DCHECK(pBuffer.data() != nullptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
|
||||||
|
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
|
||||||
|
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||||
|
cusparseHandle_t cusparse_handle,
|
||||||
|
cusparseOperation_t transA, int m, int n, int nnz,
|
||||||
|
const Scalar* alpha_host,
|
||||||
|
const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* x,
|
||||||
|
const Scalar* beta_host, Scalar* y) {
|
||||||
|
cusparseSpMatDescr_t matA;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
|
||||||
|
&matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
|
||||||
|
const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
|
||||||
|
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
|
||||||
|
|
||||||
|
cusparseDnVecDescr_t vecX, vecY;
|
||||||
|
int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
|
||||||
|
int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
|
||||||
|
|
||||||
|
size_t bufferSize;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
|
||||||
|
cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
|
||||||
|
CUSPARSE_CSRMV_ALG1, &bufferSize));
|
||||||
|
|
||||||
|
Tensor buffer;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||||
|
auto pBuffer = buffer.flat<int8>();
|
||||||
|
DCHECK(pBuffer.data() != nullptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
|
||||||
|
vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csrmv<Scalar>( \
|
||||||
|
cusparseOperation_t transA, int m, int n, int nnz, \
|
||||||
|
const Scalar* alpha_host, const Scalar* csrSortedValA, \
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||||
|
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
|
||||||
|
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \
|
||||||
|
m, n, nnz, alpha_host, csrSortedValA, \
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
|
||||||
|
} else { \
|
||||||
|
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \
|
||||||
|
n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, x, beta_host, y); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
|
||||||
|
|
||||||
|
#endif // CUDA_VERSION < 10020
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFnT>
|
template <typename Scalar, typename SparseFnT>
|
||||||
static inline Status CsrgeamImpl(
|
static inline Status CsrgeamImpl(
|
||||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
@ -483,7 +643,7 @@ static inline Status CsrgeamImpl(
|
|||||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||||
int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
|
||||||
DCHECK(initialized_); \
|
DCHECK(initialized_); \
|
||||||
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
|
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
|
||||||
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
|
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||||
@ -493,8 +653,113 @@ static inline Status CsrgeamImpl(
|
|||||||
csrSortedRowPtrC, csrSortedColIndC); \
|
csrSortedRowPtrC, csrSortedColIndC); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status Csrgeam2Impl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||||
|
cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||||
|
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC, workspace));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csrgeam<Scalar>( \
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta, \
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return Csrgeam2Impl(SPARSE_FN(csrgeam2, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||||
|
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||||
|
beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
|
||||||
|
csrSortedColIndB, descrC, csrSortedValC, \
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC, workspace); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
|
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
|
|
||||||
|
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta, \
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
*bufferSize = 0; \
|
||||||
|
return Status::OK(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrgeamBufferSizeExtImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t sparse_handle,
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||||
|
sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||||
|
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC, bufferSize));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta, \
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrgeamBufferSizeExtImpl( \
|
||||||
|
SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC, bufferSize); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRGEAM_BUFFERSIZE_INSTANCE);
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
|
|
||||||
Status GpuSparse::CsrgemmNnz(
|
Status GpuSparse::CsrgemmNnz(
|
||||||
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
|
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
|
||||||
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
|
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
|
||||||
@ -551,6 +816,101 @@ static inline Status CsrgemmImpl(
|
|||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static const T* one_ptr() {
|
||||||
|
static const T one = static_cast<T>(1);
|
||||||
|
return &one;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static const T* null_ptr() {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEMM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::CsrgemmBufferSize<Scalar>( \
|
||||||
|
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, \
|
||||||
|
const int* csrSortedColIndB, csrgemm2Info_t info, \
|
||||||
|
size_t* workspaceBytes) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \
|
||||||
|
sparse_prefix)( \
|
||||||
|
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
|
||||||
|
nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
|
||||||
|
descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes)); \
|
||||||
|
return Status::OK(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
|
||||||
|
|
||||||
|
Status GpuSparse::CsrgemmNnz(
|
||||||
|
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||||
|
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
|
||||||
|
void* workspace) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
|
||||||
|
*gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
|
||||||
|
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||||
|
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
|
||||||
|
nnzTotalDevHostPtr, info, workspace));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrgemmImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
|
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
|
||||||
|
const csrgemm2Info_t info, void* workspace) {
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
|
||||||
|
nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
|
||||||
|
csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
|
||||||
|
AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
|
||||||
|
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
|
||||||
|
csrSortedColIndC, info, workspace));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csrgemm<Scalar>( \
|
||||||
|
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
|
||||||
|
const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC, \
|
||||||
|
const csrgemm2Info_t info, void* workspace) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrgemmImpl(SPARSE_FN(csrgemm2, sparse_prefix), context_, \
|
||||||
|
*gpusparse_handle_, m, n, k, descrA, nnzA, \
|
||||||
|
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||||
|
descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
|
||||||
|
csrSortedColIndB, descrC, csrSortedValC, \
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC, info, workspace); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||||
|
|
||||||
|
#endif // CUDA_VERSION < 10000
|
||||||
|
|
||||||
template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
|
template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
|
||||||
static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||||
OpKernelContext* context,
|
OpKernelContext* context,
|
||||||
@ -596,6 +956,8 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
|||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
|
TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10010
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFnT>
|
template <typename Scalar, typename SparseFnT>
|
||||||
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||||
cusparseHandle_t cusparse_handle, int m, int n,
|
cusparseHandle_t cusparse_handle, int m, int n,
|
||||||
@ -624,6 +986,53 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
|||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||||
|
cusparseHandle_t cusparse_handle, int m, int n,
|
||||||
|
int nnz, const Scalar* csrVal,
|
||||||
|
const int* csrRowPtr, const int* csrColInd,
|
||||||
|
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||||
|
const cusparseAction_t copyValues) {
|
||||||
|
size_t bufferSize;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
|
||||||
|
cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||||
|
AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
|
||||||
|
CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
|
||||||
|
|
||||||
|
Tensor buffer;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataTypeToEnum<Scalar>::value,
|
||||||
|
TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||||
|
|
||||||
|
DCHECK(buffer.flat<Scalar>().data() != nullptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
|
||||||
|
csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
|
||||||
|
cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
|
||||||
|
CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
|
||||||
|
template <> \
|
||||||
|
Status GpuSparse::Csr2csc<Scalar>( \
|
||||||
|
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||||
|
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||||
|
const cusparseAction_t copyValues) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
|
||||||
|
csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \
|
||||||
|
cscColPtr, copyValues); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);
|
||||||
|
|
||||||
|
#endif // CUDA_VERSION < 10010
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#include "third_party/gpus/cuda/include/cuda.h"
|
||||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||||
|
|
||||||
using gpusparseStatus_t = cusparseStatus_t;
|
using gpusparseStatus_t = cusparseStatus_t;
|
||||||
@ -34,6 +35,11 @@ using gpusparseMatDescr_t = cusparseMatDescr_t;
|
|||||||
using gpusparseAction_t = cusparseAction_t;
|
using gpusparseAction_t = cusparseAction_t;
|
||||||
using gpusparseHandle_t = cusparseHandle_t;
|
using gpusparseHandle_t = cusparseHandle_t;
|
||||||
using gpuStream_t = cudaStream_t;
|
using gpuStream_t = cudaStream_t;
|
||||||
|
#if CUDA_VERSION >= 10020
|
||||||
|
using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
|
||||||
|
using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
|
||||||
|
using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define GPUSPARSE(postfix) CUSPARSE_##postfix
|
#define GPUSPARSE(postfix) CUSPARSE_##postfix
|
||||||
#define gpusparse(postfix) cusparse##postfix
|
#define gpusparse(postfix) cusparse##postfix
|
||||||
@ -253,6 +259,7 @@ class GpuSparse {
|
|||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
||||||
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10020
|
||||||
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
||||||
// where A is a sparse matrix in CSR format, B and C are dense tall
|
// where A is a sparse matrix in CSR format, B and C are dense tall
|
||||||
// matrices. This routine allows transposition of matrix B, which
|
// matrices. This routine allows transposition of matrix B, which
|
||||||
@ -272,18 +279,64 @@ class GpuSparse {
|
|||||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
|
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
|
||||||
int ldc) const;
|
int ldc) const;
|
||||||
|
#else
|
||||||
|
// Workspace size query for sparse-dense matrix multiplication. Helper
|
||||||
|
// function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
|
||||||
|
// where A is a sparse matrix in CSR format, B and C are dense matricies in
|
||||||
|
// column-major format. Returns needed workspace size in bytes.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status SpMMBufferSize(gpusparseOperation_t transA,
|
||||||
|
gpusparseOperation_t transB, const Scalar* alpha,
|
||||||
|
const gpusparseSpMatDescr_t matA,
|
||||||
|
const gpusparseDnMatDescr_t matB, const Scalar* beta,
|
||||||
|
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
|
||||||
|
size_t* bufferSize) const;
|
||||||
|
|
||||||
|
// Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
|
||||||
|
// where A is a sparse matrix in CSR format, B and C are dense matricies in
|
||||||
|
// column-major format. Buffer is assumed to be at least as large as the
|
||||||
|
// workspace size returned by SpMMBufferSize().
|
||||||
|
//
|
||||||
|
// **NOTE** This is an in-place operation for data in C.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||||
|
const Scalar* alpha, const gpusparseSpMatDescr_t matA,
|
||||||
|
const gpusparseDnMatDescr_t matB, const Scalar* beta,
|
||||||
|
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
|
||||||
|
int8* buffer) const;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
|
// Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
|
||||||
// where A is a sparse matrix in CSR format, x and y are dense vectors. See:
|
// where A is a sparse matrix in CSR format, x and y are dense vectors. See:
|
||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
||||||
//
|
//
|
||||||
// **NOTE** This is an in-place operation for data in y.
|
// **NOTE** This is an in-place operation for data in y.
|
||||||
|
#if CUDA_VERSION < 10020
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||||
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
||||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
const int* csrSortedColIndA, const Scalar* x,
|
const int* csrSortedColIndA, const Scalar* x,
|
||||||
const Scalar* beta_host, Scalar* y) const;
|
const Scalar* beta_host, Scalar* y) const;
|
||||||
|
#else
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||||
|
const Scalar* alpha_host, const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const Scalar* x, const Scalar* beta_host, Scalar* y) const;
|
||||||
|
#endif // CUDA_VERSION < 10020
|
||||||
|
|
||||||
|
// Computes workspace size for sparse - sparse matrix addition of matrices
|
||||||
|
// stored in CSR format.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status CsrgeamBufferSizeExt(
|
||||||
|
int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta,
|
||||||
|
const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
|
||||||
|
|
||||||
// Computes sparse-sparse matrix addition of matrices
|
// Computes sparse-sparse matrix addition of matrices
|
||||||
// stored in CSR format. This is part one: calculate nnz of the
|
// stored in CSR format. This is part one: calculate nnz of the
|
||||||
@ -295,7 +348,7 @@ class GpuSparse {
|
|||||||
const gpusparseMatDescr_t descrB, int nnzB,
|
const gpusparseMatDescr_t descrB, int nnzB,
|
||||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||||
int* nnzTotalDevHostPtr);
|
int* nnzTotalDevHostPtr, void* workspace);
|
||||||
|
|
||||||
// Computes sparse - sparse matrix addition of matrices
|
// Computes sparse - sparse matrix addition of matrices
|
||||||
// stored in CSR format. This is part two: perform sparse-sparse
|
// stored in CSR format. This is part two: perform sparse-sparse
|
||||||
@ -311,13 +364,26 @@ class GpuSparse {
|
|||||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||||
int* csrSortedColIndC);
|
int* csrSortedColIndC, void* workspace);
|
||||||
|
|
||||||
|
#if CUDA_VERSION >= 10000
|
||||||
|
// Computes sparse-sparse matrix multiplication of matrices
|
||||||
|
// stored in CSR format. This is part zero: calculate required workspace
|
||||||
|
// size.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status CsrgemmBufferSize(
|
||||||
|
int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Computes sparse-sparse matrix multiplication of matrices
|
// Computes sparse-sparse matrix multiplication of matrices
|
||||||
// stored in CSR format. This is part one: calculate nnz of the
|
// stored in CSR format. This is part one: calculate nnz of the
|
||||||
// output. csrSortedRowPtrC must be preallocated on device with
|
// output. csrSortedRowPtrC must be preallocated on device with
|
||||||
// m + 1 entries. See:
|
// m + 1 entries. See:
|
||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||||
int nnzA, const int* csrSortedRowPtrA,
|
int nnzA, const int* csrSortedRowPtrA,
|
||||||
@ -326,12 +392,23 @@ class GpuSparse {
|
|||||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||||
int* nnzTotalDevHostPtr);
|
int* nnzTotalDevHostPtr);
|
||||||
|
#else
|
||||||
|
Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA,
|
||||||
|
const gpusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||||
|
int* nnzTotalDevHostPtr, csrgemm2Info_t info,
|
||||||
|
void* workspace);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Computes sparse - sparse matrix matmul of matrices
|
// Computes sparse - sparse matrix matmul of matrices
|
||||||
// stored in CSR format. This is part two: perform sparse-sparse
|
// stored in CSR format. This is part two: perform sparse-sparse
|
||||||
// addition. csrValC and csrColIndC must be allocated on the device
|
// addition. csrValC and csrColIndC must be allocated on the device
|
||||||
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||||
@ -342,6 +419,18 @@ class GpuSparse {
|
|||||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||||
int* csrSortedColIndC);
|
int* csrSortedColIndC);
|
||||||
|
#else
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const gpusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||||
|
int* csrSortedColIndC, const csrgemm2Info_t info,
|
||||||
|
void* workspace);
|
||||||
|
#endif
|
||||||
|
|
||||||
// In-place reordering of unsorted CSR to sorted CSR.
|
// In-place reordering of unsorted CSR to sorted CSR.
|
||||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
|
||||||
|
@ -107,6 +107,26 @@ class CSRSparseMatrixAddFunctor {
|
|||||||
const Device& d = ctx_->eigen_device<Device>();
|
const Device& d = ctx_->eigen_device<Device>();
|
||||||
set_zero(d, c_row_ptr_t.flat<int32>());
|
set_zero(d, c_row_ptr_t.flat<int32>());
|
||||||
|
|
||||||
|
size_t maxWorkspaceSize = 0;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
|
||||||
|
a.values_vec<T>(i), a_dense_shape};
|
||||||
|
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
|
||||||
|
b.values_vec<T>(i), b_dense_shape};
|
||||||
|
|
||||||
|
size_t thisWorkspaceSize;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
csr_geam.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
|
||||||
|
if (thisWorkspaceSize > maxWorkspaceSize) {
|
||||||
|
maxWorkspaceSize = thisWorkspaceSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor temp;
|
||||||
|
TF_RETURN_IF_ERROR(ctx_->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}), &temp));
|
||||||
|
void* workspace = temp.flat<int8>().data();
|
||||||
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
// Calculate output sizes for all minibatch entries.
|
// Calculate output sizes for all minibatch entries.
|
||||||
// Store in c_batch_ptr and update c_row_ptrs.
|
// Store in c_batch_ptr and update c_row_ptrs.
|
||||||
@ -121,8 +141,8 @@ class CSRSparseMatrixAddFunctor {
|
|||||||
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
|
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
|
||||||
rows + 1);
|
rows + 1);
|
||||||
int c_nnz_i;
|
int c_nnz_i;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure(
|
||||||
csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i));
|
a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace));
|
||||||
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +171,7 @@ class CSRSparseMatrixAddFunctor {
|
|||||||
CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i),
|
CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i),
|
||||||
c->values_vec<T>(i), c_dense_shape_t.vec<int64>()};
|
c->values_vec<T>(i), c_dense_shape_t.vec<int64>()};
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp));
|
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -269,10 +289,36 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
|
||||||
|
const int m = a.row_ptr.size() - 1;
|
||||||
|
DCHECK_EQ(m, b.row_ptr.size() - 1);
|
||||||
|
const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
|
||||||
|
DCHECK_EQ(m, a.dense_shape_host(row_dim));
|
||||||
|
DCHECK_EQ(m, b.dense_shape_host(row_dim));
|
||||||
|
const int nnzA = a.col_ind.size();
|
||||||
|
const int nnzB = b.col_ind.size();
|
||||||
|
|
||||||
|
const int n = a.dense_shape_host(row_dim + 1);
|
||||||
|
DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
|
||||||
|
T* null_T = nullptr;
|
||||||
|
int* null_int = nullptr;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt(
|
||||||
|
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||||
|
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
||||||
|
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int,
|
||||||
|
null_int, bufferSize));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||||
const ConstCSRComponent<T>& b,
|
const ConstCSRComponent<T>& b,
|
||||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||||
int* output_nnz) {
|
int* output_nnz, void* workspace) {
|
||||||
DCHECK(initialized_);
|
DCHECK(initialized_);
|
||||||
|
|
||||||
const int m = a.row_ptr.size() - 1;
|
const int m = a.row_ptr.size() - 1;
|
||||||
@ -290,7 +336,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
|||||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz(
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz(
|
||||||
m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
||||||
descrC_.descr(), c_row_ptr.data(), output_nnz));
|
descrC_.descr(), c_row_ptr.data(), output_nnz, workspace));
|
||||||
|
|
||||||
if (*output_nnz < 0) {
|
if (*output_nnz < 0) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
@ -300,7 +346,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
||||||
CSRComponent<T>* c) {
|
CSRComponent<T>* c, void* workspace) {
|
||||||
DCHECK(initialized_);
|
DCHECK(initialized_);
|
||||||
|
|
||||||
const int m = a.row_ptr.size() - 1;
|
const int m = a.row_ptr.size() - 1;
|
||||||
@ -319,7 +365,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
|||||||
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||||
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
||||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
||||||
c->row_ptr.data(), c->col_ind.data()));
|
c->row_ptr.data(), c->col_ind.data(), workspace));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -167,13 +167,18 @@ struct CSRStructureModifyingFunctor {
|
|||||||
|
|
||||||
virtual Status Initialize() = 0;
|
virtual Status Initialize() = 0;
|
||||||
|
|
||||||
|
virtual Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b,
|
||||||
|
size_t* bufferSize) = 0;
|
||||||
|
|
||||||
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||||
const ConstCSRComponent<T>& b,
|
const ConstCSRComponent<T>& b,
|
||||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||||
int* output_nnz) = 0;
|
int* output_nnz, void* workspace) = 0;
|
||||||
|
|
||||||
virtual Status Compute(const ConstCSRComponent<T>& a,
|
virtual Status Compute(const ConstCSRComponent<T>& a,
|
||||||
const ConstCSRComponent<T>& b, CSRComponent<T>* c) = 0;
|
const ConstCSRComponent<T>& b, CSRComponent<T>* c,
|
||||||
|
void* workspace) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Calculates C = alpha * A + beta * B, where A and B are in CSR
|
// Calculates C = alpha * A + beta * B, where A and B are in CSR
|
||||||
|
@ -721,6 +721,56 @@ REGISTER_GPU(complex128)
|
|||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// CUDADataType<T>::type translates from a C++ type (e.g. float) to a
|
||||||
|
// cudaDataType_t (e.g. CUDA_R_32F).
|
||||||
|
template <typename T>
|
||||||
|
struct CUDADataType;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDADataType<Eigen::half> {
|
||||||
|
static constexpr cudaDataType_t type = CUDA_R_16F;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDADataType<float> {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
static constexpr cudaDataType_t type = CUDA_R_32F;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
static constexpr cudaDataType_t type = HIPBLAS_R_32F;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDADataType<std::complex<float>> {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
static constexpr cudaDataType_t type = CUDA_C_32F;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
static constexpr cudaDataType_t type = HIPBLAS_C_32F;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDADataType<double> {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
static constexpr cudaDataType_t type = CUDA_R_64F;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
static constexpr cudaDataType_t type = HIPBLAS_R_64F;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDADataType<std::complex<double>> {
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
static constexpr cudaDataType_t type = HIPBLAS_C_64F;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CSRSparseMatrixMatMul<GPUDevice, T> {
|
class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||||
public:
|
public:
|
||||||
@ -733,10 +783,10 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
|||||||
GpuSparse cuda_sparse(ctx);
|
GpuSparse cuda_sparse(ctx);
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
{
|
{
|
||||||
// Use Csrmm to calculate:
|
// Use Csrmm/SpMM to calculate:
|
||||||
// C = alpha * op(A) * op(B) + beta * C
|
// C = alpha * op(A) * op(B) + beta * C
|
||||||
// where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
|
// where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
|
||||||
// Note that Csrmm assumes B and C are in column-major form; so we
|
// Note that Csrmm/Spmm assumes B and C are in column-major form; so we
|
||||||
// use transB == true, and manually transpose the output in place
|
// use transB == true, and manually transpose the output in place
|
||||||
// using blas<t>geam.
|
// using blas<t>geam.
|
||||||
// TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
|
// TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
|
||||||
@ -746,22 +796,6 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
|||||||
const T alpha = 1;
|
const T alpha = 1;
|
||||||
const T beta = 0;
|
const T beta = 0;
|
||||||
|
|
||||||
// transA must be non-transpose if transB is transpose (cusparse
|
|
||||||
// limitation).
|
|
||||||
const gpusparseOperation_t transA = GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
|
||||||
|
|
||||||
// transB: b is row-major, and cusparse requires col-major b (or
|
|
||||||
// equivalently transB == transpose). this version is actually more
|
|
||||||
// efficient.
|
|
||||||
const gpusparseOperation_t transB = GPUSPARSE(OPERATION_TRANSPOSE);
|
|
||||||
|
|
||||||
gpusparseMatDescr_t descrA;
|
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
|
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
|
||||||
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
|
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
|
||||||
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
|
|
||||||
|
|
||||||
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
||||||
const int k = b.dimension(0);
|
const int k = b.dimension(0);
|
||||||
DCHECK_EQ(k, a.dense_shape_host(1));
|
DCHECK_EQ(k, a.dense_shape_host(1));
|
||||||
@ -786,10 +820,87 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
|||||||
// op(A) = A and at least max(1, k) otherwise.
|
// op(A) = A and at least max(1, k) otherwise.
|
||||||
const int ldc = m;
|
const int ldc = m;
|
||||||
|
|
||||||
|
// transA must be non-transpose if transB is transpose (cusparse
|
||||||
|
// limitation).
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// transB: b is row-major, and cusparse requires col-major b (or
|
||||||
|
// equivalently transB == transpose). this version is actually more
|
||||||
|
// efficient.
|
||||||
|
#if GOOGLE_CUDA && CUDA_VERSION >= 10020
|
||||||
|
|
||||||
|
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||||
|
gpusparseSpMatDescr_t matA;
|
||||||
|
gpusparseDnMatDescr_t matB, matC;
|
||||||
|
|
||||||
|
// NOTE: the following APIs are not available in ROCM
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
|
||||||
|
&matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
|
||||||
|
const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
|
||||||
|
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
|
||||||
|
CUDADataType<T>::type));
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
|
||||||
|
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseCreateDnMat(&matC, m, n, ldc, c.data(), CUDADataType<T>::type,
|
||||||
|
CUSPARSE_ORDER_COL));
|
||||||
|
|
||||||
|
size_t bufferSize = 0;
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
|
||||||
|
transA, transB, &alpha, matA, matB, &beta, matC,
|
||||||
|
CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
|
||||||
|
|
||||||
|
Tensor buffer;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||||
|
DCHECK(buffer.flat<int8>().data() != nullptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
|
||||||
|
&beta, matC, CUSPARSE_MM_ALG_DEFAULT,
|
||||||
|
buffer.flat<int8>().data()));
|
||||||
|
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||||
|
|
||||||
|
gpusparseMatDescr_t descrA;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
|
||||||
|
|
||||||
|
gpusparseMatDescr_t descrA;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
|
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
|
||||||
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
|
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
|
||||||
b.data(), ldb, &beta, c.data(), ldc));
|
b.data(), ldb, &beta, c.data(), ldc));
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -822,20 +933,35 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
|||||||
const T alpha = 1;
|
const T alpha = 1;
|
||||||
const T beta = 0;
|
const T beta = 0;
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA && CUDA_VERSION < 10020
|
||||||
gpusparseMatDescr_t descrA;
|
gpusparseMatDescr_t descrA;
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
|
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
|
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
gpusparseMatDescr_t descrA;
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||||
|
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
const int m = a.dense_shape_host(0);
|
const int m = a.dense_shape_host(0);
|
||||||
const int n = a.dense_shape_host(1);
|
const int n = a.dense_shape_host(1);
|
||||||
const int nnz = a.values.size();
|
const int nnz = a.values.size();
|
||||||
DCHECK_EQ(nnz, a.col_ind.size());
|
DCHECK_EQ(nnz, a.col_ind.size());
|
||||||
|
#if CUDA_VERSION >= 10020
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
|
||||||
|
a.values.data(), a.row_ptr.data(),
|
||||||
|
a.col_ind.data(), x, &beta, y));
|
||||||
|
#else
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
|
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
|
||||||
a.values.data(), a.row_ptr.data(),
|
a.values.data(), a.row_ptr.data(),
|
||||||
a.col_ind.data(), x, &beta, y));
|
a.col_ind.data(), x, &beta, y));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -417,6 +417,36 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
|
#if CUDA_VERSION >= 10000
|
||||||
|
size_t maxWorkspaceSize = 0;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
// Calculate maximum workspace size over batch.
|
||||||
|
ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
|
||||||
|
a_input_matrix->col_indices_vec(i),
|
||||||
|
a_input_matrix->values_vec<T>(i),
|
||||||
|
a_input_dense_shape};
|
||||||
|
ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
|
||||||
|
b_input_matrix->col_indices_vec(i),
|
||||||
|
b_input_matrix->values_vec<T>(i),
|
||||||
|
b_input_dense_shape};
|
||||||
|
size_t thisWorkspaceSize;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, csr_gemm.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
|
||||||
|
if (thisWorkspaceSize > maxWorkspaceSize) {
|
||||||
|
maxWorkspaceSize = thisWorkspaceSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor temp;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}),
|
||||||
|
&temp));
|
||||||
|
void* workspace = temp.flat<int8>().data();
|
||||||
|
#else
|
||||||
|
void* workspace = nullptr;
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
// Calculate output sizes for all minibatch entries.
|
// Calculate output sizes for all minibatch entries.
|
||||||
// Store in c_batch_ptr and update c_row_ptrs.
|
// Store in c_batch_ptr and update c_row_ptrs.
|
||||||
@ -433,8 +463,9 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
|||||||
rows + 1);
|
rows + 1);
|
||||||
|
|
||||||
int c_nnz_i;
|
int c_nnz_i;
|
||||||
OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp,
|
OP_REQUIRES_OK(ctx,
|
||||||
c_row_ptr_i, &c_nnz_i));
|
csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
|
||||||
|
&c_nnz_i, workspace));
|
||||||
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -464,7 +495,7 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
|||||||
b_input_dense_shape};
|
b_input_dense_shape};
|
||||||
CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
|
CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
|
||||||
c.values_vec<T>(i), c_dense_shape};
|
c.values_vec<T>(i), c_dense_shape};
|
||||||
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp));
|
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp, workspace));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
@ -527,7 +558,12 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
initialized_(false),
|
initialized_(false),
|
||||||
transpose_a_(transpose_a),
|
transpose_a_(transpose_a),
|
||||||
adjoint_a_(adjoint_a),
|
adjoint_a_(adjoint_a),
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
transpose_b_(transpose_b) {
|
transpose_b_(transpose_b) {
|
||||||
|
#else
|
||||||
|
transpose_b_(transpose_b),
|
||||||
|
info_(nullptr) {
|
||||||
|
#endif // CUDA_VERSION < 10000
|
||||||
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
||||||
transA_ = transpose_a
|
transA_ = transpose_a
|
||||||
? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
|
? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
|
||||||
@ -537,6 +573,14 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if CUDA_VERSION >= 10000
|
||||||
|
~CSRSparseSparseMatrixMatMul() {
|
||||||
|
if (initialized_) {
|
||||||
|
cusparseDestroyCsrgemm2Info(info_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
Status Initialize() {
|
Status Initialize() {
|
||||||
if (adjoint_a_ && transpose_a_) {
|
if (adjoint_a_ && transpose_a_) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -547,14 +591,44 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
||||||
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
||||||
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
||||||
|
#if CUDA_VERSION >= 10000
|
||||||
|
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
|
||||||
|
#endif
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
const int m =
|
||||||
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||||
|
if (!transpose_a_) {
|
||||||
|
DCHECK_EQ(m, a.row_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
const int k =
|
||||||
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
|
||||||
|
if (!transpose_b_) {
|
||||||
|
DCHECK_EQ(k, b.row_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
const int nnzA = a.col_ind.size();
|
||||||
|
const int nnzB = b.col_ind.size();
|
||||||
|
|
||||||
|
const int n =
|
||||||
|
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
|
||||||
|
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||||
|
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
|
||||||
|
bufferSize));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||||
const ConstCSRComponent<T>& b,
|
const ConstCSRComponent<T>& b,
|
||||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||||
int* output_nnz) {
|
int* output_nnz, void* workspace) {
|
||||||
DCHECK(initialized_);
|
DCHECK(initialized_);
|
||||||
|
|
||||||
const int m =
|
const int m =
|
||||||
@ -576,10 +650,17 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
|
|
||||||
*output_nnz = -1;
|
*output_nnz = -1;
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||||
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||||
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||||
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
|
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
|
||||||
|
#else
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||||
|
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||||
|
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
||||||
|
descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
|
||||||
|
#endif
|
||||||
|
|
||||||
if (*output_nnz < 0) {
|
if (*output_nnz < 0) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
@ -590,7 +671,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
||||||
CSRComponent<T>* c) {
|
CSRComponent<T>* c, void* workspace) {
|
||||||
DCHECK(initialized_);
|
DCHECK(initialized_);
|
||||||
|
|
||||||
const int m =
|
const int m =
|
||||||
@ -612,11 +693,19 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||||
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
||||||
|
|
||||||
|
#if CUDA_VERSION < 10000
|
||||||
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||||
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
||||||
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
||||||
b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
|
b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
|
||||||
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
|
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
|
||||||
|
#else
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||||
|
m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||||
|
a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
|
||||||
|
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
||||||
|
c->row_ptr.data(), c->col_ind.data(), info_, workspace));
|
||||||
|
#endif
|
||||||
|
|
||||||
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
|
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
|
||||||
// columns are sorted? Above operation leads to unsorted columns.
|
// columns are sorted? Above operation leads to unsorted columns.
|
||||||
@ -643,6 +732,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
|||||||
GpuSparseMatrixDescriptor descrC_;
|
GpuSparseMatrixDescriptor descrC_;
|
||||||
gpusparseOperation_t transA_;
|
gpusparseOperation_t transA_;
|
||||||
gpusparseOperation_t transB_;
|
gpusparseOperation_t transB_;
|
||||||
|
#if CUDA_VERSION >= 10000
|
||||||
|
csrgemm2Info_t info_;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
Loading…
x
Reference in New Issue
Block a user