Remove usages of cusparse gtsv*
PiperOrigin-RevId: 302470219 Change-Id: Idaa6bfaefa7f29f92525109f5170315b2d312901
This commit is contained in:
parent
06ca7fc73c
commit
9478afb61c
tensorflow/core/kernels
@ -200,66 +200,6 @@ Status GpuSparse::Initialize() {
|
||||
// Check the actual declarations in the cusparse.h header file.
|
||||
//=============================================================================
|
||||
|
||||
template <typename Scalar, typename SparseFn>
|
||||
static inline Status GtsvImpl(SparseFn op, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, const Scalar* dl, const Scalar* d,
|
||||
const Scalar* du, Scalar* B, int ldb) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, n, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(B), ldb));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::Gtsv<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, Scalar* B, \
|
||||
int ldb) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvImpl(SPARSE_FN(gtsv, sparse_prefix), *gpusparse_handle_, m, n, \
|
||||
dl, d, du, B, ldb); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV_INSTANCE);
|
||||
|
||||
#define GTSV_NO_PIVOT_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::GtsvNoPivot<Scalar>(int m, int n, const Scalar* dl, \
|
||||
const Scalar* d, const Scalar* du, \
|
||||
Scalar* B, int ldb) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvImpl(SPARSE_FN(gtsv_nopivot, sparse_prefix), \
|
||||
*gpusparse_handle_, m, n, dl, d, du, B, ldb); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV_NO_PIVOT_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SparseFn>
|
||||
static inline Status GtsvStridedBatchImpl(SparseFn op,
|
||||
cusparseHandle_t cusparse_handle,
|
||||
int m, const Scalar* dl,
|
||||
const Scalar* d, const Scalar* du,
|
||||
Scalar* x, int batchCount,
|
||||
int batchStride) {
|
||||
TF_RETURN_IF_GPUSPARSE_ERROR(op(cusparse_handle, m, AsCudaComplex(dl),
|
||||
AsCudaComplex(d), AsCudaComplex(du),
|
||||
AsCudaComplex(x), batchCount, batchStride));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GTSV_STRIDED_BATCH_INSTANCE(Scalar, sparse_prefix) \
|
||||
template <> \
|
||||
Status GpuSparse::GtsvStridedBatch<Scalar>( \
|
||||
int m, const Scalar* dl, const Scalar* d, const Scalar* du, Scalar* x, \
|
||||
int batchCount, int batchStride) const { \
|
||||
DCHECK(initialized_); \
|
||||
return GtsvStridedBatchImpl(SPARSE_FN(gtsvStridedBatch, sparse_prefix), \
|
||||
*gpusparse_handle_, m, dl, d, du, x, \
|
||||
batchCount, batchStride); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GTSV_STRIDED_BATCH_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename SparseFn>
|
||||
static inline Status Gtsv2Impl(SparseFn op, cusparseHandle_t cusparse_handle,
|
||||
int m, int n, const Scalar* dl, const Scalar* d,
|
||||
|
@ -190,37 +190,6 @@ class GpuSparse {
|
||||
// Wrappers for cuSparse start here.
|
||||
//
|
||||
|
||||
// Solves tridiagonal system of equations.
|
||||
// Note: Cuda Toolkit 9.0+ has better-performing gtsv2 routine. gtsv will be
|
||||
// removed in Cuda Toolkit 11.0.
|
||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
template <typename Scalar>
|
||||
Status Gtsv(int m, int n, const Scalar *dl, const Scalar *d, const Scalar *du,
|
||||
Scalar *B, int ldb) const;
|
||||
|
||||
// Solves tridiagonal system of equations without pivoting.
|
||||
// Note: Cuda Toolkit 9.0+ has better-performing gtsv2_nopivot routine.
|
||||
// gtsv_nopivot will be removed in Cuda Toolkit 11.0.
|
||||
// See:
|
||||
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsv_nopivot
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
template <typename Scalar>
|
||||
Status GtsvNoPivot(int m, int n, const Scalar *dl, const Scalar *d,
|
||||
const Scalar *du, Scalar *B, int ldb) const;
|
||||
|
||||
// Solves a batch of tridiagonal systems of equations. Doesn't support
|
||||
// multiple right-hand sides per each system. Doesn't do pivoting.
|
||||
// Note: Cuda Toolkit 9.0+ has better-performing gtsv2StridedBatch routine.
|
||||
// gtsvStridedBatch will be removed in Cuda Toolkit 11.0.
|
||||
// See:
|
||||
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-gtsvstridedbatch
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
template <typename Scalar>
|
||||
Status GtsvStridedBatch(int m, const Scalar *dl, const Scalar *d,
|
||||
const Scalar *du, Scalar *x, int batchCount,
|
||||
int batchStride) const;
|
||||
|
||||
// Solves tridiagonal system of equations.
|
||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
|
||||
template <typename Scalar>
|
||||
|
@ -200,13 +200,6 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
||||
const Scalar* superdiag, const Scalar* diag,
|
||||
const Scalar* subdiag, Scalar* rhs, const int num_eqs,
|
||||
const int num_rhs) const {
|
||||
#if CUDA_VERSION < 9000
|
||||
auto function =
|
||||
pivoting_ ? &GpuSparse::Gtsv<Scalar> : &GpuSparse::GtsvNoPivot<Scalar>;
|
||||
OP_REQUIRES_OK(
|
||||
context, (cusparse_solver.get()->*function)(
|
||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs, num_eqs));
|
||||
#else
|
||||
auto buffer_function = pivoting_
|
||||
? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
|
||||
: &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
|
||||
@ -225,7 +218,6 @@ class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
|
||||
OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
|
||||
num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
|
||||
num_eqs, buffer));
|
||||
#endif // CUDA_VERSION < 9000
|
||||
}
|
||||
|
||||
void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
|
||||
@ -318,11 +310,7 @@ class TridiagonalSolveOpGpu : public OpKernel {
|
||||
std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
|
||||
|
||||
OP_REQUIRES_OK(context, cusparse_solver->Initialize());
|
||||
#if CUDA_VERSION < 9000
|
||||
OP_REQUIRES_OK(context, cusparse_solver->GtsvStridedBatch(
|
||||
matrix_size, subdiag, diag, superdiag, x,
|
||||
batch_size, matrix_size));
|
||||
#else
|
||||
|
||||
size_t buffer_size;
|
||||
OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
|
||||
matrix_size, subdiag, diag, superdiag, x,
|
||||
@ -335,7 +323,6 @@ class TridiagonalSolveOpGpu : public OpKernel {
|
||||
OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
|
||||
matrix_size, subdiag, diag, superdiag, x,
|
||||
batch_size, matrix_size, buffer));
|
||||
#endif // CUDA_VERSION < 9000
|
||||
}
|
||||
|
||||
void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
|
||||
|
Loading…
Reference in New Issue
Block a user