Remove usages of cusparse gtsv*
PiperOrigin-RevId: 302470219 Change-Id: Idaa6bfaefa7f29f92525109f5170315b2d312901
This commit is contained in:
parent
06ca7fc73c
commit
9478afb61c
@ -200,66 +200,6 @@ Status GpuSparse::Initialize() {
|
|||||||
// Check the actual declarations in the cusparse.h header file.
|
// Check the actual declarations in the cusparse.h header file.
|
||||||
//=============================================================================
|
//=============================================================================
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFn>
|
|
||||||
static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
|
|
||||||
int m, int n, const Scalar* dl, const Scalar* d,
|
|
||||||
const Scalar* du, Scalar* B, int ldb) {
|
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
|
||||||
AsCudaComplex(d), AsCudaComplex(du),
|
|
||||||
AsCudaComplex(B), ldb));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
#define GTSV_INSTANCE(Scalar, sparse_prefix) \
|
|
||||||
template <> \
|
|
||||||
Status GpuSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \
|
|
||||||
const Scalar* d, const Scalar* du, Scalar* B, \
|
|
||||||
int ldb) const { \
|
|
||||||
DCHECK(initialized_); \
|
|
||||||
return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \
|
|
||||||
dl, d, du, B, ldb); \
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(GTSV_INSTANCE);
|
|
||||||
|
|
||||||
#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
|
|
||||||
template <> \
|
|
||||||
Status GpuSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \
|
|
||||||
const Scalar* d, const Scalar* du, \
|
|
||||||
Scalar* B, int ldb) const { \
|
|
||||||
DCHECK(initialized_); \
|
|
||||||
return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \
|
|
||||||
*gpusparse_handle_, m, n, dl, d, du, B, ldb); \
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE);
|
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFn>
|
|
||||||
static inline Status GtsvStridedBatchImpl(SparseFn op,
|
|
||||||
cusparseHandle_t cusparse_handle,
|
|
||||||
int m, const Scalar* dl,
|
|
||||||
const Scalar* d, const Scalar* du,
|
|
||||||
Scalar* x, int batchCount,
|
|
||||||
int batchStride) {
|
|
||||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
|
|
||||||
AsCudaComplex(d), AsCudaComplex(du),
|
|
||||||
AsCudaComplex(x), batchCount, batchStride));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
#define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
|
|
||||||
template <> \
|
|
||||||
Status GpuSparse::GtsvStridedBatch<Scalar>( \
|
|
||||||
int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
|
|
||||||
int batchCount, int batchStride) const { \
|
|
||||||
DCHECK(initialized_); \
|
|
||||||
return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \
|
|
||||||
*gpusparse_handle_, m, dl, d, du, x, \
|
|
||||||
batchCount, batchStride); \
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(GTSV_STRIDED_BATCH_INSTANCE);
|
|
||||||
|
|
||||||
template <typename Scalar, typename SparseFn>
|
template <typename Scalar, typename SparseFn>
|
||||||
static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
|
static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
|
||||||
int m, int n, const Scalar* dl, const Scalar* d,
|
int m, int n, const Scalar* dl, const Scalar* d,
|
||||||
|
@ -190,37 +190,6 @@ class GpuSparse {
|
|||||||
// Wrappers for cuSparse start here.
|
// Wrappers for cuSparse start here.
|
||||||
//
|
//
|
||||||
|
|
||||||
// Solves tridiagonal system of equations.
|
|
||||||
// Note: Cuda Toolkit 9.0+ has better-performing gtsv2 routine. gtsv will be
|
|
||||||
// removed in Cuda Toolkit 11.0.
|
|
||||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv
|
|
||||||
// Returns Status::OK() if the kernel was launched successfully.
|
|
||||||
template <typename Scalar>
|
|
||||||
Status Gtsv(int m, int n, const Scalar *dl, const Scalar *d, const Scalar *du,
|
|
||||||
Scalar *B, int ldb) const;
|
|
||||||
|
|
||||||
// Solves tridiagonal system of equations without pivoting.
|
|
||||||
// Note: Cuda Toolkit 9.0+ has better-performing gtsv2_nopivot routine.
|
|
||||||
// gtsv_nopivot will be removed in Cuda Toolkit 11.0.
|
|
||||||
// See:
|
|
||||||
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv_nopivot
|
|
||||||
// Returns Status::OK() if the kernel was launched successfully.
|
|
||||||
template <typename Scalar>
|
|
||||||
Status GtsvNoPivot(int m, int n, const Scalar *dl, const Scalar *d,
|
|
||||||
const Scalar *du, Scalar *B, int ldb) const;
|
|
||||||
|
|
||||||
// Solves a batch of tridiagonal systems of equations. Doesn't support
|
|
||||||
// multiple right-hand sides per each system. Doesn't do pivoting.
|
|
||||||
// Note: Cuda Toolkit 9.0+ has better-performing gtsv2StridedBatch routine.
|
|
||||||
// gtsvStridedBatch will be removed in Cuda Toolkit 11.0.
|
|
||||||
// See:
|
|
||||||
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsvstridedbatch
|
|
||||||
// Returns Status::OK() if the kernel was launched successfully.
|
|
||||||
template <typename Scalar>
|
|
||||||
Status GtsvStridedBatch(int m, const Scalar *dl, const Scalar *d,
|
|
||||||
const Scalar *du, Scalar *x, int batchCount,
|
|
||||||
int batchStride) const;
|
|
||||||
|
|
||||||
// Solves tridiagonal system of equations.
|
// Solves tridiagonal system of equations.
|
||||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
|
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
|
@ -200,13 +200,6 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
|||||||
const Scalar* superdiag, const Scalar* diag,
|
const Scalar* superdiag, const Scalar* diag,
|
||||||
const Scalar* subdiag, Scalar* rhs, const int num_eqs,
|
const Scalar* subdiag, Scalar* rhs, const int num_eqs,
|
||||||
const int num_rhs) const {
|
const int num_rhs) const {
|
||||||
#if CUDA_VERSION < 9000
|
|
||||||
auto function =
|
|
||||||
pivoting_ ? &GpuSparse::Gtsv<Scalar> : &GpuSparse::GtsvNoPivot<Scalar>;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
context, (cusparse_solver.get()->*function)(
|
|
||||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs));
|
|
||||||
#else
|
|
||||||
auto buffer_function = pivoting_
|
auto buffer_function = pivoting_
|
||||||
? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
|
? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
|
||||||
: &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
|
: &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
|
||||||
@ -225,7 +218,6 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
|||||||
OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
|
OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
|
||||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
|
num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
|
||||||
num_eqs, buffer));
|
num_eqs, buffer));
|
||||||
#endif // CUDA_VERSION < 9000
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
|
void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
|
||||||
@ -318,11 +310,7 @@ class TridiagonalSolveOpGpu : public OpKernel {
|
|||||||
std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
|
std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
|
||||||
|
|
||||||
OP_REQUIRES_OK(context, cusparse_solver->Initialize());
|
OP_REQUIRES_OK(context, cusparse_solver->Initialize());
|
||||||
#if CUDA_VERSION < 9000
|
|
||||||
OP_REQUIRES_OK(context, cusparse_solver->GtsvStridedBatch(
|
|
||||||
matrix_size, subdiag, diag, superdiag, x,
|
|
||||||
batch_size, matrix_size));
|
|
||||||
#else
|
|
||||||
size_t buffer_size;
|
size_t buffer_size;
|
||||||
OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
|
OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
|
||||||
matrix_size, subdiag, diag, superdiag, x,
|
matrix_size, subdiag, diag, superdiag, x,
|
||||||
@ -335,7 +323,6 @@ class TridiagonalSolveOpGpu : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
|
OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
|
||||||
matrix_size, subdiag, diag, superdiag, x,
|
matrix_size, subdiag, diag, superdiag, x,
|
||||||
batch_size, matrix_size, buffer));
|
batch_size, matrix_size, buffer));
|
||||||
#endif // CUDA_VERSION < 9000
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
|
void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user