From 495b3eeef0386b1b89b7aa9df42f2cf438de6ebc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 8 Jan 2019 12:55:56 -0800 Subject: [PATCH] Enable fast GPU code path for solving many small linear systems in matrix_solve. PiperOrigin-RevId: 228382682 --- tensorflow/core/kernels/cuda_solvers.cc | 12 +++++----- tensorflow/core/kernels/cuda_solvers.h | 5 ++-- tensorflow/core/kernels/matrix_solve_op.cc | 28 +++++++++++++--------- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index a59baaa96fc..39d0a998fdc 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -692,8 +692,8 @@ static inline Status GetrsBatchedImpl( SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs, const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, - const Scalar* const host_b_dev_ptrs[], int ldb, - DeviceLapackInfo* dev_lapack_info, int batch_size) { + const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, + int batch_size) { mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT::type; ScratchSpace dev_a_dev_ptrs = @@ -714,7 +714,7 @@ static inline Status GetrsBatchedImpl( cublas_handle, trans, n, nrhs, reinterpret_cast(dev_a_dev_ptrs.data()), lda, dev_pivots, reinterpret_cast(dev_b_dev_ptrs.mutable_data()), - ldb, dev_lapack_info->mutable_data(), batch_size)); + ldb, host_lapack_info, batch_size)); return Status::OK(); } @@ -723,13 +723,13 @@ static inline Status GetrsBatchedImpl( Status CudaSolver::GetrsBatched( \ cublasOperation_t trans, int n, int nrhs, \ const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \ - const Scalar* const host_b_dev_ptrs[], int ldb, \ - DeviceLapackInfo* dev_lapack_info, int batch_size) { \ + const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \ + int batch_size) { \ return GetrsBatchedImpl(reinterpret_cast( \ BLAS_SOLVER_FN(getrsBatched, type_prefix)), \ this, context_, cublas_handle_, trans, n, nrhs, \ host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \ - ldb, dev_lapack_info, batch_size); \ + ldb, host_lapack_info, batch_size); \ } TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE); diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 2c30d036df7..1fc344731c2 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -235,13 +235,14 @@ class CudaSolver { int batch_size) TF_MUST_USE_RESULT; // Batched linear solver using LU factorization from getrfBatched. - // See: + // Notice that lapack_info is returned on the host, as opposed to + // most of the other functions that return it on the device. See: // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched template Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, const Scalar* const dev_Aarray[], int lda, const int* devIpiv, const Scalar* const dev_Barray[], - int ldb, DeviceLapackInfo* dev_lapack_info, + int ldb, int* host_lapack_info, int batch_size) TF_MUST_USE_RESULT; // Computes matrix inverses for a batch of small matrices. Uses the outputs diff --git a/tensorflow/core/kernels/matrix_solve_op.cc b/tensorflow/core/kernels/matrix_solve_op.cc index 169f3dae76d..f3919a16aa5 100644 --- a/tensorflow/core/kernels/matrix_solve_op.cc +++ b/tensorflow/core/kernels/matrix_solve_op.cc @@ -214,9 +214,12 @@ class MatrixSolveOpGpu : public AsyncOpKernel { auto input_copy_ptrs = solver->GetScratchSpace( sizeof(Scalar*) * batch_size, "input_copt_ptrs", /* on_host */ true); - if (n / batch_size <= 128) { - // For small matrices or large batch sizes, we use the batched - // interface from cuBlas. + const int kMaxMatrixSizeToBatchSizeRatio = 128; + const bool use_batched_solver = + n <= kMaxMatrixSizeToBatchSizeRatio * batch_size; + if (use_batched_solver) { + // For small matrices or large batch sizes, we use the batched interface + // from cuBlas. const Scalar** input_copy_ptrs_base = reinterpret_cast(input_copy_ptrs.mutable_data()); for (int batch = 0; batch < batch_size; ++batch) { @@ -230,8 +233,8 @@ class MatrixSolveOpGpu : public AsyncOpKernel { &dev_info.back(), batch_size), done); } else { - // For small batch sizes we use the non-batched interface from cuSolver, - // which is much faster for large matrices. + // For small batch sizes or large matrices, we use the non-batched + // interface from cuSolver, which is much faster for large matrices. dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); for (int batch = 0; batch < batch_size; ++batch) { OP_REQUIRES_OK_ASYNC( @@ -279,11 +282,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel { /* on_host */ true); auto transposed_rhs_reshaped = transposed_rhs.template flat_inner_dims(); - // TODO(rmlarsen): Enable the following branch when I figure - // out why it causes a segfault. - if (false && n / batch_size <= 128) { - dev_info.push_back( - solver->GetDeviceLapackInfo(batch_size, "GetrsBatched")); + if (use_batched_solver) { const Scalar** input_copy_ptrs_base = reinterpret_cast(input_copy_ptr_array.mutable_data()); const Scalar** transposed_rhs_ptrs_base = @@ -293,13 +292,20 @@ class MatrixSolveOpGpu : public AsyncOpKernel { input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0); } + int host_info = 0; OP_REQUIRES_OK_ASYNC( context, solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs, input_copy_ptrs_base, n, pivots_mat.data(), - transposed_rhs_ptrs_base, n, &dev_info.back(), + transposed_rhs_ptrs_base, n, &host_info, batch_size), done); + OP_REQUIRES_ASYNC( + context, host_info == 0, + errors::InvalidArgument("The ", -host_info, + "'th argument to cublas*getrsBatched had " + "an illegal value."), + done); } else { dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs")); for (int batch = 0; batch < batch_size; ++batch) {