Enable fast GPU code path for solving many small linear systems in matrix_solve.

PiperOrigin-RevId: 228382682
This commit is contained in:
A. Unique TensorFlower 2019-01-08 12:55:56 -08:00 committed by TensorFlower Gardener
parent a3d639a5a4
commit 495b3eeef0
3 changed files with 26 additions and 19 deletions

View File

@ -692,8 +692,8 @@ static inline Status GetrsBatchedImpl(
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
const Scalar* const host_b_dev_ptrs[], int ldb,
DeviceLapackInfo* dev_lapack_info, int batch_size) {
const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
@ -714,7 +714,7 @@ static inline Status GetrsBatchedImpl(
cublas_handle, trans, n, nrhs,
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
ldb, dev_lapack_info->mutable_data(), batch_size));
ldb, host_lapack_info, batch_size));
return Status::OK();
}
@ -723,13 +723,13 @@ static inline Status GetrsBatchedImpl(
Status CudaSolver::GetrsBatched( \
cublasOperation_t trans, int n, int nrhs, \
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
const Scalar* const host_b_dev_ptrs[], int ldb, \
DeviceLapackInfo* dev_lapack_info, int batch_size) { \
const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
int batch_size) { \
return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
this, context_, cublas_handle_, trans, n, nrhs, \
host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
ldb, dev_lapack_info, batch_size); \
ldb, host_lapack_info, batch_size); \
}
TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);

View File

@ -235,13 +235,14 @@ class CudaSolver {
int batch_size) TF_MUST_USE_RESULT;
// Batched linear solver using LU factorization from getrfBatched.
// See:
// Notice that lapack_info is returned on the host, as opposed to
// most of the other functions that return it on the device. See:
// http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
template <typename Scalar>
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
const Scalar* const dev_Aarray[], int lda,
const int* devIpiv, const Scalar* const dev_Barray[],
int ldb, DeviceLapackInfo* dev_lapack_info,
int ldb, int* host_lapack_info,
int batch_size) TF_MUST_USE_RESULT;
// Computes matrix inverses for a batch of small matrices. Uses the outputs

View File

@ -214,9 +214,12 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
sizeof(Scalar*) * batch_size, "input_copt_ptrs",
/* on_host */ true);
if (n / batch_size <= 128) {
// For small matrices or large batch sizes, we use the batched
// interface from cuBlas.
const int kMaxMatrixSizeToBatchSizeRatio = 128;
const bool use_batched_solver =
n <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
if (use_batched_solver) {
// For small matrices or large batch sizes, we use the batched interface
// from cuBlas.
const Scalar** input_copy_ptrs_base =
reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
for (int batch = 0; batch < batch_size; ++batch) {
@ -230,8 +233,8 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
&dev_info.back(), batch_size),
done);
} else {
// For small batch sizes we use the non-batched interface from cuSolver,
// which is much faster for large matrices.
// For small batch sizes or large matrices, we use the non-batched
// interface from cuSolver, which is much faster for large matrices.
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK_ASYNC(
@ -279,11 +282,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
/* on_host */ true);
auto transposed_rhs_reshaped =
transposed_rhs.template flat_inner_dims<Scalar, 3>();
// TODO(rmlarsen): Enable the following branch when I figure
// out why it causes a segfault.
if (false && n / batch_size <= 128) {
dev_info.push_back(
solver->GetDeviceLapackInfo(batch_size, "GetrsBatched"));
if (use_batched_solver) {
const Scalar** input_copy_ptrs_base =
reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
const Scalar** transposed_rhs_ptrs_base =
@ -293,13 +292,20 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0);
}
int host_info = 0;
OP_REQUIRES_OK_ASYNC(
context,
solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
input_copy_ptrs_base, n, pivots_mat.data(),
transposed_rhs_ptrs_base, n, &dev_info.back(),
transposed_rhs_ptrs_base, n, &host_info,
batch_size),
done);
OP_REQUIRES_ASYNC(
context, host_info == 0,
errors::InvalidArgument("The ", -host_info,
"'th argument to cublas*getrsBatched had "
"an illegal value."),
done);
} else {
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
for (int batch = 0; batch < batch_size; ++batch) {