Add GPU support for QR decomposition.

Remove support support for on-the-fly transpose in internal matrix_band_part functor recently added (in anticipation of using it for QR), since it turned out to not be useful.

PiperOrigin-RevId: 169249336
This commit is contained in:
A. Unique TensorFlower 2017-09-19 09:09:05 -07:00 committed by TensorFlower Gardener
parent ec962ff638
commit 34018f8fa7
14 changed files with 546 additions and 213 deletions

View File

@ -2334,7 +2334,11 @@ tf_kernel_library(
tf_kernel_library(
name = "qr_op",
prefix = "qr_op",
deps = LINALG_DEPS,
deps = LINALG_DEPS + if_cuda([
":cuda_solvers",
":transpose_functor",
":matrix_band_part_op",
]),
)
tf_kernel_library(

View File

@ -76,14 +76,14 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
struct MatrixBandPartFunctor<GPUDevice, T> { \
void operator()(OpKernelContext* context, const GPUDevice& device, \
int num_upper_diags, int num_lower_diags, bool transpose, \
typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
}; \
#define DECLARE_GPU_SPEC(T) \
template <> \
struct MatrixBandPartFunctor<GPUDevice, T> { \
void operator()(OpKernelContext* context, const GPUDevice& device, \
int num_upper_diags, int num_lower_diags, \
typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
}; \
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
@ -132,9 +132,10 @@ class CholeskyOpGpu : public AsyncOpKernel {
// before we launch each of the Cholesky factorization kernels in paralle.
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
input_reshaped, output_reshaped);
functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
band_part(context, context->eigen_device<GPUDevice>(),
n /* num_lower_diags */, 0 /* num_upper_diags */, input_reshaped,
output_reshaped);
// Launch a Cholesky kernel for each matrix in the batch.
const int64 batch_size = input_reshaped.dimension(0);

View File

@ -174,7 +174,7 @@ Status CudaSolver::CopyLapackInfoToHostAsync(
}
info_checker_callback(status, host_lapack_infos);
};
auto cb =
std::bind(wrapped_info_checker_callback, context_,
std::move(info_checker_callback), std::move(host_lapack_infos));
@ -363,6 +363,156 @@ static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver,
TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
OpKernelContext* context,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, Scalar* tau,
int* dev_lapack_info) {
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
/* Allocate device memory for workspace. */
ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(
cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
return Status::OK();
}
#define GEQRF_INSTANCE(Scalar, lapack_prefix) \
template <> \
Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
Scalar* tau, int* dev_lapack_info) const { \
return GeqrfImpl(DN_BUFSIZE_FN(geqrf, lapack_prefix), \
DN_SOLVER_FN(geqrf, lapack_prefix), context_, \
cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
}
TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
static inline Status OrmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
OpKernelContext* context,
cusolverDnHandle_t cusolver_dn_handle,
cublasSideMode_t side, cublasOperation_t trans,
int m, int n, int k, const Scalar* dev_a,
int lda, const Scalar* dev_tau, Scalar* dev_c,
int ldc, int* dev_lapack_info) {
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
/* Allocate device memory for workspace. */
ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(
cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
return Status::OK();
}
// Unfortunately the LAPACK function name differs for the real and complex case
// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
// one separately.
template <>
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const float* dev_a, int lda,
const float* dev_tau, float* dev_c, int ldc,
int* dev_lapack_info) const {
return OrmqrImpl(DN_BUFSIZE_FN(ormqr, S), DN_SOLVER_FN(ormqr, S), context_,
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
dev_tau, dev_c, ldc, dev_lapack_info);
}
template <>
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const double* dev_a, int lda,
const double* dev_tau, double* dev_c, int ldc,
int* dev_lapack_info) const {
return OrmqrImpl(DN_BUFSIZE_FN(ormqr, D), DN_SOLVER_FN(ormqr, D), context_,
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
dev_tau, dev_c, ldc, dev_lapack_info);
}
template <>
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const std::complex<float>* dev_a,
int lda, const std::complex<float>* dev_tau,
std::complex<float>* dev_c, int ldc,
int* dev_lapack_info) const {
return OrmqrImpl(DN_BUFSIZE_FN(unmqr, C), DN_SOLVER_FN(unmqr, C), context_,
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
dev_tau, dev_c, ldc, dev_lapack_info);
}
template <>
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
int n, int k, const std::complex<double>* dev_a,
int lda, const std::complex<double>* dev_tau,
std::complex<double>* dev_c, int ldc,
int* dev_lapack_info) const {
return OrmqrImpl(DN_BUFSIZE_FN(unmqr, Z), DN_SOLVER_FN(unmqr, Z), context_,
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
dev_tau, dev_c, ldc, dev_lapack_info);
}
template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
static inline Status OrgqrImpl(BufSizeFnT bufsize, SolverFnT solver,
OpKernelContext* context,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, int k, Scalar* dev_a, int lda,
const Scalar* dev_tau, int* dev_lapack_info) {
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
CUDAComplex(dev_a), lda,
CUDAComplex(dev_tau), &lwork));
/* Allocate device memory for workspace. */
ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(
solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
lwork, dev_lapack_info));
return Status::OK();
}
// Unfortunately the LAPACK function name differs for the real and complex case
// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
// one separately.
template <>
Status CudaSolver::Orgqr(int m, int n, int k, float* dev_a, int lda,
const float* dev_tau, int* dev_lapack_info) const {
return OrgqrImpl(DN_BUFSIZE_FN(orgqr, S), DN_SOLVER_FN(orgqr, S), context_,
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
dev_lapack_info);
}
template <>
Status CudaSolver::Orgqr(int m, int n, int k, double* dev_a, int lda,
const double* dev_tau, int* dev_lapack_info) const {
return OrgqrImpl(DN_BUFSIZE_FN(orgqr, D), DN_SOLVER_FN(orgqr, D), context_,
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
dev_lapack_info);
}
template <>
Status CudaSolver::Orgqr(int m, int n, int k, std::complex<float>* dev_a,
int lda, const std::complex<float>* dev_tau,
int* dev_lapack_info) const {
return OrgqrImpl(DN_BUFSIZE_FN(ungqr, C), DN_SOLVER_FN(ungqr, C), context_,
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
dev_lapack_info);
}
template <>
Status CudaSolver::Orgqr(int m, int n, int k, std::complex<double>* dev_a,
int lda, const std::complex<double>* dev_tau,
int* dev_lapack_info) const {
return OrgqrImpl(DN_BUFSIZE_FN(ungqr, Z), DN_SOLVER_FN(ungqr, Z), context_,
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
dev_lapack_info);
}
//=============================================================================
// Wrappers of cuBlas computational methods begin here.
//

View File

@ -147,7 +147,7 @@ class CudaSolver {
Status CopyLapackInfoToHostAsync(
const std::vector<DeviceLapackInfo>& dev_lapack_info,
std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
info_checker_callback) const;
info_checker_callback) const TF_MUST_USE_RESULT;
// ====================================================================
// Wrappers for cuSolverDN and cuBlas solvers start here.
@ -166,28 +166,29 @@ class CudaSolver {
const Scalar* alpha, /* host or device pointer */
const Scalar* A, int lda,
const Scalar* beta, /* host or device pointer */
const Scalar* B, int ldb, Scalar* C, int ldc) const;
const Scalar* B, int ldb, Scalar* C,
int ldc) const TF_MUST_USE_RESULT;
// Computes the Cholesky factorization A = L * L^T for a single matrix.
// Returns Status::OK() if the kernel was launched successfully. See:
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
template <typename Scalar>
Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
int* dev_lapack_info) const;
int* dev_lapack_info) const TF_MUST_USE_RESULT;
// LU factorization.
// Computes LU factorization with partial pivoting P * A = L * U.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
template <typename Scalar>
Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
int* dev_lapack_info) const;
int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Uses LU factorization to solve A * X = B.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
template <typename Scalar>
Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
int lda, const int* pivots, Scalar* B, int ldb,
int* dev_lapack_info) const;
int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Computes partially pivoted LU factorizations for a batch of small matrices.
// Returns Status::OK() if the kernel was launched successfully.See:
@ -195,7 +196,7 @@ class CudaSolver {
template <typename Scalar>
Status GetrfBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
int batch_size) const;
int batch_size) const TF_MUST_USE_RESULT;
// Batched linear solver using LU factorization from getrfBatched.
// See:
@ -204,7 +205,8 @@ class CudaSolver {
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
const Scalar* dev_Aarray[], int lda, const int* devIpiv,
const Scalar* dev_Barray[], int ldb,
DeviceLapackInfo* dev_lapack_info, int batch_size) const;
DeviceLapackInfo* dev_lapack_info,
int batch_size) const TF_MUST_USE_RESULT;
// Computes matrix inverses for a batch of small matrices. Uses the outputs
// from GetrfBatched. Returns Status::OK() if the kernel was launched
@ -214,7 +216,8 @@ class CudaSolver {
Status GetriBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
const int* dev_pivots,
const Scalar* host_a_inverse_dev_ptrs[], int ldainv,
DeviceLapackInfo* dev_lapack_info, int batch_size) const;
DeviceLapackInfo* dev_lapack_info,
int batch_size) const TF_MUST_USE_RESULT;
// Computes matrix inverses for a batch of small matrices with size n < 32.
// Returns Status::OK() if the kernel was launched successfully. See:
@ -222,59 +225,58 @@ class CudaSolver {
template <typename Scalar>
Status MatInvBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
const Scalar* host_a_inverse_dev_ptrs[], int ldainv,
DeviceLapackInfo* dev_lapack_info, int batch_size) const;
/*
TODO(rmlarsen, volunteers): Implement the kernels below.
// Uses Cholesky factorization to solve A * X = B.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrs
template <typename Scalar>
Status Potrs(cublasFillMode_t uplo, int n, int nrhs, const Scalar* dev_A, int
lda, Scalar* dev_B, int ldb, int* dev_lapack_info) const;
DeviceLapackInfo* dev_lapack_info,
int batch_size) const TF_MUST_USE_RESULT;
// QR factorization.
// Computes QR factorization A = Q * R.
// Returns Status::OK() if the kernel was launched successfully.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
template <typename Scalar>
Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_TAU, int*
devInfo) const;
Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Multiplies by Q.
// Overwrite matrix C by product of C and Householder matrix Q. The
// Householder matrix Q is represented by the output from Geqrf in dev_a and
// dev_tau.
// Returns Status::OK() if the kernel was launched successfully.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
template <typename Scalar>
Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int
k, const Scalar* dev_a, int lda, const Scalar* dev_tau, Scalar* dev_c, int
ldc, int* dev_lapack_info) const;
Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
Scalar* dev_c, int ldc,
int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Generate Q.
// Overwrites QR factorization produced by Geqrf by Householder matrix Q.
// On input, the Householder matrix Q is represented by the output from Geqrf
// in dev_a and dev_tau. On output, dev_a is overwritten with the first n
// columns of Q.
// Requires m >= n >= 0.
// Returns Status::OK() if the kernel was launched successfully.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
template <typename Scalar>
Status Orgqr(int m, int n, int k, Scalar* dev_A, int lda, const Scalar*
dev_tau, int* dev_lapack_info) const;
Status Orgqr(int m, int n, int k, Scalar* dev_a, int lda,
const Scalar* dev_tau,
int* dev_lapack_info) const TF_MUST_USE_RESULT;
// Singular value decomposition.
// Returns Status::OK() if the kernel was launched successfully.
// TODO(rmlarsen, volunteers): Add support for complex types.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
template <typename Scalar>
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
int ldvt, int* dev_lapack_info) const TF_MUST_USE_RESULT;
/*
TODO(rmlarsen, volunteers): Implement the kernels below.
// Symmetric/Hermitian Eigen decomposition.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
template <typename Scalar>
Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar*
dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const;
*/
// Singular value decomposition.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
template <typename Scalar>
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
int ldvt, int* dev_lapack_info) const;
/*
// Batched linear solver using LU factorization from getrfBatched.
// See:
http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
template <typename Scalar>
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
const Scalar* dev_Aarray[], int lda, const int* devIpiv,
Scalar* dev_Barray[], int ldb, int* info, int batch_size)
const;
*/
dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const TF_MUST_USE_RESULT;
*/
private:
OpKernelContext* context_; // not owned.
@ -371,7 +373,7 @@ namespace functor {
template <typename Device, typename Scalar>
struct AdjointBatchFunctor {
// We assume that the tensor sizes are correct.
void operator()(const Device& d,
void operator()(const Device& device,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output);
};
@ -380,7 +382,8 @@ struct AdjointBatchFunctor {
// in a flattened batch.
template <typename Device, typename Scalar>
struct DeterminantFromPivotedLUFunctor {
void operator()(const Device& d, typename TTypes<Scalar, 3>::Tensor lu_factor,
void operator()(const Device& device,
typename TTypes<Scalar, 3>::Tensor lu_factor,
const int* pivots, typename TTypes<Scalar, 1>::Tensor output,
int* info);
};
@ -390,7 +393,7 @@ struct DeterminantFromPivotedLUFunctor {
// op.
template <typename Device, typename Scalar>
struct EyeFunctor {
void operator()(const Device& d,
void operator()(const Device& device,
typename TTypes<Scalar, 3>::Tensor matrix_batch);
};

View File

@ -190,7 +190,6 @@ struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
}
};
// Instantiate implementations for the 4 numeric types.
template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>;
template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>;
template struct DeterminantFromPivotedLUFunctor<GPUDevice, std::complex<float>>;
@ -202,7 +201,6 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m,
int n, Scalar* matrix_batch_ptr) {
const int matrix_size = m * n;
const Scalar one = Const<Scalar>::make_const(1.0);
const Scalar zero = Const<Scalar>::make_const(0.0);
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
if (batch >= batch_size) {
break;
@ -216,7 +214,7 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m,
if (col >= n) {
break;
}
matrix_batch_ptr[row_start + col] = row == col ? one : zero;
matrix_batch_ptr[row_start + col] = row == col ? one : Scalar();
}
}
}
@ -239,7 +237,6 @@ struct EyeFunctor<GPUDevice, Scalar> {
}
};
// Instantiate implementations for the 4 numeric types.
template struct EyeFunctor<GPUDevice, float>;
template struct EyeFunctor<GPUDevice, double>;
template struct EyeFunctor<GPUDevice, std::complex<float>>;

View File

@ -93,7 +93,7 @@ class MatrixBandPartOp : public OpKernel {
auto output_reshaped = output->flat_inner_dims<T, 3>();
functor::MatrixBandPartFunctor<Device, T> fn;
fn(context, context->eigen_device<Device>(), num_lower, num_upper,
false /* transpose */, input_reshaped, output_reshaped);
input_reshaped, output_reshaped);
}
private:
@ -126,7 +126,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Scalar>
struct MatrixBandPartFunctor<CPUDevice, Scalar> {
void operator()(OpKernelContext* context, const CPUDevice& device,
int num_lower_diags, int num_upper_diags, bool transpose,
int num_lower_diags, int num_upper_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
const int64 b = input.dimension(0);
@ -137,72 +137,46 @@ struct MatrixBandPartFunctor<CPUDevice, Scalar> {
const int64 total_rows = b * m;
const int64 row_cost = 10 * n;
const bool in_place = input.data() == output.data();
CHECK(!(transpose && in_place));
if (!transpose) {
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
if (!in_place) {
std::fill(output.data() + begin * n, output.data() + end * n,
Scalar());
}
const int64 batch_begin = begin / m;
const int64 batch_end = (end + m - 1) / m;
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
const int64 row_begin = begin > batch * m ? begin % m : 0;
const int64 row_end = end < (batch + 1) * m ? end % m : m;
for (int64 row = row_begin; row < row_end; ++row) {
const int64 band_start =
num_lower_diags < 0
? 0
: std::min(n, std::max(0ll, row - num_lower_diags));
const int64 band_end = num_upper_diags < 0
? n
: std::min(static_cast<int64>(n),
row + num_upper_diags + 1);
if (in_place) {
if (band_start > 0) {
std::fill(&output(batch, row, 0),
&output(batch, row, band_start), Scalar());
}
if (band_end < n) {
std::fill(&output(batch, row, band_end), &output(batch, row, n),
Scalar());
}
} else {
if (band_start < band_end) {
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
band_start);
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
1, 1, band_end - band_start);
output.slice(indices, sizes) = input.slice(indices, sizes);
}
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
if (!in_place) {
std::fill(output.data() + begin * n, output.data() + end * n, Scalar());
}
const int64 batch_begin = begin / m;
const int64 batch_end = (end + m - 1) / m;
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
const int64 row_begin = begin > batch * m ? begin % m : 0;
const int64 row_end = end < (batch + 1) * m ? end % m : m;
for (int64 row = row_begin; row < row_end; ++row) {
const int64 band_start =
num_lower_diags < 0
? 0
: std::min(n, std::max(0ll, row - num_lower_diags));
const int64 band_end =
num_upper_diags < 0
? n
: std::min(static_cast<int64>(n), row + num_upper_diags + 1);
if (in_place) {
if (band_start > 0) {
std::fill(&output(batch, row, 0), &output(batch, row, band_start),
Scalar());
}
if (band_end < n) {
std::fill(&output(batch, row, band_end), &output(batch, row, n),
Scalar());
}
} else {
if (band_start < band_end) {
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
band_start);
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
1, 1, band_end - band_start);
output.slice(indices, sizes) = input.slice(indices, sizes);
}
}
}
};
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
} else {
output.device(device) = output.constant(Scalar());
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
const int64 batch_begin = begin / m;
const int64 batch_end = (end + m - 1) / m;
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
const int64 row_begin = begin > batch * m ? begin % m : 0;
const int64 row_end = end < (batch + 1) * m ? end % m : m;
for (int64 row = row_begin; row < row_end; ++row) {
const int64 band_start =
num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
const int64 band_end = num_upper_diags < 0
? n
: std::min(static_cast<int64>(n),
row + num_upper_diags + 1);
for (int64 col = band_start; col < band_end; ++col) {
output(batch, col, row) = input(batch, row, col);
}
}
}
};
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
}
};
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
}
};
@ -216,14 +190,14 @@ TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
struct MatrixBandPartFunctor<GPUDevice, T> { \
void operator()(OpKernelContext* context, const GPUDevice& device, \
int num_upper_diags, int num_lower_diags, bool transpose, \
typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
}; \
#define DECLARE_GPU_SPEC(T) \
template <> \
struct MatrixBandPartFunctor<GPUDevice, T> { \
void operator()(OpKernelContext* context, const GPUDevice& device, \
int num_upper_diags, int num_lower_diags, \
typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
}; \
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);

View File

@ -26,7 +26,7 @@ namespace functor {
template <typename Device, typename Scalar>
struct MatrixBandPartFunctor {
void operator()(OpKernelContext* context, const Device& device,
int num_upper_diags, int num_lower_diags, bool transpose,
int num_upper_diags, int num_lower_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output);
};

View File

@ -28,41 +28,22 @@ namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
template <bool transpose, typename Scalar>
template <typename Scalar>
__global__ void MatrixBandPartKernel(const int num_threads,
const int batch_size, const int m,
const int n, const int num_lower_diags,
const int num_upper_diags,
const Scalar* input_ptr,
Scalar* output_ptr) {
if (!transpose) {
CUDA_1D_KERNEL_LOOP(index, num_threads) {
const int col = index % n;
const int row = (index / n) % m;
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
const int band_end =
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
if (col < band_start || col >= band_end) {
output_ptr[index] = Scalar();
} else {
output_ptr[index] = input_ptr[index];
}
}
} else {
const int matrix_size = m * n;
CUDA_1D_KERNEL_LOOP(index, num_threads) {
const int col = index % n;
const int row = (index / n) % m;
const int batch = index / matrix_size;
const int transpose_index = batch * matrix_size + n * col + row;
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
const int band_end =
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
if (col < band_start || col >= band_end) {
output_ptr[transpose_index] = Scalar();
} else {
output_ptr[transpose_index] = input_ptr[index];
}
CUDA_1D_KERNEL_LOOP(index, num_threads) {
const int col = index % n;
const int row = (index / n) % m;
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
const int band_end = (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
if (col < band_start || col >= band_end) {
output_ptr[index] = Scalar();
} else {
output_ptr[index] = input_ptr[index];
}
}
}
@ -70,7 +51,7 @@ __global__ void MatrixBandPartKernel(const int num_threads,
template <typename Scalar>
struct MatrixBandPartFunctor<GPUDevice, Scalar> {
void operator()(OpKernelContext* context, const GPUDevice& device,
int num_lower_diags, int num_upper_diags, bool transpose,
int num_lower_diags, int num_upper_diags,
typename TTypes<Scalar, 3>::ConstTensor input,
typename TTypes<Scalar, 3>::Tensor output) {
using CudaType = typename CUDAComplexT<Scalar>::type;
@ -80,17 +61,10 @@ struct MatrixBandPartFunctor<GPUDevice, Scalar> {
const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
if (transpose) {
MatrixBandPartKernel<true>
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
num_upper_diags, input_ptr, output_ptr);
} else {
MatrixBandPartKernel<false>
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
num_upper_diags, input_ptr, output_ptr);
}
MatrixBandPartKernel<<<config.block_count, config.thread_per_block, 0,
device.stream()>>>(
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
num_upper_diags, input_ptr, output_ptr);
}
};

View File

@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
#if GOOGLE_CUDA
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
#endif
} // namespace tensorflow

View File

@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
#if GOOGLE_CUDA
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
#endif
} // namespace tensorflow

View File

@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
#if GOOGLE_CUDA
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
#endif
} // namespace tensorflow

View File

@ -19,4 +19,8 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
#if GOOGLE_CUDA
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
#endif
} // namespace tensorflow

View File

@ -19,10 +19,16 @@ limitations under the License.
// individual kernels. A separate file is used for each instantiated kernel to
// improve compilation times.
#include <algorithm>
#include <numeric>
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif
#include "third_party/eigen3/Eigen/QR"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
@ -30,6 +36,13 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#endif
namespace tensorflow {
template <class Scalar>
@ -107,4 +120,189 @@ class QrOp : public LinearAlgebraOp<Scalar> {
TF_DISALLOW_COPY_AND_ASSIGN(QrOp);
};
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
template <class Scalar>
class QrOpGpu : public AsyncOpKernel {
public:
explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
}
void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
const Tensor& input = context->input(0);
const int ndims = input.dims();
const int64 m = input.dim_size(ndims - 2);
const int64 n = input.dim_size(ndims - 1);
const int64 min_size = std::min(m, n);
const int64 batch_size =
input.template flat_inner_dims<Scalar, 3>().dimension(0);
// Validate inputs.
OP_REQUIRES_ASYNC(
context, ndims >= 2,
errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
done);
// Allocate output.
// If full_matrices_ is true then Q is m x m and R is m x n.
// Otherwise, Q is m x min(m, n), and R is min(m, n) x n.
Tensor* q;
TensorShape q_shape = input.shape();
q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size);
OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q),
done);
Tensor* r;
TensorShape r_shape = input.shape();
r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size);
OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r),
done);
if (input.NumElements() == 0) {
done();
return;
}
// Allocate temporaries.
Tensor input_transposed;
TensorShape transposed_shape = input.shape();
transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1));
transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2));
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<Scalar>::value, transposed_shape,
&input_transposed),
done);
Tensor tau;
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<Scalar>::value,
TensorShape({batch_size, min_size}), &tau),
done);
// Transpose input, since cuSolver uses column-major, while TensorFlow uses
// row-major storage.
std::vector<int> perm(ndims);
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[ndims - 2], perm[ndims - 1]);
const GPUDevice& device = context->eigen_device<GPUDevice>();
OP_REQUIRES_OK_ASYNC(
context, DoTranspose(device, input, perm, &input_transposed), done);
// Compute QR decomposition in-place in input_transposed.
CudaSolver solver(context);
std::vector<DeviceLapackInfo> dev_info;
dev_info.emplace_back(context, batch_size, "geqrf");
auto input_transposed_reshaped =
input_transposed.flat_inner_dims<Scalar, 3>();
auto tau_matrix = tau.matrix<Scalar>();
auto r_reshaped = r->flat_inner_dims<Scalar, 3>();
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK_ASYNC(
context,
solver.Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m,
&tau_matrix(batch, 0),
dev_info.back().mutable_data() + batch),
done);
}
// Generate R. R is equal to the upper triangle of the decomposition
// stored in input_transposed. Crop, transpose (to get back to row-major)
// and copy it to the output buffer.
if (full_matrices_ || m == n) {
OP_REQUIRES_OK_ASYNC(
context, DoTranspose(device, input_transposed, perm, r), done);
} else {
const Scalar alpha(1);
const Scalar beta(0);
const Scalar* dummy = nullptr;
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK_ASYNC(
context,
solver.Geam(CUBLAS_OP_T, CUBLAS_OP_N, n,
full_matrices_ ? m : min_size, &alpha,
&input_transposed_reshaped(batch, 0, 0), m, &beta,
dummy, n, &r_reshaped(batch, 0, 0), n),
done);
}
}
// Extract the upper triangle of r (i.e. zero out the strictly lower
// triangle).
functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
auto r_reshaped_const =
const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>();
band_part(context, device, 0 /* num_lower_diags */,
-1 /* num_upper_diags */, r_reshaped_const, r_reshaped);
// Generate Q from the decomposition in input_transposed.
if (m != n && (full_matrices_ || m < n)) {
context->CtxFailure(
errors::Unimplemented("The case m != n && (full_matrices_ || m < "
"n) is not currently supported on GPU."));
done();
return;
/* TODO(rmlarsen): FIXME. This branch fails with info != 0 (both
positive and negative) error statuses from ORMQR.
// Generate full m x m matrix Q by computing the product Q^T * I
// (transpose to get back to row-major form).
functor::EyeFunctor<GPUDevice, Scalar> eye;
auto q_reshaped = q->flat_inner_dims<Scalar, 3>();
eye(device, q_reshaped);
dev_info.emplace_back(context, batch_size, "ormqr");
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK_ASYNC(
context,
solver.Ormqr(CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, m, min_size,
&input_transposed_reshaped(batch, 0, 0), m,
&tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m,
dev_info.back().mutable_data() + batch),
done);
}
*/
} else {
// Generate m x n matrix Q. In this case we can use the more efficient
// algorithm in Orgqr to generate Q in place.
dev_info.emplace_back(context, batch_size, "orgqr");
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK_ASYNC(
context,
solver.Orgqr(
m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m,
&tau_matrix(batch, 0), dev_info.back().mutable_data() + batch),
done);
}
OP_REQUIRES_OK_ASYNC(
context, DoTranspose(device, input_transposed, perm, q), done);
}
// Asynchronously check return status from cuSolver kernels.
TensorReference input_transposed_ref(input_transposed);
TensorReference tau_ref(tau);
auto info_checker = [context, dev_info, input_transposed_ref, tau_ref,
done](const Status& status,
const std::vector<HostLapackInfo>& host_infos) {
input_transposed_ref.Unref();
tau_ref.Unref();
OP_REQUIRES_OK_ASYNC(context, status, done);
done();
};
OP_REQUIRES_OK_ASYNC(
context,
solver.CopyLapackInfoToHostAsync(dev_info, std::move(info_checker)),
done);
}
private:
bool full_matrices_;
TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
};
#endif
} // namespace tensorflow

View File

@ -27,6 +27,13 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def _AddTest(test_class, op_name, testcase_name, fn):
test_name = "_".join(["test", op_name, testcase_name])
if hasattr(test_class, test_name):
raise RuntimeError("Test %s defined more than once" % test_name)
setattr(test_class, test_name, fn)
class QrOpTest(test.TestCase):
def testWrongDimensions(self):
@ -41,7 +48,7 @@ class QrOpTest(test.TestCase):
linalg_ops.qr(vector)
def _GetQrOpTest(dtype_, shape_, use_static_shape_):
def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
@ -95,36 +102,41 @@ def _GetQrOpTest(dtype_, shape_, use_static_shape_):
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
for full_matrices in False, True:
with self.test_session() as sess:
if use_static_shape_:
x_tf = constant_op.constant(x_np)
else:
x_tf = array_ops.placeholder(dtype_)
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
# TODO(rmlarsen): Debug failure due to invalid parameter to ORMQR.
rows_ = shape_[-2]
cols_ = shape_[-1]
use_gpu = False if rows_ < cols_ or (full_matrices_ and
rows_ != cols_) else True
if use_static_shape_:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
else:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
with self.test_session(use_gpu=use_gpu) as sess:
if use_static_shape_:
x_tf = constant_op.constant(x_np)
else:
x_tf = array_ops.placeholder(dtype_)
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_)
q_dims = q_tf_val.shape
np_q = np.ndarray(q_dims, dtype_)
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
new_first_dim = np_q_reshape.shape[0]
if use_static_shape_:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
else:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
for i in range(new_first_dim):
if full_matrices:
np_q_reshape[i,:,:], _ = \
q_dims = q_tf_val.shape
np_q = np.ndarray(q_dims, dtype_)
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
new_first_dim = np_q_reshape.shape[0]
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
for i in range(new_first_dim):
if full_matrices_:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="complete")
else:
np_q_reshape[i,:,:], _ = \
else:
np_q_reshape[i,:,:], _ = \
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
np_q = np.reshape(np_q_reshape, q_dims)
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
CheckUnitary(self, q_tf_val)
np_q = np.reshape(np_q_reshape, q_dims)
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
CheckUnitary(self, q_tf_val)
return Test
@ -133,11 +145,15 @@ if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
for cols in 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
for use_static_shape in True, False:
name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
use_static_shape)
setattr(QrOpTest, "testQr_" + name,
_GetQrOpTest(dtype, shape, use_static_shape))
for full_matrices in False, True:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
for use_static_shape in True, False:
shape = batch_dims + (rows, cols)
name = "%s_%s_full_%s_static_%s" % (dtype.__name__,
"_".join(map(str, shape)),
full_matrices,
use_static_shape)
_AddTest(QrOpTest, "Qr", name,
_GetQrOpTest(dtype, shape, full_matrices,
use_static_shape))
test.main()