Cleanup formatting with clang-format
This commit is contained in:
parent
7e07d1fe65
commit
54e5bedaaf
@ -23,8 +23,6 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/include/library_types.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -38,6 +36,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/include/library_types.h"
|
||||
|
||||
// TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
|
||||
|
||||
@ -180,11 +180,9 @@ Status GpuSparse::Initialize() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define TF_CALL_CUSPARSE_DTYPES(m) \
|
||||
m(float, CUDA_R_32F) \
|
||||
m(double, CUDA_R_64F) \
|
||||
m(std::complex<float>, CUDA_C_32F) \
|
||||
m(std::complex<double>, CUDA_C_64F)
|
||||
#define TF_CALL_CUSPARSE_DTYPES(m) \
|
||||
m(float, CUDA_R_32F) m(double, CUDA_R_64F) \
|
||||
m(std::complex<float>, CUDA_C_32F) m(std::complex<double>, CUDA_C_64F)
|
||||
|
||||
// Macro that specializes a sparse method for all 4 standard
|
||||
// numeric types.
|
||||
@ -366,15 +364,12 @@ Status GpuSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GpuSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr,
|
||||
void* workspace) {
|
||||
Status GpuSparse::CsrgeamNnz(
|
||||
int m, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
#if CUDA_VERSION >= 10000
|
||||
@ -435,36 +430,35 @@ TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||
|
||||
#else
|
||||
|
||||
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
|
||||
template<> \
|
||||
Status GpuSparse::SpMMBufferSize<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, \
|
||||
size_t* bufferSize) const { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
|
||||
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, \
|
||||
matC, dtype, alg, bufferSize)); \
|
||||
return Status::OK(); \
|
||||
#define SPMM_BUFFERSIZE_INSTANCE(Scalar, dtype) \
|
||||
template <> \
|
||||
Status GpuSparse::SpMMBufferSize<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, size_t* bufferSize) \
|
||||
const { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM_bufferSize( \
|
||||
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, matC, \
|
||||
dtype, alg, bufferSize)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(SPMM_BUFFERSIZE_INSTANCE);
|
||||
|
||||
#define SPMM_INSTANCE(Scalar, dtype) \
|
||||
template<> \
|
||||
Status GpuSparse::SpMM<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, \
|
||||
int8* buffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM( \
|
||||
*gpusparse_handle_, transA, transB, alpha, matA, matB, beta, \
|
||||
matC, dtype, alg, buffer)); \
|
||||
return Status::OK(); \
|
||||
#define SPMM_INSTANCE(Scalar, dtype) \
|
||||
template <> \
|
||||
Status GpuSparse::SpMM<Scalar>( \
|
||||
cusparseOperation_t transA, cusparseOperation_t transB, \
|
||||
const Scalar* alpha, const cusparseSpMatDescr_t matA, \
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta, \
|
||||
gpusparseDnMatDescr_t matC, cusparseSpMMAlg_t alg, int8* buffer) const { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMM(*gpusparse_handle_, transA, \
|
||||
transB, alpha, matA, matB, beta, \
|
||||
matC, dtype, alg, buffer)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(SPMM_INSTANCE);
|
||||
@ -515,11 +509,14 @@ TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||
#else
|
||||
|
||||
template <typename Scalar>
|
||||
static inline Status CsrmvExImpl(
|
||||
cudaDataType_t dtype, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const Scalar* x, const Scalar* beta_host, Scalar* y) {
|
||||
static inline Status CsrmvExImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
cusparseHandle_t cusparse_handle,
|
||||
cusparseOperation_t transA, int m, int n,
|
||||
int nnz, const Scalar* alpha_host,
|
||||
const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
cusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
@ -531,10 +528,9 @@ static inline Status CsrmvExImpl(
|
||||
|
||||
size_t bufferSize;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx_bufferSize(
|
||||
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz,
|
||||
alpha_host, dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA,
|
||||
csrSortedColIndA, x, dtype, beta_host, dtype, y, dtype, dtype,
|
||||
&bufferSize));
|
||||
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
|
||||
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
|
||||
x, dtype, beta_host, dtype, y, dtype, dtype, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
@ -543,40 +539,40 @@ static inline Status CsrmvExImpl(
|
||||
DCHECK(pBuffer.data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsrmvEx(
|
||||
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz,
|
||||
alpha_host, dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA,
|
||||
csrSortedColIndA, x, dtype, beta_host, dtype, y, dtype, dtype,
|
||||
pBuffer.data()));
|
||||
cusparse_handle, CUSPARSE_ALG_MERGE_PATH, transA, m, n, nnz, alpha_host,
|
||||
dtype, descrA, csrSortedValA, dtype, csrSortedRowPtrA, csrSortedColIndA,
|
||||
x, dtype, beta_host, dtype, y, dtype, dtype, pBuffer.data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyMatDescr(descrA));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
static inline Status SpMVImpl(
|
||||
cudaDataType_t dtype, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const Scalar* x, const Scalar* beta_host, Scalar* y) {
|
||||
static inline Status SpMVImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
cusparseHandle_t cusparse_handle,
|
||||
cusparseOperation_t transA, int m, int n, int nnz,
|
||||
const Scalar* alpha_host,
|
||||
const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* x,
|
||||
const Scalar* beta_host, Scalar* y) {
|
||||
cusparseSpMatDescr_t matA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(&matA, m, n, nnz,
|
||||
const_cast<int*>(csrSortedRowPtrA),
|
||||
const_cast<int*>(csrSortedColIndA),
|
||||
const_cast<Scalar*>(csrSortedValA),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO, dtype));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
|
||||
&matA, m, n, nnz, const_cast<int*>(csrSortedRowPtrA),
|
||||
const_cast<int*>(csrSortedColIndA), const_cast<Scalar*>(csrSortedValA),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, dtype));
|
||||
|
||||
cusparseDnVecDescr_t vecX, vecY;
|
||||
int sizeX = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? n : m;
|
||||
int sizeY = (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) ? m : n;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(
|
||||
&vecX, sizeX, const_cast<Scalar*>(x), dtype));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCreateDnVec(&vecX, sizeX, const_cast<Scalar*>(x), dtype));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnVec(&vecY, sizeY, y, dtype));
|
||||
|
||||
size_t bufferSize;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV_bufferSize(
|
||||
cusparse_handle, transA, alpha_host, matA, vecX,
|
||||
beta_host, vecY, dtype, CUSPARSE_CSRMV_ALG1, &bufferSize));
|
||||
cusparse_handle, transA, alpha_host, matA, vecX, beta_host, vecY, dtype,
|
||||
CUSPARSE_CSRMV_ALG1, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
@ -584,9 +580,9 @@ static inline Status SpMVImpl(
|
||||
auto pBuffer = buffer.flat<int8>();
|
||||
DCHECK(pBuffer.data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpMV(
|
||||
cusparse_handle, transA, alpha_host, matA, vecX,
|
||||
beta_host, vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSpMV(cusparse_handle, transA, alpha_host, matA, vecX, beta_host,
|
||||
vecY, dtype, CUSPARSE_CSRMV_ALG1, pBuffer.data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecY));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnVec(vecX));
|
||||
@ -594,25 +590,23 @@ static inline Status SpMVImpl(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmv<Scalar>( \
|
||||
cusparseOperation_t transA, int m, int n, int nnz, \
|
||||
const Scalar* alpha_host, const Scalar* csrSortedValA, \
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
|
||||
DCHECK(initialized_); \
|
||||
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
|
||||
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, \
|
||||
transA, m, n, nnz, alpha_host, \
|
||||
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||
x, beta_host, y); \
|
||||
} else { \
|
||||
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, \
|
||||
transA, m, n, nnz, alpha_host, \
|
||||
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||
x, beta_host, y); \
|
||||
} \
|
||||
#define CSRMV_INSTANCE(Scalar, cudaDataType) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrmv<Scalar>( \
|
||||
cusparseOperation_t transA, int m, int n, int nnz, \
|
||||
const Scalar* alpha_host, const Scalar* csrSortedValA, \
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||
const Scalar* x, const Scalar* beta_host, Scalar* y) const { \
|
||||
DCHECK(initialized_); \
|
||||
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
|
||||
return CsrmvExImpl(cudaDataType, context_, *gpusparse_handle_, transA, \
|
||||
m, n, nnz, alpha_host, csrSortedValA, \
|
||||
csrSortedRowPtrA, csrSortedColIndA, x, beta_host, y); \
|
||||
} else { \
|
||||
return SpMVImpl(cudaDataType, context_, *gpusparse_handle_, transA, m, \
|
||||
n, nnz, alpha_host, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, x, beta_host, y); \
|
||||
} \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(CSRMV_INSTANCE);
|
||||
@ -671,17 +665,16 @@ static inline Status Csrgeam2Impl(
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, void* workspace) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC,
|
||||
AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC,
|
||||
workspace));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||
cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||
csrSortedRowPtrC, csrSortedColIndC, workspace));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
|
||||
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Csrgeam<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
@ -733,34 +726,32 @@ static inline Status CsrgeamBufferSizeExtImpl(
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC,
|
||||
AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC,
|
||||
bufferSize));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(
|
||||
sparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||
csrSortedRowPtrC, csrSortedColIndC, bufferSize));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* beta, \
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgeamBufferSizeExtImpl(SPARSE_FN(csrgeam2_bufferSizeExt, \
|
||||
sparse_prefix), context_, *gpusparse_handle_, m, n, \
|
||||
alpha, descrA, nnzA, csrSortedValA, csrSortedRowPtrA, \
|
||||
csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, \
|
||||
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, \
|
||||
bufferSize); \
|
||||
#define CSRGEAM_BUFFERSIZE_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::CsrgeamBufferSizeExt<Scalar>( \
|
||||
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||
const int* csrSortedColIndA, const Scalar* beta, \
|
||||
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize) { \
|
||||
DCHECK(initialized_); \
|
||||
return CsrgeamBufferSizeExtImpl( \
|
||||
SPARSE_FN(csrgeam2_bufferSizeExt, sparse_prefix), context_, \
|
||||
*gpusparse_handle_, m, n, alpha, descrA, nnzA, csrSortedValA, \
|
||||
csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, csrSortedValB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, descrC, csrSortedValC, \
|
||||
csrSortedRowPtrC, csrSortedColIndC, bufferSize); \
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -827,13 +818,13 @@ TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||
|
||||
#else
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
static const T* one_ptr() {
|
||||
static const T one = static_cast<T>(1);
|
||||
return &one;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
static const T* null_ptr() {
|
||||
return nullptr;
|
||||
}
|
||||
@ -847,55 +838,53 @@ static const T* null_ptr() {
|
||||
const int* csrSortedColIndB, csrgemm2Info_t info, \
|
||||
size_t* workspaceBytes) { \
|
||||
DCHECK(initialized_); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR( \
|
||||
SPARSE_FN(csrgemm2_bufferSizeExt, sparse_prefix)( \
|
||||
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), \
|
||||
descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, \
|
||||
AsCudaComplex(null_ptr<Scalar>()), descrA, 0, null_ptr<int>(), \
|
||||
null_ptr<int>(), info, workspaceBytes)); \
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(SPARSE_FN(csrgemm2_bufferSizeExt, \
|
||||
sparse_prefix)( \
|
||||
*gpusparse_handle_, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA, \
|
||||
nnzA, csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, \
|
||||
csrSortedRowPtrB, csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), \
|
||||
descrA, 0, null_ptr<int>(), null_ptr<int>(), info, workspaceBytes)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(CSRGEMM_BUFFERSIZE_INSTANCE);
|
||||
|
||||
|
||||
Status GpuSparse::CsrgemmNnz(
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const cusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||
int* nnzTotalDevHostPtr, csrgemm2Info_t info, void* workspace) {
|
||||
const cusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr, csrgemm2Info_t info,
|
||||
void* workspace) {
|
||||
DCHECK(initialized_);
|
||||
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseXcsrgemm2Nnz(
|
||||
*gpusparse_handle_, m, n, k, descrA, nnzA, csrSortedRowPtrA,
|
||||
csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB,
|
||||
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC,
|
||||
csrSortedRowPtrC, nnzTotalDevHostPtr, info, workspace));
|
||||
descrA, 0, null_ptr<int>(), null_ptr<int>(), descrC, csrSortedRowPtrC,
|
||||
nnzTotalDevHostPtr, info, workspace));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Scalar, typename SparseFnT>
|
||||
static inline Status CsrgemmImpl(
|
||||
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
int m, int n, int k, const cusparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC,
|
||||
const csrgemm2Info_t info, void* workspace) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()),
|
||||
descrA, nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB, csrSortedColIndB,
|
||||
AsCudaComplex(null_ptr<Scalar>()),
|
||||
descrA, 0, AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
|
||||
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC,
|
||||
info, workspace));
|
||||
op(cusparse_handle, m, n, k, AsCudaComplex(one_ptr<Scalar>()), descrA,
|
||||
nnzA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
|
||||
csrSortedColIndB, AsCudaComplex(null_ptr<Scalar>()), descrA, 0,
|
||||
AsCudaComplex(null_ptr<Scalar>()), null_ptr<int>(), null_ptr<int>(),
|
||||
descrC, AsCudaComplex(csrSortedValC), csrSortedRowPtrC,
|
||||
csrSortedColIndC, info, workspace));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1008,11 +997,9 @@ static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
const cusparseAction_t copyValues) {
|
||||
size_t bufferSize;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2_bufferSize(
|
||||
cusparse_handle, m, n, nnz,
|
||||
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
AsCudaComplex(cscVal), cscColPtr, cscRowInd,
|
||||
dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
|
||||
CUSPARSE_CSR2CSC_ALG2, &bufferSize));
|
||||
cusparse_handle, m, n, nnz, AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
AsCudaComplex(cscVal), cscColPtr, cscRowInd, dtype, copyValues,
|
||||
CUSPARSE_INDEX_BASE_ZERO, CUSPARSE_CSR2CSC_ALG2, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
@ -1021,27 +1008,25 @@ static inline Status Csr2cscImpl(cudaDataType_t dtype, OpKernelContext* context,
|
||||
|
||||
DCHECK(buffer.flat<Scalar>().data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCsr2cscEx2(
|
||||
cusparse_handle, m, n, nnz,
|
||||
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||
AsCudaComplex(cscVal), cscColPtr, cscRowInd,
|
||||
dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
|
||||
CUSPARSE_CSR2CSC_ALG2,
|
||||
buffer.flat<Scalar>().data()));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCsr2cscEx2(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
|
||||
csrRowPtr, csrColInd, AsCudaComplex(cscVal), cscColPtr,
|
||||
cscRowInd, dtype, copyValues, CUSPARSE_INDEX_BASE_ZERO,
|
||||
CUSPARSE_CSR2CSC_ALG2, buffer.flat<Scalar>().data()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
|
||||
template <> \
|
||||
Status GpuSparse::Csr2csc<Scalar>( \
|
||||
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||
const cusparseAction_t copyValues) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csr2cscImpl(cudaDataType, context_, \
|
||||
*gpusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
|
||||
#define CSR2CSC_INSTANCE(Scalar, cudaDataType) \
|
||||
template <> \
|
||||
Status GpuSparse::Csr2csc<Scalar>( \
|
||||
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||
const cusparseAction_t copyValues) { \
|
||||
DCHECK(initialized_); \
|
||||
return Csr2cscImpl(cudaDataType, context_, *gpusparse_handle_, m, n, nnz, \
|
||||
csrVal, csrRowPtr, csrColInd, cscVal, cscRowInd, \
|
||||
cscColPtr, copyValues); \
|
||||
}
|
||||
|
||||
TF_CALL_CUSPARSE_DTYPES(CSR2CSC_INSTANCE);
|
||||
|
@ -26,8 +26,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
using gpusparseStatus_t = cusparseStatus_t;
|
||||
using gpusparseOperation_t = cusparseOperation_t;
|
||||
@ -253,7 +253,6 @@ class GpuSparse {
|
||||
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
||||
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
||||
|
||||
|
||||
#if CUDA_VERSION < 10020
|
||||
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
||||
// where A is a sparse matrix in CSR format, B and C are dense tall
|
||||
@ -275,13 +274,14 @@ class GpuSparse {
|
||||
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
|
||||
int ldc) const;
|
||||
#else
|
||||
// Workspace size query for sparse-dense matrix multiplication. Helper function
|
||||
// for SpMM which computes y = alpha * op(A) * op(B) + beta * C, where A is
|
||||
// a sparse matrix in CSR format, B and C are dense matricies in column-major
|
||||
// format. Returns needed workspace size in bytes.
|
||||
// Workspace size query for sparse-dense matrix multiplication. Helper
|
||||
// function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
|
||||
// where A is a sparse matrix in CSR format, B and C are dense matricies in
|
||||
// column-major format. Returns needed workspace size in bytes.
|
||||
template <typename Scalar>
|
||||
Status SpMMBufferSize(gpusparseOperation_t transA, gpusparseOperation_t transB,
|
||||
const Scalar* alpha, const gpusparseSpMatDescr_t matA,
|
||||
Status SpMMBufferSize(gpusparseOperation_t transA,
|
||||
gpusparseOperation_t transB, const Scalar* alpha,
|
||||
const gpusparseSpMatDescr_t matA,
|
||||
const gpusparseDnMatDescr_t matB, const Scalar* beta,
|
||||
gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
|
||||
size_t* bufferSize) const;
|
||||
@ -323,15 +323,14 @@ class GpuSparse {
|
||||
// Computes workspace size for sparse - sparse matrix addition of matrices
|
||||
// stored in CSR format.
|
||||
template <typename Scalar>
|
||||
Status CsrgeamBufferSizeExt(int m, int n, const Scalar* alpha,
|
||||
const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* beta,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC, size_t* bufferSize);
|
||||
Status CsrgeamBufferSizeExt(
|
||||
int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const Scalar* beta,
|
||||
const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||
int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
|
||||
|
||||
// Computes sparse-sparse matrix addition of matrices
|
||||
// stored in CSR format. This is part one: calculate nnz of the
|
||||
@ -365,15 +364,12 @@ class GpuSparse {
|
||||
// Computes sparse-sparse matrix multiplication of matrices
|
||||
// stored in CSR format. This is part zero: calculate required workspace
|
||||
// size.
|
||||
template<typename Scalar>
|
||||
Status CsrgemmBufferSize(int m, int n, int k,
|
||||
const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB,
|
||||
csrgemm2Info_t info, size_t* workspaceBytes);
|
||||
template <typename Scalar>
|
||||
Status CsrgemmBufferSize(
|
||||
int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
|
||||
#endif
|
||||
|
||||
// Computes sparse-sparse matrix multiplication of matrices
|
||||
@ -419,13 +415,12 @@ class GpuSparse {
|
||||
int* csrSortedColIndC);
|
||||
#else
|
||||
template <typename Scalar>
|
||||
Status Csrgemm(int m, int n, int k,
|
||||
const gpusparseMatDescr_t descrA, int nnzA,
|
||||
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||
const int* csrSortedColIndA, const gpusparseMatDescr_t descrB,
|
||||
int nnzB, const Scalar* csrSortedValB,
|
||||
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||
const gpusparseMatDescr_t descrC,
|
||||
Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
|
||||
int nnzA, const Scalar* csrSortedValA,
|
||||
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||
const gpusparseMatDescr_t descrB, int nnzB,
|
||||
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||
const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
|
||||
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||
int* csrSortedColIndC, const csrgemm2Info_t info,
|
||||
void* workspace);
|
||||
|
@ -108,7 +108,7 @@ class CSRSparseMatrixAddFunctor {
|
||||
set_zero(d, c_row_ptr_t.flat<int32>());
|
||||
|
||||
size_t maxWorkspaceSize = 0;
|
||||
for (int i=0; i < batch_size; ++i) {
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
|
||||
a.values_vec<T>(i), a_dense_shape};
|
||||
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
|
||||
@ -140,9 +140,8 @@ class CSRSparseMatrixAddFunctor {
|
||||
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
|
||||
rows + 1);
|
||||
int c_nnz_i;
|
||||
TF_RETURN_IF_ERROR(
|
||||
csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i,
|
||||
workspace));
|
||||
TF_RETURN_IF_ERROR(csr_geam.GetOutputStructure(
|
||||
a_comp, b_comp, c_row_ptr_i, &c_nnz_i, workspace));
|
||||
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||
}
|
||||
|
||||
@ -290,8 +289,7 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
}
|
||||
|
||||
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
size_t* bufferSize) {
|
||||
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||
DCHECK(initialized_);
|
||||
|
||||
const int m = a.row_ptr.size() - 1;
|
||||
@ -310,13 +308,12 @@ struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamBufferSizeExt(
|
||||
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T,
|
||||
null_int, null_int, bufferSize));
|
||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), null_T, null_int,
|
||||
null_int, bufferSize));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||
|
@ -174,8 +174,7 @@ struct CSRStructureModifyingFunctor {
|
||||
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||
int* output_nnz,
|
||||
void* workspace) = 0;
|
||||
int* output_nnz, void* workspace) = 0;
|
||||
|
||||
virtual Status Compute(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b, CSRComponent<T>* c,
|
||||
|
@ -838,35 +838,33 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
gpusparseDnMatDescr_t matB, matC;
|
||||
|
||||
// NOTE: the following APIs are not available in ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(&matA, m, k, nnz,
|
||||
const_cast<int*>(a.row_ptr.data()),
|
||||
const_cast<int*>(a.col_ind.data()),
|
||||
const_cast<T*>(a.values.data()),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO, CUDADataType<T>::type));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr(
|
||||
&matA, m, k, nnz, const_cast<int*>(a.row_ptr.data()),
|
||||
const_cast<int*>(a.col_ind.data()), const_cast<T*>(a.values.data()),
|
||||
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO,
|
||||
CUDADataType<T>::type));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnMat(&matB, n, k, ldb,
|
||||
const_cast<T*>(b.data()),
|
||||
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCreateDnMat(&matB, n, k, ldb, const_cast<T*>(b.data()),
|
||||
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnMat(&matC, m, n, ldc, c.data(),
|
||||
CUDADataType<T>::type, CUSPARSE_ORDER_COL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseCreateDnMat(&matC, m, n, ldc, c.data(), CUDADataType<T>::type,
|
||||
CUSPARSE_ORDER_COL));
|
||||
|
||||
size_t bufferSize = 0;
|
||||
TF_RETURN_IF_ERROR(
|
||||
cuda_sparse.SpMMBufferSize(transA, transB, &alpha, matA, matB, &beta,
|
||||
matC, CUSPARSE_MM_ALG_DEFAULT,
|
||||
&bufferSize));
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize(
|
||||
transA, transB, &alpha, matA, matB, &beta, matC,
|
||||
CUSPARSE_MM_ALG_DEFAULT, &bufferSize));
|
||||
|
||||
Tensor buffer;
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(bufferSize)}), &buffer));
|
||||
DCHECK(buffer.flat<int8>().data() != nullptr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, &beta,
|
||||
matC, CUSPARSE_MM_ALG_DEFAULT,
|
||||
buffer.flat<int8>().data()));
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB,
|
||||
&beta, matC, CUSPARSE_MM_ALG_DEFAULT,
|
||||
buffer.flat<int8>().data()));
|
||||
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC));
|
||||
@ -897,14 +895,12 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
|
||||
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
|
||||
b.data(), ldb, &beta, c.data(), ldc));
|
||||
|
||||
#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020
|
||||
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -462,9 +462,9 @@ class CSRSparseMatMulGPUOp : public OpKernel {
|
||||
rows + 1);
|
||||
|
||||
int c_nnz_i;
|
||||
OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp,
|
||||
c_row_ptr_i, &c_nnz_i,
|
||||
workspace));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
csr_gemm.GetOutputStructure(a_comp, b_comp, c_row_ptr_i,
|
||||
&c_nnz_i, workspace));
|
||||
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||
}
|
||||
|
||||
@ -604,8 +604,7 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
}
|
||||
|
||||
Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
|
||||
const ConstCSRComponent<T>& b,
|
||||
size_t* bufferSize) {
|
||||
const ConstCSRComponent<T>& b, size_t* bufferSize) {
|
||||
DCHECK(initialized_);
|
||||
const int m =
|
||||
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||
@ -624,9 +623,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmBufferSize<T>(
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||
b.col_ind.data(), info_, bufferSize));
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(), info_,
|
||||
bufferSize));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -663,10 +662,9 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
|
||||
#else
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz,
|
||||
info_, workspace));
|
||||
m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
||||
descrC_.descr(), c_row_ptr.data(), output_nnz, info_, workspace));
|
||||
#endif
|
||||
|
||||
if (*output_nnz < 0) {
|
||||
@ -708,11 +706,10 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
|
||||
#else
|
||||
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||
m, n, k, descrA_.descr(), nnzA, a.values.data(),
|
||||
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
||||
b.values.data(), b.row_ptr.data(), b.col_ind.data(),
|
||||
descrC_.descr(), c->values.data(), c->row_ptr.data(), c->col_ind.data(),
|
||||
info_, workspace));
|
||||
m, n, k, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||
a.col_ind.data(), descrB_.descr(), nnzB, b.values.data(),
|
||||
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
||||
c->row_ptr.data(), c->col_ind.data(), info_, workspace));
|
||||
#endif
|
||||
|
||||
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
|
||||
|
Loading…
Reference in New Issue
Block a user