Enable fast GPU code path for solving many small linear systems in matrix_solve.
PiperOrigin-RevId: 228382682
This commit is contained in:
parent
a3d639a5a4
commit
495b3eeef0
@ -692,8 +692,8 @@ static inline Status GetrsBatchedImpl(
|
||||
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
|
||||
cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
|
||||
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
|
||||
const Scalar* const host_b_dev_ptrs[], int ldb,
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) {
|
||||
const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
|
||||
int batch_size) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||
@ -714,7 +714,7 @@ static inline Status GetrsBatchedImpl(
|
||||
cublas_handle, trans, n, nrhs,
|
||||
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
|
||||
dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
|
||||
ldb, dev_lapack_info->mutable_data(), batch_size));
|
||||
ldb, host_lapack_info, batch_size));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -723,13 +723,13 @@ static inline Status GetrsBatchedImpl(
|
||||
Status CudaSolver::GetrsBatched( \
|
||||
cublasOperation_t trans, int n, int nrhs, \
|
||||
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
|
||||
const Scalar* const host_b_dev_ptrs[], int ldb, \
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) { \
|
||||
const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
|
||||
int batch_size) { \
|
||||
return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
|
||||
BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
|
||||
this, context_, cublas_handle_, trans, n, nrhs, \
|
||||
host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
|
||||
ldb, dev_lapack_info, batch_size); \
|
||||
ldb, host_lapack_info, batch_size); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
|
||||
|
@ -235,13 +235,14 @@ class CudaSolver {
|
||||
int batch_size) TF_MUST_USE_RESULT;
|
||||
|
||||
// Batched linear solver using LU factorization from getrfBatched.
|
||||
// See:
|
||||
// Notice that lapack_info is returned on the host, as opposed to
|
||||
// most of the other functions that return it on the device. See:
|
||||
// http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
|
||||
template <typename Scalar>
|
||||
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
|
||||
const Scalar* const dev_Aarray[], int lda,
|
||||
const int* devIpiv, const Scalar* const dev_Barray[],
|
||||
int ldb, DeviceLapackInfo* dev_lapack_info,
|
||||
int ldb, int* host_lapack_info,
|
||||
int batch_size) TF_MUST_USE_RESULT;
|
||||
|
||||
// Computes matrix inverses for a batch of small matrices. Uses the outputs
|
||||
|
@ -214,9 +214,12 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
||||
auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
|
||||
sizeof(Scalar*) * batch_size, "input_copt_ptrs",
|
||||
/* on_host */ true);
|
||||
if (n / batch_size <= 128) {
|
||||
// For small matrices or large batch sizes, we use the batched
|
||||
// interface from cuBlas.
|
||||
const int kMaxMatrixSizeToBatchSizeRatio = 128;
|
||||
const bool use_batched_solver =
|
||||
n <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
|
||||
if (use_batched_solver) {
|
||||
// For small matrices or large batch sizes, we use the batched interface
|
||||
// from cuBlas.
|
||||
const Scalar** input_copy_ptrs_base =
|
||||
reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
@ -230,8 +233,8 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
||||
&dev_info.back(), batch_size),
|
||||
done);
|
||||
} else {
|
||||
// For small batch sizes we use the non-batched interface from cuSolver,
|
||||
// which is much faster for large matrices.
|
||||
// For small batch sizes or large matrices, we use the non-batched
|
||||
// interface from cuSolver, which is much faster for large matrices.
|
||||
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
@ -279,11 +282,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
||||
/* on_host */ true);
|
||||
auto transposed_rhs_reshaped =
|
||||
transposed_rhs.template flat_inner_dims<Scalar, 3>();
|
||||
// TODO(rmlarsen): Enable the following branch when I figure
|
||||
// out why it causes a segfault.
|
||||
if (false && n / batch_size <= 128) {
|
||||
dev_info.push_back(
|
||||
solver->GetDeviceLapackInfo(batch_size, "GetrsBatched"));
|
||||
if (use_batched_solver) {
|
||||
const Scalar** input_copy_ptrs_base =
|
||||
reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
|
||||
const Scalar** transposed_rhs_ptrs_base =
|
||||
@ -293,13 +292,20 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
||||
input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
|
||||
transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0);
|
||||
}
|
||||
int host_info = 0;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
|
||||
input_copy_ptrs_base, n, pivots_mat.data(),
|
||||
transposed_rhs_ptrs_base, n, &dev_info.back(),
|
||||
transposed_rhs_ptrs_base, n, &host_info,
|
||||
batch_size),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, host_info == 0,
|
||||
errors::InvalidArgument("The ", -host_info,
|
||||
"'th argument to cublas*getrsBatched had "
|
||||
"an illegal value."),
|
||||
done);
|
||||
} else {
|
||||
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user