Enable fast GPU code path for solving many small linear systems in matrix_solve.
PiperOrigin-RevId: 228382682
This commit is contained in:
parent
a3d639a5a4
commit
495b3eeef0
@ -692,8 +692,8 @@ static inline Status GetrsBatchedImpl(
|
|||||||
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
|
SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
|
||||||
cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
|
cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
|
||||||
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
|
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
|
||||||
const Scalar* const host_b_dev_ptrs[], int ldb,
|
const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
|
||||||
DeviceLapackInfo* dev_lapack_info, int batch_size) {
|
int batch_size) {
|
||||||
mutex_lock lock(handle_map_mutex);
|
mutex_lock lock(handle_map_mutex);
|
||||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||||
@ -714,7 +714,7 @@ static inline Status GetrsBatchedImpl(
|
|||||||
cublas_handle, trans, n, nrhs,
|
cublas_handle, trans, n, nrhs,
|
||||||
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
|
reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
|
||||||
dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
|
dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
|
||||||
ldb, dev_lapack_info->mutable_data(), batch_size));
|
ldb, host_lapack_info, batch_size));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -723,13 +723,13 @@ static inline Status GetrsBatchedImpl(
|
|||||||
Status CudaSolver::GetrsBatched( \
|
Status CudaSolver::GetrsBatched( \
|
||||||
cublasOperation_t trans, int n, int nrhs, \
|
cublasOperation_t trans, int n, int nrhs, \
|
||||||
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
|
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
|
||||||
const Scalar* const host_b_dev_ptrs[], int ldb, \
|
const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
|
||||||
DeviceLapackInfo* dev_lapack_info, int batch_size) { \
|
int batch_size) { \
|
||||||
return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
|
return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
|
||||||
BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
|
BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
|
||||||
this, context_, cublas_handle_, trans, n, nrhs, \
|
this, context_, cublas_handle_, trans, n, nrhs, \
|
||||||
host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
|
host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
|
||||||
ldb, dev_lapack_info, batch_size); \
|
ldb, host_lapack_info, batch_size); \
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
|
TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
|
||||||
|
@ -235,13 +235,14 @@ class CudaSolver {
|
|||||||
int batch_size) TF_MUST_USE_RESULT;
|
int batch_size) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Batched linear solver using LU factorization from getrfBatched.
|
// Batched linear solver using LU factorization from getrfBatched.
|
||||||
// See:
|
// Notice that lapack_info is returned on the host, as opposed to
|
||||||
|
// most of the other functions that return it on the device. See:
|
||||||
// http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
|
// http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
|
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
|
||||||
const Scalar* const dev_Aarray[], int lda,
|
const Scalar* const dev_Aarray[], int lda,
|
||||||
const int* devIpiv, const Scalar* const dev_Barray[],
|
const int* devIpiv, const Scalar* const dev_Barray[],
|
||||||
int ldb, DeviceLapackInfo* dev_lapack_info,
|
int ldb, int* host_lapack_info,
|
||||||
int batch_size) TF_MUST_USE_RESULT;
|
int batch_size) TF_MUST_USE_RESULT;
|
||||||
|
|
||||||
// Computes matrix inverses for a batch of small matrices. Uses the outputs
|
// Computes matrix inverses for a batch of small matrices. Uses the outputs
|
||||||
|
@ -214,9 +214,12 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
|||||||
auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
|
auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
|
||||||
sizeof(Scalar*) * batch_size, "input_copt_ptrs",
|
sizeof(Scalar*) * batch_size, "input_copt_ptrs",
|
||||||
/* on_host */ true);
|
/* on_host */ true);
|
||||||
if (n / batch_size <= 128) {
|
const int kMaxMatrixSizeToBatchSizeRatio = 128;
|
||||||
// For small matrices or large batch sizes, we use the batched
|
const bool use_batched_solver =
|
||||||
// interface from cuBlas.
|
n <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
|
||||||
|
if (use_batched_solver) {
|
||||||
|
// For small matrices or large batch sizes, we use the batched interface
|
||||||
|
// from cuBlas.
|
||||||
const Scalar** input_copy_ptrs_base =
|
const Scalar** input_copy_ptrs_base =
|
||||||
reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
|
reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
|
||||||
for (int batch = 0; batch < batch_size; ++batch) {
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
@ -230,8 +233,8 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
|||||||
&dev_info.back(), batch_size),
|
&dev_info.back(), batch_size),
|
||||||
done);
|
done);
|
||||||
} else {
|
} else {
|
||||||
// For small batch sizes we use the non-batched interface from cuSolver,
|
// For small batch sizes or large matrices, we use the non-batched
|
||||||
// which is much faster for large matrices.
|
// interface from cuSolver, which is much faster for large matrices.
|
||||||
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
|
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
|
||||||
for (int batch = 0; batch < batch_size; ++batch) {
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
OP_REQUIRES_OK_ASYNC(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
@ -279,11 +282,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
|||||||
/* on_host */ true);
|
/* on_host */ true);
|
||||||
auto transposed_rhs_reshaped =
|
auto transposed_rhs_reshaped =
|
||||||
transposed_rhs.template flat_inner_dims<Scalar, 3>();
|
transposed_rhs.template flat_inner_dims<Scalar, 3>();
|
||||||
// TODO(rmlarsen): Enable the following branch when I figure
|
if (use_batched_solver) {
|
||||||
// out why it causes a segfault.
|
|
||||||
if (false && n / batch_size <= 128) {
|
|
||||||
dev_info.push_back(
|
|
||||||
solver->GetDeviceLapackInfo(batch_size, "GetrsBatched"));
|
|
||||||
const Scalar** input_copy_ptrs_base =
|
const Scalar** input_copy_ptrs_base =
|
||||||
reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
|
reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
|
||||||
const Scalar** transposed_rhs_ptrs_base =
|
const Scalar** transposed_rhs_ptrs_base =
|
||||||
@ -293,13 +292,20 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
|
|||||||
input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
|
input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
|
||||||
transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0);
|
transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0);
|
||||||
}
|
}
|
||||||
|
int host_info = 0;
|
||||||
OP_REQUIRES_OK_ASYNC(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
context,
|
context,
|
||||||
solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
|
solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
|
||||||
input_copy_ptrs_base, n, pivots_mat.data(),
|
input_copy_ptrs_base, n, pivots_mat.data(),
|
||||||
transposed_rhs_ptrs_base, n, &dev_info.back(),
|
transposed_rhs_ptrs_base, n, &host_info,
|
||||||
batch_size),
|
batch_size),
|
||||||
done);
|
done);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, host_info == 0,
|
||||||
|
errors::InvalidArgument("The ", -host_info,
|
||||||
|
"'th argument to cublas*getrsBatched had "
|
||||||
|
"an illegal value."),
|
||||||
|
done);
|
||||||
} else {
|
} else {
|
||||||
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
|
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
|
||||||
for (int batch = 0; batch < batch_size; ++batch) {
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user