Add macros to avoid frequent ROCm ifdefs. Clean up the ifdefs in core/kernels/sparse.
PiperOrigin-RevId: 310004168 Change-Id: I1e48a64d6f7895c7f031397f1e231e231c4568cd
This commit is contained in:
parent
1f02731775
commit
50840b9587
@ -35,6 +35,9 @@ using gpusparseAction_t = cusparseAction_t;
|
||||
using gpusparseHandle_t = cusparseHandle_t;
|
||||
using gpuStream_t = cudaStream_t;
|
||||
|
||||
#define GPUSPARSE(postfix) CUSPARSE_##postfix
|
||||
#define gpusparse(postfix) cusparse##postfix
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "rocm/include/hipsparse/hipsparse.h"
|
||||
@ -46,6 +49,9 @@ using gpusparseAction_t = hipsparseAction_t;
|
||||
using gpusparseHandle_t = hipsparseHandle_t;
|
||||
using gpuStream_t = hipStream_t;
|
||||
|
||||
#define GPUSPARSE(postfix) HIPSPARSE_##postfix
|
||||
#define gpusparse(postfix) hipsparse##postfix
|
||||
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
@ -748,34 +748,19 @@ class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||
|
||||
// transA must be non-transpose if transB is transpose (cusparse
|
||||
// limitation).
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif
|
||||
const gpusparseOperation_t transA = GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||
|
||||
// transB: b is row-major, and cusparse requires col-major b (or
|
||||
// equivalently transB == transpose). this version is actually more
|
||||
// efficient.
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||
const gpusparseOperation_t transB = GPUSPARSE(OPERATION_TRANSPOSE);
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif
|
||||
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
|
||||
|
||||
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
||||
const int k = b.dimension(0);
|
||||
@ -838,19 +823,11 @@ class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||
const T beta = 0;
|
||||
|
||||
gpusparseMatDescr_t descrA;
|
||||
#if GOOGLE_CUDA
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(gpusparse(CreateMatDescr)(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
gpusparse(SetMatType)(descrA, GPUSPARSE(MATRIX_TYPE_GENERAL)));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(hipsparseCreateMatDescr(&descrA));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL));
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(
|
||||
hipsparseSetMatIndexBase(descrA, HIPSPARSE_INDEX_BASE_ZERO));
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
gpusparse(SetMatIndexBase)(descrA, GPUSPARSE(INDEX_BASE_ZERO)));
|
||||
|
||||
const int m = a.dense_shape_host(0);
|
||||
const int n = a.dense_shape_host(1);
|
||||
|
@ -529,20 +529,12 @@ struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||
adjoint_a_(adjoint_a),
|
||||
transpose_b_(transpose_b) {
|
||||
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
||||
#if GOOGLE_CUDA
|
||||
transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE
|
||||
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
transA_ = transpose_a
|
||||
? (adjoint_a ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
transB_ = transpose_b ? HIPSPARSE_OPERATION_TRANSPOSE
|
||||
: HIPSPARSE_OPERATION_NON_TRANSPOSE;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
? (adjoint_a ? GPUSPARSE(OPERATION_TRANSPOSE)
|
||||
: GPUSPARSE(OPERATION_CONJUGATE_TRANSPOSE))
|
||||
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||
transB_ = transpose_b ? GPUSPARSE(OPERATION_TRANSPOSE)
|
||||
: GPUSPARSE(OPERATION_NON_TRANSPOSE);
|
||||
}
|
||||
|
||||
Status Initialize() {
|
||||
|
@ -262,11 +262,7 @@ struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||
TF_RETURN_IF_ERROR(ValidateTransposeInputs(x, *y));
|
||||
GpuSparse cuda_sparse(ctx);
|
||||
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||
#if GOOGLE_CUDA
|
||||
const gpusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
const gpusparseAction_t copyValues = HIPSPARSE_ACTION_NUMERIC;
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
const gpusparseAction_t copyValues = GPUSPARSE(ACTION_NUMERIC);
|
||||
const int rank = x.dense_shape_host.size();
|
||||
const int m = x.row_ptr.size() - 1;
|
||||
const int n = x.dense_shape_host(rank - 1);
|
||||
|
Loading…
Reference in New Issue
Block a user