Merge pull request #38802 from nluehr:cusparse_remove_deprecated
PiperOrigin-RevId: 310182962 Change-Id: Ic9c39c151181a99fa76981b4724d2e0f51fa5f69
This commit is contained in:
commit
10b96cd214
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/include/library_types.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -179,6 +180,10 @@ Status GpuSparse::Initialize() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TF_CALL_CUSPARSE_DTYPES(m) \
|
||||
m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
|
||||
m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
|
||||
|
||||
// Macro that specializes a sparse method for all 4 standard
|
||||
// numeric types.
|
||||
// TODO: reuse with cuda_solvers
|
||||
@ -359,23 +364,30 @@ Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||
Status GpuSparse::CsrgeamNnz(
|
||||
int m, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
#if CUDA_VERSION >= 10000
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeam2Nnz(
|
||||
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr, workspace));
|
||||
#else
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgeamNnz(
|
||||
*gpusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA,
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||
descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if CUDA_VERSION < 10020
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrmmImpl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
@ -416,6 +428,45 @@ static inline Status CsrmmImpl(
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||
|
||||
#else
|
||||
|
||||
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
|
||||
template <> \
|
||||
Status GpuSparse::SpMMBufferSize<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
|
||||
const { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
|
||||
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \
|
||||
dtype, alg, bufferSize)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
|
||||
|
||||
#define SPMM_INSTANCE(Scalar, dtype) \
|
||||
template <> \
|
||||
Status GpuSparse::SpMM<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \
|
||||
transB, alpha, matA, matB, beta, \
|
||||
matC, dtype, alg, buffer)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
|
||||
|
||||
#endif
|
||||
|
||||
#if CUDA_VERSION < 10020
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrmvImpl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
@ -455,6 +506,115 @@ static inline Status CsrmvImpl(
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar>
|
||||
static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
cusparseHandle_t cusparse_handle,
|
||||
cusparseOperation_t transA, int m, int n,
|
||||
int nnz, const Scalar* alpha_host,
|
||||
const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
cusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
// CUSPARSE_ALG_MERGE_PATH algo only supports non-transpose matrix.
|
||||
DCHECK(transA == CUSPARSE_OPERATION_NON_TRANSPOSE);
|
||||
|
||||
size_t bufferSize;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
|
||||
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
|
||||
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
|
||||
x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||
auto pBuffer = buffer.flat<int8>();
|
||||
DCHECK(pBuffer.data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
|
||||
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
|
||||
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
|
||||
x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
cusparseHandle_t cusparse_handle,
|
||||
cusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host,
|
||||
const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
cusparseSpMatDescr_t matA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
|
||||
&matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
|
||||
const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
|
||||
|
||||
cusparseDnVecDescr_t vecX, vecY;
|
||||
int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
|
||||
int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
|
||||
|
||||
size_t bufferSize;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
|
||||
cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
|
||||
CUSPARSE_CSRMV_ALG1, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||
auto pBuffer = buffer.flat<int8>();
|
||||
DCHECK(pBuffer.data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
|
||||
vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmv<Scalar>( \
|
||||
cusparseOperation_t transA, int m, int n, int nnz, \
|
||||
const Scalar* alpha_host, const Scalar* csrSortedValA, \
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
|
||||
DCHECK(initialized_); \
|
||||
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
|
||||
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \
|
||||
m, n, nnz, alpha_host, csrSortedValA, \
|
||||
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
|
||||
} else { \
|
||||
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \
|
||||
n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, x, beta_host, y); \
|
||||
} \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
|
||||
|
||||
#endif // CUDA_VERSION < 10020
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrgeamImpl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
@ -483,7 +643,7 @@ static inline Status CsrgeamImpl(
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||
@ -493,8 +653,113 @@ static inline Status CsrgeamImpl(
|
||||
csrSortedRowPtrC, csrSortedColIndC); \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status Csrgeam2Impl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* beta,
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||
cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||
csrSortedRowPtrC, csrSortedColIndC, workspace));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrgeam<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* beta, \
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csrgeam2Impl(SPARSE_FN(csrgeam2, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||
beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
|
||||
csrSortedColIndB, descrC, csrSortedValC, \
|
||||
csrSortedRowPtrC, csrSortedColIndC, workspace); \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
|
||||
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* beta, \
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
|
||||
DCHECK(initialized_); \
|
||||
*bufferSize = 0; \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrgeamBufferSizeExtImpl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t sparse_handle,
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* beta,
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||
sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||
csrSortedRowPtrC, csrSortedColIndC, bufferSize));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* beta, \
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgeamBufferSizeExtImpl( \
|
||||
SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \
|
||||
csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \
|
||||
csrSortedRowPtrC, csrSortedColIndC, bufferSize); \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEAM_BUFFERSIZE_INSTANCE);
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
|
||||
Status GpuSparse::CsrgemmNnz(
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
|
||||
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
|
||||
@ -551,6 +816,101 @@ static inline Status CsrgemmImpl(
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||
|
||||
#else
|
||||
|
||||
template <typename T>
|
||||
static const T* one_ptr() {
|
||||
static const T one = static_cast<T>(1);
|
||||
return &one;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static const T* null_ptr() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#define CSRGEMM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::CsrgemmBufferSize<Scalar>( \
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, \
|
||||
const int* csrSortedColIndB, csrgemm2Info_t info, \
|
||||
size_t* workspaceBytes) { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \
|
||||
sparse_prefix)( \
|
||||
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
|
||||
nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
|
||||
descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
|
||||
|
||||
Status GpuSparse::CsrgemmNnz(
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
|
||||
void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
|
||||
*gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
|
||||
nnzTotalDevHostPtr, info, workspace));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrgemmImpl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
|
||||
const csrgemm2Info_t info, void* workspace) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
|
||||
nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
|
||||
csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
|
||||
AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
|
||||
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
|
||||
csrSortedColIndC, info, workspace));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrgemm<Scalar>( \
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA, \
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC, \
|
||||
const csrgemm2Info_t info, void* workspace) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgemmImpl(SPARSE_FN(csrgemm2, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, k, descrA, nnzA, \
|
||||
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||
descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
|
||||
csrSortedColIndB, descrC, csrSortedValC, \
|
||||
csrSortedRowPtrC, csrSortedColIndC, info, workspace); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||
|
||||
#endif // CUDA_VERSION < 10000
|
||||
|
||||
template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
|
||||
static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||
OpKernelContext* context,
|
||||
@ -596,6 +956,8 @@ static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
|
||||
|
||||
#if CUDA_VERSION < 10010
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||
cusparseHandle_t cusparse_handle, int m, int n,
|
||||
@ -624,6 +986,53 @@ static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
||||
|
||||
#else
|
||||
|
||||
template <typename Scalar>
|
||||
static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
cusparseHandle_t cusparse_handle, int m, int n,
|
||||
int nnz, const Scalar* csrVal,
|
||||
const int* csrRowPtr, const int* csrColInd,
|
||||
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||
const cusparseAction_t copyValues) {
|
||||
size_t bufferSize;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
|
||||
cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
|
||||
CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataTypeToEnum<Scalar>::value,
|
||||
TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||
|
||||
DCHECK(buffer.flat<Scalar>().data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
|
||||
csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
|
||||
cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
|
||||
CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
|
||||
template <> \
|
||||
Status GpuSparse::Csr2csc<Scalar>( \
|
||||
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||
const cusparseAction_t copyValues) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
|
||||
csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \
|
||||
cscColPtr, copyValues); \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);
|
||||
|
||||
#endif // CUDA_VERSION < 10010
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
using gpusparseStatus_t = cusparseStatus_t;
|
||||
@ -34,6 +35,11 @@ using gpusparseMatDescr_t = cusparseMatDescr_t;
|
||||
using gpusparseAction_t = cusparseAction_t;
|
||||
using gpusparseHandle_t = cusparseHandle_t;
|
||||
using gpuStream_t = cudaStream_t;
|
||||
#if CUDA_VERSION >= 10020
|
||||
using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
|
||||
using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
|
||||
using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
|
||||
#endif
|
||||
|
||||
#define GPUSPARSE(postfix) CUSPARSE_##postfix
|
||||
#define gpusparse(postfix) cusparse##postfix
|
||||
@ -253,6 +259,7 @@ class GpuSparse {
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
||||
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
||||
|
||||
#if CUDA_VERSION < 10020
|
||||
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
||||
// where A is a sparse matrix in CSR format, B and C are dense tall
|
||||
// matrices. This routine allows transposition of matrix B, which
|
||||
@ -272,18 +279,64 @@ class GpuSparse {
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
|
||||
int ldc) const;
|
||||
#else
|
||||
// Workspace size query for sparse-dense matrix multiplication. Helper
|
||||
// function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
|
||||
// where A is a sparse matrix in CSR format, B and C are dense matricies in
|
||||
// column-major format. Returns needed workspace size in bytes.
|
||||
template <typename Scalar>
|
||||
Status SpMMBufferSize(gpusparseOperation_t transA,
|
||||
gpusparseOperation_t transB, const Scalar* alpha,
|
||||
const gpusparseSpMatDescr_t matA,
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta,
|
||||
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
|
||||
size_t* bufferSize) const;
|
||||
|
||||
// Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
|
||||
// where A is a sparse matrix in CSR format, B and C are dense matricies in
|
||||
// column-major format. Buffer is assumed to be at least as large as the
|
||||
// workspace size returned by SpMMBufferSize().
|
||||
//
|
||||
// **NOTE** This is an in-place operation for data in C.
|
||||
template <typename Scalar>
|
||||
Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
const Scalar* alpha, const gpusparseSpMatDescr_t matA,
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta,
|
||||
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
|
||||
int8* buffer) const;
|
||||
#endif
|
||||
|
||||
// Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
|
||||
// where A is a sparse matrix in CSR format, x and y are dense vectors. See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
||||
//
|
||||
// **NOTE** This is an in-place operation for data in y.
|
||||
#if CUDA_VERSION < 10020
|
||||
template <typename Scalar>
|
||||
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) const;
|
||||
#else
|
||||
template <typename Scalar>
|
||||
Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host, const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const Scalar* x, const Scalar* beta_host, Scalar* y) const;
|
||||
#endif // CUDA_VERSION < 10020
|
||||
|
||||
// Computes workspace size for sparse - sparse matrix addition of matrices
|
||||
// stored in CSR format.
|
||||
template <typename Scalar>
|
||||
Status CsrgeamBufferSizeExt(
|
||||
int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* beta,
|
||||
const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
|
||||
|
||||
// Computes sparse-sparse matrix addition of matrices
|
||||
// stored in CSR format. This is part one: calculate nnz of the
|
||||
@ -295,7 +348,7 @@ class GpuSparse {
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
int* nnzTotalDevHostPtr);
|
||||
int* nnzTotalDevHostPtr, void* workspace);
|
||||
|
||||
// Computes sparse - sparse matrix addition of matrices
|
||||
// stored in CSR format. This is part two: perform sparse-sparse
|
||||
@ -311,13 +364,26 @@ class GpuSparse {
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC);
|
||||
int* csrSortedColIndC, void* workspace);
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
// Computes sparse-sparse matrix multiplication of matrices
|
||||
// stored in CSR format. This is part zero: calculate required workspace
|
||||
// size.
|
||||
template <typename Scalar>
|
||||
Status CsrgemmBufferSize(
|
||||
int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
|
||||
#endif
|
||||
|
||||
// Computes sparse-sparse matrix multiplication of matrices
|
||||
// stored in CSR format. This is part one: calculate nnz of the
|
||||
// output. csrSortedRowPtrC must be preallocated on device with
|
||||
// m + 1 entries. See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||
#if CUDA_VERSION < 10000
|
||||
Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
@ -326,12 +392,23 @@ class GpuSparse {
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
int* nnzTotalDevHostPtr);
|
||||
#else
|
||||
Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
int* nnzTotalDevHostPtr, csrgemm2Info_t info,
|
||||
void* workspace);
|
||||
#endif
|
||||
|
||||
// Computes sparse - sparse matrix matmul of matrices
|
||||
// stored in CSR format. This is part two: perform sparse-sparse
|
||||
// addition. csrValC and csrColIndC must be allocated on the device
|
||||
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||
#if CUDA_VERSION < 10000
|
||||
template <typename Scalar>
|
||||
Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
int m, int k, int n, const gpusparseMatDescr_t descrA,
|
||||
@ -342,6 +419,18 @@ class GpuSparse {
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC);
|
||||
#else
|
||||
template <typename Scalar>
|
||||
Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC, const csrgemm2Info_t info,
|
||||
void* workspace);
|
||||
#endif
|
||||
|
||||
// In-place reordering of unsorted CSR to sorted CSR.
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
|
||||
|
@ -107,6 +107,26 @@ class CSRSparseMatrixAddFunctor {
|
||||
const Device& d = ctx_->eigen_device<Device>();
|
||||
set_zero(d, c_row_ptr_t.flat<int32>());
|
||||
|
||||
size_t maxWorkspaceSize = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
|
||||
a.values_vec<T>(i), a_dense_shape};
|
||||
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
|
||||
b.values_vec<T>(i), b_dense_shape};
|
||||
|
||||
size_t thisWorkspaceSize;
|
||||
TF_RETURN_IF_ERROR(
|
||||
csr_geam.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
|
||||
if (thisWorkspaceSize > maxWorkspaceSize) {
|
||||
maxWorkspaceSize = thisWorkspaceSize;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor temp;
|
||||
TF_RETURN_IF_ERROR(ctx_->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}), &temp));
|
||||
void* workspace = temp.flat<int8>().data();
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Calculate output sizes for all minibatch entries.
|
||||
// Store in c_batch_ptr and update c_row_ptrs.
|
||||
@ -121,8 +141,8 @@ class CSRSparseMatrixAddFunctor {
|
||||
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
|
||||
rows + 1);
|
||||
int c_nnz_i;
|
||||
TF_RETURN_IF_ERROR(
|
||||
csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i));
|
||||
TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure(
|
||||
a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace));
|
||||
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||
}
|
||||
|
||||
@ -151,7 +171,7 @@ class CSRSparseMatrixAddFunctor {
|
||||
CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i),
|
||||
c->values_vec<T>(i), c_dense_shape_t.vec<int64>()};
|
||||
|
||||
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp));
|
||||
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp, workspace));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -269,10 +289,36 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||
DCHECK(initialized_);
|
||||
|
||||
const int m = a.row_ptr.size() - 1;
|
||||
DCHECK_EQ(m, b.row_ptr.size() - 1);
|
||||
const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
|
||||
DCHECK_EQ(m, a.dense_shape_host(row_dim));
|
||||
DCHECK_EQ(m, b.dense_shape_host(row_dim));
|
||||
const int nnzA = a.col_ind.size();
|
||||
const int nnzB = b.col_ind.size();
|
||||
|
||||
const int n = a.dense_shape_host(row_dim + 1);
|
||||
DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
|
||||
T* null_T = nullptr;
|
||||
int* null_int = nullptr;
|
||||
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt(
|
||||
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int,
|
||||
null_int, bufferSize));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||
int* output_nnz) {
|
||||
int* output_nnz, void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
|
||||
const int m = a.row_ptr.size() - 1;
|
||||
@ -290,7 +336,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz(
|
||||
m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
||||
descrC_.descr(), c_row_ptr.data(), output_nnz));
|
||||
descrC_.descr(), c_row_ptr.data(), output_nnz, workspace));
|
||||
|
||||
if (*output_nnz < 0) {
|
||||
return errors::Internal(
|
||||
@ -300,7 +346,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
}
|
||||
|
||||
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
||||
CSRComponent<T>* c) {
|
||||
CSRComponent<T>* c, void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
|
||||
const int m = a.row_ptr.size() - 1;
|
||||
@ -319,7 +365,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
||||
c->row_ptr.data(), c->col_ind.data()));
|
||||
c->row_ptr.data(), c->col_ind.data(), workspace));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -167,13 +167,18 @@ struct CSRStructureModifyingFunctor {
|
||||
|
||||
virtual Status Initialize() = 0;
|
||||
|
||||
virtual Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
size_t* bufferSize) = 0;
|
||||
|
||||
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||
int* output_nnz) = 0;
|
||||
int* output_nnz, void* workspace) = 0;
|
||||
|
||||
virtual Status Compute(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b, CSRComponent<T>* c) = 0;
|
||||
const ConstCSRComponent<T>& b, CSRComponent<T>* c,
|
||||
void* workspace) = 0;
|
||||
};
|
||||
|
||||
// Calculates C = alpha * A + beta * B, where A and B are in CSR
|
||||
|
@ -721,6 +721,56 @@ REGISTER_GPU(complex128)
|
||||
|
||||
namespace functor {
|
||||
|
||||
namespace {
|
||||
|
||||
// CUDADataType<T>::type translates from a C++ type (e.g. float) to a
|
||||
// cudaDataType_t (e.g. CUDA_R_32F).
|
||||
template <typename T>
|
||||
struct CUDADataType;
|
||||
|
||||
template <>
|
||||
struct CUDADataType<Eigen::half> {
|
||||
static constexpr cudaDataType_t type = CUDA_R_16F;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDADataType<float> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_32F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr cudaDataType_t type = HIPBLAS_R_32F;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDADataType<std::complex<float>> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_C_32F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr cudaDataType_t type = HIPBLAS_C_32F;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDADataType<double> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_R_64F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr cudaDataType_t type = HIPBLAS_R_64F;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CUDADataType<std::complex<double>> {
|
||||
#if GOOGLE_CUDA
|
||||
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
static constexpr cudaDataType_t type = HIPBLAS_C_64F;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
public:
|
||||
@ -733,10 +783,10 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
GpuSparse cuda_sparse(ctx);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
{
|
||||
// Use Csrmm to calculate:
|
||||
// Use Csrmm/SpMM to calculate:
|
||||
// C = alpha * op(A) * op(B) + beta * C
|
||||
// where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
|
||||
// Note that Csrmm assumes B and C are in column-major form; so we
|
||||
// Note that Csrmm/Spmm assumes B and C are in column-major form; so we
|
||||
// use transB == true, and manually transpose the output in place
|
||||
// using blas<t>geam.
|
||||
// TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
|
||||
@ -746,22 +796,6 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
const T alpha = 1;
|
||||
const T beta = 0;
|
||||
|
||||
// transA must be non-transpose if transB is transpose (cusparse
|
||||
// limitation).
|
||||
const gpusparseOperation_t transA = GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||
|
||||
// transB: b is row-major, and cusparse requires col-major b (or
|
||||
// equivalently transB == transpose). this version is actually more
|
||||
// efficient.
|
||||
const gpusparseOperation_t transB = GPUSPARSE(OPERATION_TRANSPOSE);
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
|
||||
|
||||
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
||||
const int k = b.dimension(0);
|
||||
DCHECK_EQ(k, a.dense_shape_host(1));
|
||||
@ -786,10 +820,87 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
// op(A) = A and at least max(1, k) otherwise.
|
||||
const int ldc = m;
|
||||
|
||||
// transA must be non-transpose if transB is transpose (cusparse
|
||||
// limitation).
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif
|
||||
|
||||
// transB: b is row-major, and cusparse requires col-major b (or
|
||||
// equivalently transB == transpose). this version is actually more
|
||||
// efficient.
|
||||
#if GOOGLE_CUDA && CUDA_VERSION >= 10020
|
||||
|
||||
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||
gpusparseSpMatDescr_t matA;
|
||||
gpusparseDnMatDescr_t matB, matC;
|
||||
|
||||
// NOTE: the following APIs are not available in ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
|
||||
&matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
|
||||
const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
|
||||
CUDADataType<T>::type));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
|
||||
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCreateDnMat(&matC, m, n, ldc, c.data(), CUDADataType<T>::type,
|
||||
CUSPARSE_ORDER_COL));
|
||||
|
||||
size_t bufferSize = 0;
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
|
||||
transA, transB, &alpha, matA, matB, &beta, matC,
|
||||
CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||
DCHECK(buffer.flat<int8>().data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
|
||||
&beta, matC, CUSPARSE_MM_ALG_DEFAULT,
|
||||
buffer.flat<int8>().data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA));
|
||||
|
||||
#else
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
|
||||
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
|
||||
b.data(), ldb, &beta, c.data(), ldc));
|
||||
|
||||
#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -822,20 +933,35 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
const T alpha = 1;
|
||||
const T beta = 0;
|
||||
|
||||
#if GOOGLE_CUDA && CUDA_VERSION < 10020
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
const int m = a.dense_shape_host(0);
|
||||
const int n = a.dense_shape_host(1);
|
||||
const int nnz = a.values.size();
|
||||
DCHECK_EQ(nnz, a.col_ind.size());
|
||||
#if CUDA_VERSION >= 10020
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha,
|
||||
a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), x, &beta, y));
|
||||
#else
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
|
||||
a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), x, &beta, y));
|
||||
#endif
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -417,6 +417,36 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
||||
}
|
||||
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
size_t maxWorkspaceSize = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Calculate maximum workspace size over batch.
|
||||
ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
|
||||
a_input_matrix->col_indices_vec(i),
|
||||
a_input_matrix->values_vec<T>(i),
|
||||
a_input_dense_shape};
|
||||
ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
|
||||
b_input_matrix->col_indices_vec(i),
|
||||
b_input_matrix->values_vec<T>(i),
|
||||
b_input_dense_shape};
|
||||
size_t thisWorkspaceSize;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, csr_gemm.GetWorkspaceSize(a_comp, b_comp, &thisWorkspaceSize));
|
||||
if (thisWorkspaceSize > maxWorkspaceSize) {
|
||||
maxWorkspaceSize = thisWorkspaceSize;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor temp;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(maxWorkspaceSize)}),
|
||||
&temp));
|
||||
void* workspace = temp.flat<int8>().data();
|
||||
#else
|
||||
void* workspace = nullptr;
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Calculate output sizes for all minibatch entries.
|
||||
// Store in c_batch_ptr and update c_row_ptrs.
|
||||
@ -433,8 +463,9 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
||||
rows + 1);
|
||||
|
||||
int c_nnz_i;
|
||||
OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp,
|
||||
c_row_ptr_i, &c_nnz_i));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
|
||||
&c_nnz_i, workspace));
|
||||
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||
}
|
||||
|
||||
@ -464,7 +495,7 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
||||
b_input_dense_shape};
|
||||
CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
|
||||
c.values_vec<T>(i), c_dense_shape};
|
||||
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp));
|
||||
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp, workspace));
|
||||
}
|
||||
|
||||
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||
@ -527,7 +558,12 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
initialized_(false),
|
||||
transpose_a_(transpose_a),
|
||||
adjoint_a_(adjoint_a),
|
||||
#if CUDA_VERSION < 10000
|
||||
transpose_b_(transpose_b) {
|
||||
#else
|
||||
transpose_b_(transpose_b),
|
||||
info_(nullptr) {
|
||||
#endif // CUDA_VERSION < 10000
|
||||
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
||||
transA_ = transpose_a
|
||||
? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
|
||||
@ -537,6 +573,14 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 10000
|
||||
~CSRSparseSparseMatrixMatMul() {
|
||||
if (initialized_) {
|
||||
cusparseDestroyCsrgemm2Info(info_);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Status Initialize() {
|
||||
if (adjoint_a_ && transpose_a_) {
|
||||
return errors::InvalidArgument(
|
||||
@ -547,14 +591,44 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
||||
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
||||
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
||||
#if CUDA_VERSION >= 10000
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsrgemm2Info(&info_));
|
||||
#endif
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||
DCHECK(initialized_);
|
||||
const int m =
|
||||
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||
if (!transpose_a_) {
|
||||
DCHECK_EQ(m, a.row_ptr.size() - 1);
|
||||
}
|
||||
const int k =
|
||||
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
|
||||
if (!transpose_b_) {
|
||||
DCHECK_EQ(k, b.row_ptr.size() - 1);
|
||||
}
|
||||
const int nnzA = a.col_ind.size();
|
||||
const int nnzB = b.col_ind.size();
|
||||
|
||||
const int n =
|
||||
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
|
||||
bufferSize));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||
int* output_nnz) {
|
||||
int* output_nnz, void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
|
||||
const int m =
|
||||
@ -576,10 +650,17 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
|
||||
*output_nnz = -1;
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
|
||||
#else
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
||||
descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
|
||||
#endif
|
||||
|
||||
if (*output_nnz < 0) {
|
||||
return errors::Internal(
|
||||
@ -590,7 +671,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
}
|
||||
|
||||
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
||||
CSRComponent<T>* c) {
|
||||
CSRComponent<T>* c, void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
|
||||
const int m =
|
||||
@ -612,11 +693,19 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
||||
|
||||
#if CUDA_VERSION < 10000
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
||||
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
||||
b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
|
||||
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
|
||||
#else
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||
m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
|
||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
||||
c->row_ptr.data(), c->col_ind.data(), info_, workspace));
|
||||
#endif
|
||||
|
||||
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
|
||||
// columns are sorted? Above operation leads to unsorted columns.
|
||||
@ -643,6 +732,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
GpuSparseMatrixDescriptor descrC_;
|
||||
gpusparseOperation_t transA_;
|
||||
gpusparseOperation_t transB_;
|
||||
#if CUDA_VERSION >= 10000
|
||||
csrgemm2Info_t info_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
Loading…
Reference in New Issue
Block a user