Add GPU support for QR decomposition.
Remove support support for on-the-fly transpose in internal matrix_band_part functor recently added (in anticipation of using it for QR), since it turned out to not be useful. PiperOrigin-RevId: 169249336
This commit is contained in:
parent
ec962ff638
commit
34018f8fa7
@ -2334,7 +2334,11 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "qr_op",
|
||||
prefix = "qr_op",
|
||||
deps = LINALG_DEPS,
|
||||
deps = LINALG_DEPS + if_cuda([
|
||||
":cuda_solvers",
|
||||
":transpose_functor",
|
||||
":matrix_band_part_op",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -76,14 +76,14 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
struct MatrixBandPartFunctor<GPUDevice, T> { \
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device, \
|
||||
int num_upper_diags, int num_lower_diags, bool transpose, \
|
||||
typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
}; \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
struct MatrixBandPartFunctor<GPUDevice, T> { \
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device, \
|
||||
int num_upper_diags, int num_lower_diags, \
|
||||
typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
}; \
|
||||
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
@ -132,9 +132,10 @@ class CholeskyOpGpu : public AsyncOpKernel {
|
||||
// before we launch each of the Cholesky factorization kernels in paralle.
|
||||
auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
|
||||
auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
|
||||
functor::MatrixBandPartFunctor<GPUDevice, Scalar> fn;
|
||||
fn(context, context->eigen_device<GPUDevice>(), n, 0, false /* transpose */,
|
||||
input_reshaped, output_reshaped);
|
||||
functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
|
||||
band_part(context, context->eigen_device<GPUDevice>(),
|
||||
n /* num_lower_diags */, 0 /* num_upper_diags */, input_reshaped,
|
||||
output_reshaped);
|
||||
|
||||
// Launch a Cholesky kernel for each matrix in the batch.
|
||||
const int64 batch_size = input_reshaped.dimension(0);
|
||||
|
@ -174,7 +174,7 @@ Status CudaSolver::CopyLapackInfoToHostAsync(
|
||||
}
|
||||
info_checker_callback(status, host_lapack_infos);
|
||||
};
|
||||
|
||||
|
||||
auto cb =
|
||||
std::bind(wrapped_info_checker_callback, context_,
|
||||
std::move(info_checker_callback), std::move(host_lapack_infos));
|
||||
@ -363,6 +363,156 @@ static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
|
||||
TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
|
||||
static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
OpKernelContext* context,
|
||||
cusolverDnHandle_t cusolver_dn_handle, int m,
|
||||
int n, Scalar* A, int lda, Scalar* tau,
|
||||
int* dev_lapack_info) {
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
|
||||
/* Allocate device memory for workspace. */
|
||||
ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
|
||||
/* Launch the solver kernel. */
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(solver(
|
||||
cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
|
||||
CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define GEQRF_INSTANCE(Scalar, lapack_prefix) \
|
||||
template <> \
|
||||
Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
|
||||
Scalar* tau, int* dev_lapack_info) const { \
|
||||
return GeqrfImpl(DN_BUFSIZE_FN(geqrf, lapack_prefix), \
|
||||
DN_SOLVER_FN(geqrf, lapack_prefix), context_, \
|
||||
cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
|
||||
}
|
||||
|
||||
TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
|
||||
|
||||
template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
|
||||
static inline Status OrmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
OpKernelContext* context,
|
||||
cusolverDnHandle_t cusolver_dn_handle,
|
||||
cublasSideMode_t side, cublasOperation_t trans,
|
||||
int m, int n, int k, const Scalar* dev_a,
|
||||
int lda, const Scalar* dev_tau, Scalar* dev_c,
|
||||
int ldc, int* dev_lapack_info) {
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
|
||||
CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
|
||||
/* Allocate device memory for workspace. */
|
||||
ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
|
||||
/* Launch the solver kernel. */
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(solver(
|
||||
cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
|
||||
CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
|
||||
CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Unfortunately the LAPACK function name differs for the real and complex case
|
||||
// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
|
||||
// one separately.
|
||||
template <>
|
||||
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
|
||||
int n, int k, const float* dev_a, int lda,
|
||||
const float* dev_tau, float* dev_c, int ldc,
|
||||
int* dev_lapack_info) const {
|
||||
return OrmqrImpl(DN_BUFSIZE_FN(ormqr, S), DN_SOLVER_FN(ormqr, S), context_,
|
||||
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
|
||||
dev_tau, dev_c, ldc, dev_lapack_info);
|
||||
}
|
||||
template <>
|
||||
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
|
||||
int n, int k, const double* dev_a, int lda,
|
||||
const double* dev_tau, double* dev_c, int ldc,
|
||||
int* dev_lapack_info) const {
|
||||
return OrmqrImpl(DN_BUFSIZE_FN(ormqr, D), DN_SOLVER_FN(ormqr, D), context_,
|
||||
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
|
||||
dev_tau, dev_c, ldc, dev_lapack_info);
|
||||
}
|
||||
template <>
|
||||
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
|
||||
int n, int k, const std::complex<float>* dev_a,
|
||||
int lda, const std::complex<float>* dev_tau,
|
||||
std::complex<float>* dev_c, int ldc,
|
||||
int* dev_lapack_info) const {
|
||||
return OrmqrImpl(DN_BUFSIZE_FN(unmqr, C), DN_SOLVER_FN(unmqr, C), context_,
|
||||
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
|
||||
dev_tau, dev_c, ldc, dev_lapack_info);
|
||||
}
|
||||
template <>
|
||||
Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m,
|
||||
int n, int k, const std::complex<double>* dev_a,
|
||||
int lda, const std::complex<double>* dev_tau,
|
||||
std::complex<double>* dev_c, int ldc,
|
||||
int* dev_lapack_info) const {
|
||||
return OrmqrImpl(DN_BUFSIZE_FN(unmqr, Z), DN_SOLVER_FN(unmqr, Z), context_,
|
||||
cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda,
|
||||
dev_tau, dev_c, ldc, dev_lapack_info);
|
||||
}
|
||||
|
||||
template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
|
||||
static inline Status OrgqrImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
OpKernelContext* context,
|
||||
cusolverDnHandle_t cusolver_dn_handle, int m,
|
||||
int n, int k, Scalar* dev_a, int lda,
|
||||
const Scalar* dev_tau, int* dev_lapack_info) {
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
|
||||
CUDAComplex(dev_a), lda,
|
||||
CUDAComplex(dev_tau), &lwork));
|
||||
/* Allocate device memory for workspace. */
|
||||
ScratchSpace<Scalar> dev_workspace(context, lwork, /* on_host */ false);
|
||||
/* Launch the solver kernel. */
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
|
||||
CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
|
||||
lwork, dev_lapack_info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Unfortunately the LAPACK function name differs for the real and complex case
|
||||
// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
|
||||
// one separately.
|
||||
template <>
|
||||
Status CudaSolver::Orgqr(int m, int n, int k, float* dev_a, int lda,
|
||||
const float* dev_tau, int* dev_lapack_info) const {
|
||||
return OrgqrImpl(DN_BUFSIZE_FN(orgqr, S), DN_SOLVER_FN(orgqr, S), context_,
|
||||
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
|
||||
dev_lapack_info);
|
||||
}
|
||||
template <>
|
||||
Status CudaSolver::Orgqr(int m, int n, int k, double* dev_a, int lda,
|
||||
const double* dev_tau, int* dev_lapack_info) const {
|
||||
return OrgqrImpl(DN_BUFSIZE_FN(orgqr, D), DN_SOLVER_FN(orgqr, D), context_,
|
||||
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
|
||||
dev_lapack_info);
|
||||
}
|
||||
template <>
|
||||
Status CudaSolver::Orgqr(int m, int n, int k, std::complex<float>* dev_a,
|
||||
int lda, const std::complex<float>* dev_tau,
|
||||
int* dev_lapack_info) const {
|
||||
return OrgqrImpl(DN_BUFSIZE_FN(ungqr, C), DN_SOLVER_FN(ungqr, C), context_,
|
||||
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
|
||||
dev_lapack_info);
|
||||
}
|
||||
template <>
|
||||
Status CudaSolver::Orgqr(int m, int n, int k, std::complex<double>* dev_a,
|
||||
int lda, const std::complex<double>* dev_tau,
|
||||
int* dev_lapack_info) const {
|
||||
return OrgqrImpl(DN_BUFSIZE_FN(ungqr, Z), DN_SOLVER_FN(ungqr, Z), context_,
|
||||
cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau,
|
||||
dev_lapack_info);
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Wrappers of cuBlas computational methods begin here.
|
||||
//
|
||||
|
@ -147,7 +147,7 @@ class CudaSolver {
|
||||
Status CopyLapackInfoToHostAsync(
|
||||
const std::vector<DeviceLapackInfo>& dev_lapack_info,
|
||||
std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
|
||||
info_checker_callback) const;
|
||||
info_checker_callback) const TF_MUST_USE_RESULT;
|
||||
|
||||
// ====================================================================
|
||||
// Wrappers for cuSolverDN and cuBlas solvers start here.
|
||||
@ -166,28 +166,29 @@ class CudaSolver {
|
||||
const Scalar* alpha, /* host or device pointer */
|
||||
const Scalar* A, int lda,
|
||||
const Scalar* beta, /* host or device pointer */
|
||||
const Scalar* B, int ldb, Scalar* C, int ldc) const;
|
||||
const Scalar* B, int ldb, Scalar* C,
|
||||
int ldc) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Computes the Cholesky factorization A = L * L^T for a single matrix.
|
||||
// Returns Status::OK() if the kernel was launched successfully. See:
|
||||
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
|
||||
template <typename Scalar>
|
||||
Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
|
||||
int* dev_lapack_info) const;
|
||||
int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
// LU factorization.
|
||||
// Computes LU factorization with partial pivoting P * A = L * U.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
|
||||
template <typename Scalar>
|
||||
Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
|
||||
int* dev_lapack_info) const;
|
||||
int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Uses LU factorization to solve A * X = B.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
|
||||
template <typename Scalar>
|
||||
Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
|
||||
int lda, const int* pivots, Scalar* B, int ldb,
|
||||
int* dev_lapack_info) const;
|
||||
int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Computes partially pivoted LU factorizations for a batch of small matrices.
|
||||
// Returns Status::OK() if the kernel was launched successfully.See:
|
||||
@ -195,7 +196,7 @@ class CudaSolver {
|
||||
template <typename Scalar>
|
||||
Status GetrfBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
|
||||
int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
|
||||
int batch_size) const;
|
||||
int batch_size) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Batched linear solver using LU factorization from getrfBatched.
|
||||
// See:
|
||||
@ -204,7 +205,8 @@ class CudaSolver {
|
||||
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
|
||||
const Scalar* dev_Aarray[], int lda, const int* devIpiv,
|
||||
const Scalar* dev_Barray[], int ldb,
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) const;
|
||||
DeviceLapackInfo* dev_lapack_info,
|
||||
int batch_size) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Computes matrix inverses for a batch of small matrices. Uses the outputs
|
||||
// from GetrfBatched. Returns Status::OK() if the kernel was launched
|
||||
@ -214,7 +216,8 @@ class CudaSolver {
|
||||
Status GetriBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
|
||||
const int* dev_pivots,
|
||||
const Scalar* host_a_inverse_dev_ptrs[], int ldainv,
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) const;
|
||||
DeviceLapackInfo* dev_lapack_info,
|
||||
int batch_size) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Computes matrix inverses for a batch of small matrices with size n < 32.
|
||||
// Returns Status::OK() if the kernel was launched successfully. See:
|
||||
@ -222,59 +225,58 @@ class CudaSolver {
|
||||
template <typename Scalar>
|
||||
Status MatInvBatched(int n, const Scalar* host_a_dev_ptrs[], int lda,
|
||||
const Scalar* host_a_inverse_dev_ptrs[], int ldainv,
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) const;
|
||||
|
||||
/*
|
||||
TODO(rmlarsen, volunteers): Implement the kernels below.
|
||||
// Uses Cholesky factorization to solve A * X = B.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrs
|
||||
template <typename Scalar>
|
||||
Status Potrs(cublasFillMode_t uplo, int n, int nrhs, const Scalar* dev_A, int
|
||||
lda, Scalar* dev_B, int ldb, int* dev_lapack_info) const;
|
||||
DeviceLapackInfo* dev_lapack_info,
|
||||
int batch_size) const TF_MUST_USE_RESULT;
|
||||
|
||||
// QR factorization.
|
||||
// Computes QR factorization A = Q * R.
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
|
||||
template <typename Scalar>
|
||||
Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_TAU, int*
|
||||
devInfo) const;
|
||||
Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
|
||||
int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Multiplies by Q.
|
||||
// Overwrite matrix C by product of C and Householder matrix Q. The
|
||||
// Householder matrix Q is represented by the output from Geqrf in dev_a and
|
||||
// dev_tau.
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
|
||||
template <typename Scalar>
|
||||
Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int
|
||||
k, const Scalar* dev_a, int lda, const Scalar* dev_tau, Scalar* dev_c, int
|
||||
ldc, int* dev_lapack_info) const;
|
||||
Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
|
||||
int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
|
||||
Scalar* dev_c, int ldc,
|
||||
int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Generate Q.
|
||||
// Overwrites QR factorization produced by Geqrf by Householder matrix Q.
|
||||
// On input, the Householder matrix Q is represented by the output from Geqrf
|
||||
// in dev_a and dev_tau. On output, dev_a is overwritten with the first n
|
||||
// columns of Q.
|
||||
// Requires m >= n >= 0.
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
|
||||
template <typename Scalar>
|
||||
Status Orgqr(int m, int n, int k, Scalar* dev_A, int lda, const Scalar*
|
||||
dev_tau, int* dev_lapack_info) const;
|
||||
Status Orgqr(int m, int n, int k, Scalar* dev_a, int lda,
|
||||
const Scalar* dev_tau,
|
||||
int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Singular value decomposition.
|
||||
// Returns Status::OK() if the kernel was launched successfully.
|
||||
// TODO(rmlarsen, volunteers): Add support for complex types.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
|
||||
template <typename Scalar>
|
||||
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
|
||||
int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
|
||||
int ldvt, int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
|
||||
/*
|
||||
TODO(rmlarsen, volunteers): Implement the kernels below.
|
||||
|
||||
// Symmetric/Hermitian Eigen decomposition.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
|
||||
template <typename Scalar>
|
||||
Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar*
|
||||
dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const;
|
||||
|
||||
*/
|
||||
// Singular value decomposition.
|
||||
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
|
||||
template <typename Scalar>
|
||||
Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
|
||||
int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
|
||||
int ldvt, int* dev_lapack_info) const;
|
||||
/*
|
||||
// Batched linear solver using LU factorization from getrfBatched.
|
||||
// See:
|
||||
http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
|
||||
template <typename Scalar>
|
||||
Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
|
||||
const Scalar* dev_Aarray[], int lda, const int* devIpiv,
|
||||
Scalar* dev_Barray[], int ldb, int* info, int batch_size)
|
||||
const;
|
||||
*/
|
||||
dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const TF_MUST_USE_RESULT;
|
||||
*/
|
||||
|
||||
private:
|
||||
OpKernelContext* context_; // not owned.
|
||||
@ -371,7 +373,7 @@ namespace functor {
|
||||
template <typename Device, typename Scalar>
|
||||
struct AdjointBatchFunctor {
|
||||
// We assume that the tensor sizes are correct.
|
||||
void operator()(const Device& d,
|
||||
void operator()(const Device& device,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output);
|
||||
};
|
||||
@ -380,7 +382,8 @@ struct AdjointBatchFunctor {
|
||||
// in a flattened batch.
|
||||
template <typename Device, typename Scalar>
|
||||
struct DeterminantFromPivotedLUFunctor {
|
||||
void operator()(const Device& d, typename TTypes<Scalar, 3>::Tensor lu_factor,
|
||||
void operator()(const Device& device,
|
||||
typename TTypes<Scalar, 3>::Tensor lu_factor,
|
||||
const int* pivots, typename TTypes<Scalar, 1>::Tensor output,
|
||||
int* info);
|
||||
};
|
||||
@ -390,7 +393,7 @@ struct DeterminantFromPivotedLUFunctor {
|
||||
// op.
|
||||
template <typename Device, typename Scalar>
|
||||
struct EyeFunctor {
|
||||
void operator()(const Device& d,
|
||||
void operator()(const Device& device,
|
||||
typename TTypes<Scalar, 3>::Tensor matrix_batch);
|
||||
};
|
||||
|
||||
|
@ -190,7 +190,6 @@ struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
|
||||
}
|
||||
};
|
||||
|
||||
// Instantiate implementations for the 4 numeric types.
|
||||
template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>;
|
||||
template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>;
|
||||
template struct DeterminantFromPivotedLUFunctor<GPUDevice, std::complex<float>>;
|
||||
@ -202,7 +201,6 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m,
|
||||
int n, Scalar* matrix_batch_ptr) {
|
||||
const int matrix_size = m * n;
|
||||
const Scalar one = Const<Scalar>::make_const(1.0);
|
||||
const Scalar zero = Const<Scalar>::make_const(0.0);
|
||||
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
|
||||
if (batch >= batch_size) {
|
||||
break;
|
||||
@ -216,7 +214,7 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m,
|
||||
if (col >= n) {
|
||||
break;
|
||||
}
|
||||
matrix_batch_ptr[row_start + col] = row == col ? one : zero;
|
||||
matrix_batch_ptr[row_start + col] = row == col ? one : Scalar();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -239,7 +237,6 @@ struct EyeFunctor<GPUDevice, Scalar> {
|
||||
}
|
||||
};
|
||||
|
||||
// Instantiate implementations for the 4 numeric types.
|
||||
template struct EyeFunctor<GPUDevice, float>;
|
||||
template struct EyeFunctor<GPUDevice, double>;
|
||||
template struct EyeFunctor<GPUDevice, std::complex<float>>;
|
||||
|
@ -93,7 +93,7 @@ class MatrixBandPartOp : public OpKernel {
|
||||
auto output_reshaped = output->flat_inner_dims<T, 3>();
|
||||
functor::MatrixBandPartFunctor<Device, T> fn;
|
||||
fn(context, context->eigen_device<Device>(), num_lower, num_upper,
|
||||
false /* transpose */, input_reshaped, output_reshaped);
|
||||
input_reshaped, output_reshaped);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -126,7 +126,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
template <typename Scalar>
|
||||
struct MatrixBandPartFunctor<CPUDevice, Scalar> {
|
||||
void operator()(OpKernelContext* context, const CPUDevice& device,
|
||||
int num_lower_diags, int num_upper_diags, bool transpose,
|
||||
int num_lower_diags, int num_upper_diags,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output) {
|
||||
const int64 b = input.dimension(0);
|
||||
@ -137,72 +137,46 @@ struct MatrixBandPartFunctor<CPUDevice, Scalar> {
|
||||
const int64 total_rows = b * m;
|
||||
const int64 row_cost = 10 * n;
|
||||
const bool in_place = input.data() == output.data();
|
||||
CHECK(!(transpose && in_place));
|
||||
if (!transpose) {
|
||||
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
|
||||
if (!in_place) {
|
||||
std::fill(output.data() + begin * n, output.data() + end * n,
|
||||
Scalar());
|
||||
}
|
||||
const int64 batch_begin = begin / m;
|
||||
const int64 batch_end = (end + m - 1) / m;
|
||||
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
|
||||
const int64 row_begin = begin > batch * m ? begin % m : 0;
|
||||
const int64 row_end = end < (batch + 1) * m ? end % m : m;
|
||||
for (int64 row = row_begin; row < row_end; ++row) {
|
||||
const int64 band_start =
|
||||
num_lower_diags < 0
|
||||
? 0
|
||||
: std::min(n, std::max(0ll, row - num_lower_diags));
|
||||
const int64 band_end = num_upper_diags < 0
|
||||
? n
|
||||
: std::min(static_cast<int64>(n),
|
||||
row + num_upper_diags + 1);
|
||||
if (in_place) {
|
||||
if (band_start > 0) {
|
||||
std::fill(&output(batch, row, 0),
|
||||
&output(batch, row, band_start), Scalar());
|
||||
}
|
||||
if (band_end < n) {
|
||||
std::fill(&output(batch, row, band_end), &output(batch, row, n),
|
||||
Scalar());
|
||||
}
|
||||
} else {
|
||||
if (band_start < band_end) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
|
||||
band_start);
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
|
||||
1, 1, band_end - band_start);
|
||||
output.slice(indices, sizes) = input.slice(indices, sizes);
|
||||
}
|
||||
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
|
||||
if (!in_place) {
|
||||
std::fill(output.data() + begin * n, output.data() + end * n, Scalar());
|
||||
}
|
||||
const int64 batch_begin = begin / m;
|
||||
const int64 batch_end = (end + m - 1) / m;
|
||||
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
|
||||
const int64 row_begin = begin > batch * m ? begin % m : 0;
|
||||
const int64 row_end = end < (batch + 1) * m ? end % m : m;
|
||||
for (int64 row = row_begin; row < row_end; ++row) {
|
||||
const int64 band_start =
|
||||
num_lower_diags < 0
|
||||
? 0
|
||||
: std::min(n, std::max(0ll, row - num_lower_diags));
|
||||
const int64 band_end =
|
||||
num_upper_diags < 0
|
||||
? n
|
||||
: std::min(static_cast<int64>(n), row + num_upper_diags + 1);
|
||||
if (in_place) {
|
||||
if (band_start > 0) {
|
||||
std::fill(&output(batch, row, 0), &output(batch, row, band_start),
|
||||
Scalar());
|
||||
}
|
||||
if (band_end < n) {
|
||||
std::fill(&output(batch, row, band_end), &output(batch, row, n),
|
||||
Scalar());
|
||||
}
|
||||
} else {
|
||||
if (band_start < band_end) {
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
|
||||
band_start);
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
|
||||
1, 1, band_end - band_start);
|
||||
output.slice(indices, sizes) = input.slice(indices, sizes);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
|
||||
} else {
|
||||
output.device(device) = output.constant(Scalar());
|
||||
auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
|
||||
const int64 batch_begin = begin / m;
|
||||
const int64 batch_end = (end + m - 1) / m;
|
||||
for (int64 batch = batch_begin; batch < batch_end; ++batch) {
|
||||
const int64 row_begin = begin > batch * m ? begin % m : 0;
|
||||
const int64 row_end = end < (batch + 1) * m ? end % m : m;
|
||||
for (int64 row = row_begin; row < row_end; ++row) {
|
||||
const int64 band_start =
|
||||
num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags);
|
||||
const int64 band_end = num_upper_diags < 0
|
||||
? n
|
||||
: std::min(static_cast<int64>(n),
|
||||
row + num_upper_diags + 1);
|
||||
for (int64 col = band_start; col < band_end; ++col) {
|
||||
output(batch, col, row) = input(batch, row, col);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
|
||||
}
|
||||
}
|
||||
};
|
||||
thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
|
||||
}
|
||||
};
|
||||
|
||||
@ -216,14 +190,14 @@ TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
struct MatrixBandPartFunctor<GPUDevice, T> { \
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device, \
|
||||
int num_upper_diags, int num_lower_diags, bool transpose, \
|
||||
typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
}; \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
struct MatrixBandPartFunctor<GPUDevice, T> { \
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device, \
|
||||
int num_upper_diags, int num_lower_diags, \
|
||||
typename TTypes<T, 3>::ConstTensor input, \
|
||||
typename TTypes<T, 3>::Tensor output); \
|
||||
}; \
|
||||
extern template struct MatrixBandPartFunctor<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
|
@ -26,7 +26,7 @@ namespace functor {
|
||||
template <typename Device, typename Scalar>
|
||||
struct MatrixBandPartFunctor {
|
||||
void operator()(OpKernelContext* context, const Device& device,
|
||||
int num_upper_diags, int num_lower_diags, bool transpose,
|
||||
int num_upper_diags, int num_lower_diags,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output);
|
||||
};
|
||||
|
@ -28,41 +28,22 @@ namespace tensorflow {
|
||||
namespace functor {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <bool transpose, typename Scalar>
|
||||
template <typename Scalar>
|
||||
__global__ void MatrixBandPartKernel(const int num_threads,
|
||||
const int batch_size, const int m,
|
||||
const int n, const int num_lower_diags,
|
||||
const int num_upper_diags,
|
||||
const Scalar* input_ptr,
|
||||
Scalar* output_ptr) {
|
||||
if (!transpose) {
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int col = index % n;
|
||||
const int row = (index / n) % m;
|
||||
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
|
||||
const int band_end =
|
||||
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
|
||||
if (col < band_start || col >= band_end) {
|
||||
output_ptr[index] = Scalar();
|
||||
} else {
|
||||
output_ptr[index] = input_ptr[index];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int matrix_size = m * n;
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int col = index % n;
|
||||
const int row = (index / n) % m;
|
||||
const int batch = index / matrix_size;
|
||||
const int transpose_index = batch * matrix_size + n * col + row;
|
||||
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
|
||||
const int band_end =
|
||||
(num_upper_diags < 0 ? n : row + num_upper_diags + 1);
|
||||
if (col < band_start || col >= band_end) {
|
||||
output_ptr[transpose_index] = Scalar();
|
||||
} else {
|
||||
output_ptr[transpose_index] = input_ptr[index];
|
||||
}
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int col = index % n;
|
||||
const int row = (index / n) % m;
|
||||
const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags);
|
||||
const int band_end = (num_upper_diags < 0 ? n : row + num_upper_diags + 1);
|
||||
if (col < band_start || col >= band_end) {
|
||||
output_ptr[index] = Scalar();
|
||||
} else {
|
||||
output_ptr[index] = input_ptr[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -70,7 +51,7 @@ __global__ void MatrixBandPartKernel(const int num_threads,
|
||||
template <typename Scalar>
|
||||
struct MatrixBandPartFunctor<GPUDevice, Scalar> {
|
||||
void operator()(OpKernelContext* context, const GPUDevice& device,
|
||||
int num_lower_diags, int num_upper_diags, bool transpose,
|
||||
int num_lower_diags, int num_upper_diags,
|
||||
typename TTypes<Scalar, 3>::ConstTensor input,
|
||||
typename TTypes<Scalar, 3>::Tensor output) {
|
||||
using CudaType = typename CUDAComplexT<Scalar>::type;
|
||||
@ -80,17 +61,10 @@ struct MatrixBandPartFunctor<GPUDevice, Scalar> {
|
||||
const CudaType* input_ptr = reinterpret_cast<const CudaType*>(input.data());
|
||||
CudaType* output_ptr = reinterpret_cast<CudaType*>(output.data());
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device);
|
||||
if (transpose) {
|
||||
MatrixBandPartKernel<true>
|
||||
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
|
||||
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
|
||||
num_upper_diags, input_ptr, output_ptr);
|
||||
} else {
|
||||
MatrixBandPartKernel<false>
|
||||
<<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
|
||||
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
|
||||
num_upper_diags, input_ptr, output_ptr);
|
||||
}
|
||||
MatrixBandPartKernel<<<config.block_count, config.thread_per_block, 0,
|
||||
device.stream()>>>(
|
||||
config.virtual_thread_count, batch_size, m, n, num_lower_diags,
|
||||
num_upper_diags, input_ptr, output_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -19,4 +19,8 @@ namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,4 +19,8 @@ namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,4 +19,8 @@ namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,4 +19,8 @@ namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,10 +19,16 @@ limitations under the License.
|
||||
// individual kernels. A separate file is used for each instantiated kernel to
|
||||
// improve compilation times.
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif
|
||||
|
||||
#include "third_party/eigen3/Eigen/QR"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -30,6 +36,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <class Scalar>
|
||||
@ -107,4 +120,189 @@ class QrOp : public LinearAlgebraOp<Scalar> {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QrOp);
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <class Scalar>
|
||||
class QrOpGpu : public AsyncOpKernel {
|
||||
public:
|
||||
explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
|
||||
const Tensor& input = context->input(0);
|
||||
const int ndims = input.dims();
|
||||
const int64 m = input.dim_size(ndims - 2);
|
||||
const int64 n = input.dim_size(ndims - 1);
|
||||
const int64 min_size = std::min(m, n);
|
||||
const int64 batch_size =
|
||||
input.template flat_inner_dims<Scalar, 3>().dimension(0);
|
||||
|
||||
// Validate inputs.
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, ndims >= 2,
|
||||
errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
|
||||
done);
|
||||
|
||||
// Allocate output.
|
||||
// If full_matrices_ is true then Q is m x m and R is m x n.
|
||||
// Otherwise, Q is m x min(m, n), and R is min(m, n) x n.
|
||||
Tensor* q;
|
||||
TensorShape q_shape = input.shape();
|
||||
q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size);
|
||||
OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q),
|
||||
done);
|
||||
Tensor* r;
|
||||
TensorShape r_shape = input.shape();
|
||||
r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size);
|
||||
OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r),
|
||||
done);
|
||||
|
||||
if (input.NumElements() == 0) {
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
// Allocate temporaries.
|
||||
Tensor input_transposed;
|
||||
TensorShape transposed_shape = input.shape();
|
||||
transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1));
|
||||
transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2));
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<Scalar>::value, transposed_shape,
|
||||
&input_transposed),
|
||||
done);
|
||||
|
||||
Tensor tau;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<Scalar>::value,
|
||||
TensorShape({batch_size, min_size}), &tau),
|
||||
done);
|
||||
|
||||
// Transpose input, since cuSolver uses column-major, while TensorFlow uses
|
||||
// row-major storage.
|
||||
std::vector<int> perm(ndims);
|
||||
std::iota(perm.begin(), perm.end(), 0);
|
||||
std::swap(perm[ndims - 2], perm[ndims - 1]);
|
||||
const GPUDevice& device = context->eigen_device<GPUDevice>();
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, DoTranspose(device, input, perm, &input_transposed), done);
|
||||
|
||||
// Compute QR decomposition in-place in input_transposed.
|
||||
CudaSolver solver(context);
|
||||
std::vector<DeviceLapackInfo> dev_info;
|
||||
dev_info.emplace_back(context, batch_size, "geqrf");
|
||||
auto input_transposed_reshaped =
|
||||
input_transposed.flat_inner_dims<Scalar, 3>();
|
||||
auto tau_matrix = tau.matrix<Scalar>();
|
||||
auto r_reshaped = r->flat_inner_dims<Scalar, 3>();
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
solver.Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m,
|
||||
&tau_matrix(batch, 0),
|
||||
dev_info.back().mutable_data() + batch),
|
||||
done);
|
||||
}
|
||||
|
||||
// Generate R. R is equal to the upper triangle of the decomposition
|
||||
// stored in input_transposed. Crop, transpose (to get back to row-major)
|
||||
// and copy it to the output buffer.
|
||||
if (full_matrices_ || m == n) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, DoTranspose(device, input_transposed, perm, r), done);
|
||||
} else {
|
||||
const Scalar alpha(1);
|
||||
const Scalar beta(0);
|
||||
const Scalar* dummy = nullptr;
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
solver.Geam(CUBLAS_OP_T, CUBLAS_OP_N, n,
|
||||
full_matrices_ ? m : min_size, &alpha,
|
||||
&input_transposed_reshaped(batch, 0, 0), m, &beta,
|
||||
dummy, n, &r_reshaped(batch, 0, 0), n),
|
||||
done);
|
||||
}
|
||||
}
|
||||
// Extract the upper triangle of r (i.e. zero out the strictly lower
|
||||
// triangle).
|
||||
functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
|
||||
auto r_reshaped_const =
|
||||
const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>();
|
||||
band_part(context, device, 0 /* num_lower_diags */,
|
||||
-1 /* num_upper_diags */, r_reshaped_const, r_reshaped);
|
||||
|
||||
// Generate Q from the decomposition in input_transposed.
|
||||
if (m != n && (full_matrices_ || m < n)) {
|
||||
context->CtxFailure(
|
||||
errors::Unimplemented("The case m != n && (full_matrices_ || m < "
|
||||
"n) is not currently supported on GPU."));
|
||||
done();
|
||||
return;
|
||||
|
||||
/* TODO(rmlarsen): FIXME. This branch fails with info != 0 (both
|
||||
positive and negative) error statuses from ORMQR.
|
||||
|
||||
// Generate full m x m matrix Q by computing the product Q^T * I
|
||||
// (transpose to get back to row-major form).
|
||||
functor::EyeFunctor<GPUDevice, Scalar> eye;
|
||||
auto q_reshaped = q->flat_inner_dims<Scalar, 3>();
|
||||
eye(device, q_reshaped);
|
||||
dev_info.emplace_back(context, batch_size, "ormqr");
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
solver.Ormqr(CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, m, min_size,
|
||||
&input_transposed_reshaped(batch, 0, 0), m,
|
||||
&tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m,
|
||||
dev_info.back().mutable_data() + batch),
|
||||
done);
|
||||
}
|
||||
*/
|
||||
} else {
|
||||
// Generate m x n matrix Q. In this case we can use the more efficient
|
||||
// algorithm in Orgqr to generate Q in place.
|
||||
dev_info.emplace_back(context, batch_size, "orgqr");
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
solver.Orgqr(
|
||||
m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m,
|
||||
&tau_matrix(batch, 0), dev_info.back().mutable_data() + batch),
|
||||
done);
|
||||
}
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, DoTranspose(device, input_transposed, perm, q), done);
|
||||
}
|
||||
|
||||
// Asynchronously check return status from cuSolver kernels.
|
||||
TensorReference input_transposed_ref(input_transposed);
|
||||
TensorReference tau_ref(tau);
|
||||
auto info_checker = [context, dev_info, input_transposed_ref, tau_ref,
|
||||
done](const Status& status,
|
||||
const std::vector<HostLapackInfo>& host_infos) {
|
||||
input_transposed_ref.Unref();
|
||||
tau_ref.Unref();
|
||||
OP_REQUIRES_OK_ASYNC(context, status, done);
|
||||
done();
|
||||
};
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
solver.CopyLapackInfoToHostAsync(dev_info, std::move(info_checker)),
|
||||
done);
|
||||
}
|
||||
|
||||
private:
|
||||
bool full_matrices_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,13 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _AddTest(test_class, op_name, testcase_name, fn):
|
||||
test_name = "_".join(["test", op_name, testcase_name])
|
||||
if hasattr(test_class, test_name):
|
||||
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||
setattr(test_class, test_name, fn)
|
||||
|
||||
|
||||
class QrOpTest(test.TestCase):
|
||||
|
||||
def testWrongDimensions(self):
|
||||
@ -41,7 +48,7 @@ class QrOpTest(test.TestCase):
|
||||
linalg_ops.qr(vector)
|
||||
|
||||
|
||||
def _GetQrOpTest(dtype_, shape_, use_static_shape_):
|
||||
def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
|
||||
|
||||
is_complex = dtype_ in (np.complex64, np.complex128)
|
||||
is_single = dtype_ in (np.float32, np.complex64)
|
||||
@ -95,36 +102,41 @@ def _GetQrOpTest(dtype_, shape_, use_static_shape_):
|
||||
low=-1.0, high=1.0,
|
||||
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
|
||||
|
||||
for full_matrices in False, True:
|
||||
with self.test_session() as sess:
|
||||
if use_static_shape_:
|
||||
x_tf = constant_op.constant(x_np)
|
||||
else:
|
||||
x_tf = array_ops.placeholder(dtype_)
|
||||
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
|
||||
# TODO(rmlarsen): Debug failure due to invalid parameter to ORMQR.
|
||||
rows_ = shape_[-2]
|
||||
cols_ = shape_[-1]
|
||||
use_gpu = False if rows_ < cols_ or (full_matrices_ and
|
||||
rows_ != cols_) else True
|
||||
|
||||
if use_static_shape_:
|
||||
q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
|
||||
else:
|
||||
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
|
||||
with self.test_session(use_gpu=use_gpu) as sess:
|
||||
if use_static_shape_:
|
||||
x_tf = constant_op.constant(x_np)
|
||||
else:
|
||||
x_tf = array_ops.placeholder(dtype_)
|
||||
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_)
|
||||
|
||||
q_dims = q_tf_val.shape
|
||||
np_q = np.ndarray(q_dims, dtype_)
|
||||
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
|
||||
new_first_dim = np_q_reshape.shape[0]
|
||||
if use_static_shape_:
|
||||
q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
|
||||
else:
|
||||
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
|
||||
|
||||
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
|
||||
for i in range(new_first_dim):
|
||||
if full_matrices:
|
||||
np_q_reshape[i,:,:], _ = \
|
||||
q_dims = q_tf_val.shape
|
||||
np_q = np.ndarray(q_dims, dtype_)
|
||||
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
|
||||
new_first_dim = np_q_reshape.shape[0]
|
||||
|
||||
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
|
||||
for i in range(new_first_dim):
|
||||
if full_matrices_:
|
||||
np_q_reshape[i,:,:], _ = \
|
||||
np.linalg.qr(x_reshape[i,:,:], mode="complete")
|
||||
else:
|
||||
np_q_reshape[i,:,:], _ = \
|
||||
else:
|
||||
np_q_reshape[i,:,:], _ = \
|
||||
np.linalg.qr(x_reshape[i,:,:], mode="reduced")
|
||||
np_q = np.reshape(np_q_reshape, q_dims)
|
||||
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
|
||||
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
|
||||
CheckUnitary(self, q_tf_val)
|
||||
np_q = np.reshape(np_q_reshape, q_dims)
|
||||
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
|
||||
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
|
||||
CheckUnitary(self, q_tf_val)
|
||||
|
||||
return Test
|
||||
|
||||
@ -133,11 +145,15 @@ if __name__ == "__main__":
|
||||
for dtype in np.float32, np.float64, np.complex64, np.complex128:
|
||||
for rows in 1, 2, 5, 10, 32, 100:
|
||||
for cols in 1, 2, 5, 10, 32, 100:
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
|
||||
shape = batch_dims + (rows, cols)
|
||||
for use_static_shape in True, False:
|
||||
name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)),
|
||||
use_static_shape)
|
||||
setattr(QrOpTest, "testQr_" + name,
|
||||
_GetQrOpTest(dtype, shape, use_static_shape))
|
||||
for full_matrices in False, True:
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
|
||||
for use_static_shape in True, False:
|
||||
shape = batch_dims + (rows, cols)
|
||||
name = "%s_%s_full_%s_static_%s" % (dtype.__name__,
|
||||
"_".join(map(str, shape)),
|
||||
full_matrices,
|
||||
use_static_shape)
|
||||
_AddTest(QrOpTest, "Qr", name,
|
||||
_GetQrOpTest(dtype, shape, full_matrices,
|
||||
use_static_shape))
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user