From 34018f8fa7290650291bbd478534e58c128a5df4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Sep 2017 09:09:05 -0700 Subject: [PATCH] Add GPU support for QR decomposition. Remove support support for on-the-fly transpose in internal matrix_band_part functor recently added (in anticipation of using it for QR), since it turned out to not be useful. PiperOrigin-RevId: 169249336 --- tensorflow/core/kernels/BUILD | 6 +- tensorflow/core/kernels/cholesky_op.cc | 23 +- tensorflow/core/kernels/cuda_solvers.cc | 152 +++++++++++++- tensorflow/core/kernels/cuda_solvers.h | 99 ++++----- .../core/kernels/cuda_solvers_gpu.cu.cc | 5 +- .../core/kernels/matrix_band_part_op.cc | 120 +++++------ tensorflow/core/kernels/matrix_band_part_op.h | 2 +- .../kernels/matrix_band_part_op_gpu.cu.cc | 56 ++--- tensorflow/core/kernels/qr_op_complex128.cc | 4 + tensorflow/core/kernels/qr_op_complex64.cc | 4 + tensorflow/core/kernels/qr_op_double.cc | 4 + tensorflow/core/kernels/qr_op_float.cc | 4 + tensorflow/core/kernels/qr_op_impl.h | 198 ++++++++++++++++++ tensorflow/python/kernel_tests/qr_op_test.py | 82 +++++--- 14 files changed, 546 insertions(+), 213 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index cff6e30c04d..dcbbe5335d2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2334,7 +2334,11 @@ tf_kernel_library( tf_kernel_library( name = "qr_op", prefix = "qr_op", - deps = LINALG_DEPS, + deps = LINALG_DEPS + if_cuda([ + ":cuda_solvers", + ":transpose_functor", + ":matrix_band_part_op", + ]), ) tf_kernel_library( diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc index 6668b0d654f..3adff530f73 100644 --- a/tensorflow/core/kernels/cholesky_op.cc +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -76,14 +76,14 @@ class CholeskyOp : public LinearAlgebraOp { typedef Eigen::GpuDevice GPUDevice; namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - struct MatrixBandPartFunctor { \ - void operator()(OpKernelContext* context, const GPUDevice& device, \ - int num_upper_diags, int num_lower_diags, bool transpose, \ - typename TTypes::ConstTensor input, \ - typename TTypes::Tensor output); \ - }; \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + struct MatrixBandPartFunctor { \ + void operator()(OpKernelContext* context, const GPUDevice& device, \ + int num_upper_diags, int num_lower_diags, \ + typename TTypes::ConstTensor input, \ + typename TTypes::Tensor output); \ + }; \ extern template struct MatrixBandPartFunctor; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); @@ -132,9 +132,10 @@ class CholeskyOpGpu : public AsyncOpKernel { // before we launch each of the Cholesky factorization kernels in paralle. auto input_reshaped = input.template flat_inner_dims(); auto output_reshaped = output->template flat_inner_dims(); - functor::MatrixBandPartFunctor fn; - fn(context, context->eigen_device(), n, 0, false /* transpose */, - input_reshaped, output_reshaped); + functor::MatrixBandPartFunctor band_part; + band_part(context, context->eigen_device(), + n /* num_lower_diags */, 0 /* num_upper_diags */, input_reshaped, + output_reshaped); // Launch a Cholesky kernel for each matrix in the batch. const int64 batch_size = input_reshaped.dimension(0); diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index 43197d8cf41..85f1473c6c3 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -174,7 +174,7 @@ Status CudaSolver::CopyLapackInfoToHostAsync( } info_checker_callback(status, host_lapack_infos); }; - + auto cb = std::bind(wrapped_info_checker_callback, context_, std::move(info_checker_callback), std::move(host_lapack_infos)); @@ -363,6 +363,156 @@ static inline Status GesvdImpl(BufSizeFnT bufsize, SolverFnT solver, TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE); +template +static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver, + OpKernelContext* context, + cusolverDnHandle_t cusolver_dn_handle, int m, + int n, Scalar* A, int lda, Scalar* tau, + int* dev_lapack_info) { + /* Get amount of workspace memory required. */ + int lwork; + TF_RETURN_IF_CUSOLVER_ERROR( + bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork)); + /* Allocate device memory for workspace. */ + ScratchSpace dev_workspace(context, lwork, /* on_host */ false); + /* Launch the solver kernel. */ + TF_RETURN_IF_CUSOLVER_ERROR(solver( + cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau), + CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); + return Status::OK(); +} + +#define GEQRF_INSTANCE(Scalar, lapack_prefix) \ + template <> \ + Status CudaSolver::Geqrf(int m, int n, Scalar* A, int lda, \ + Scalar* tau, int* dev_lapack_info) const { \ + return GeqrfImpl(DN_BUFSIZE_FN(geqrf, lapack_prefix), \ + DN_SOLVER_FN(geqrf, lapack_prefix), context_, \ + cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \ + } + +TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE); + +template +static inline Status OrmqrImpl(BufSizeFnT bufsize, SolverFnT solver, + OpKernelContext* context, + cusolverDnHandle_t cusolver_dn_handle, + cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const Scalar* dev_a, + int lda, const Scalar* dev_tau, Scalar* dev_c, + int ldc, int* dev_lapack_info) { + /* Get amount of workspace memory required. */ + int lwork; + TF_RETURN_IF_CUSOLVER_ERROR( + bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda, + CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork)); + /* Allocate device memory for workspace. */ + ScratchSpace dev_workspace(context, lwork, /* on_host */ false); + /* Launch the solver kernel. */ + TF_RETURN_IF_CUSOLVER_ERROR(solver( + cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda, + CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, + CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); + return Status::OK(); +} + +// Unfortunately the LAPACK function name differs for the real and complex case +// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each +// one separately. +template <> +Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, + int n, int k, const float* dev_a, int lda, + const float* dev_tau, float* dev_c, int ldc, + int* dev_lapack_info) const { + return OrmqrImpl(DN_BUFSIZE_FN(ormqr, S), DN_SOLVER_FN(ormqr, S), context_, + cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda, + dev_tau, dev_c, ldc, dev_lapack_info); +} +template <> +Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, + int n, int k, const double* dev_a, int lda, + const double* dev_tau, double* dev_c, int ldc, + int* dev_lapack_info) const { + return OrmqrImpl(DN_BUFSIZE_FN(ormqr, D), DN_SOLVER_FN(ormqr, D), context_, + cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda, + dev_tau, dev_c, ldc, dev_lapack_info); +} +template <> +Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, + int n, int k, const std::complex* dev_a, + int lda, const std::complex* dev_tau, + std::complex* dev_c, int ldc, + int* dev_lapack_info) const { + return OrmqrImpl(DN_BUFSIZE_FN(unmqr, C), DN_SOLVER_FN(unmqr, C), context_, + cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda, + dev_tau, dev_c, ldc, dev_lapack_info); +} +template <> +Status CudaSolver::Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, + int n, int k, const std::complex* dev_a, + int lda, const std::complex* dev_tau, + std::complex* dev_c, int ldc, + int* dev_lapack_info) const { + return OrmqrImpl(DN_BUFSIZE_FN(unmqr, Z), DN_SOLVER_FN(unmqr, Z), context_, + cusolver_dn_handle_, side, trans, m, n, k, dev_a, lda, + dev_tau, dev_c, ldc, dev_lapack_info); +} + +template +static inline Status OrgqrImpl(BufSizeFnT bufsize, SolverFnT solver, + OpKernelContext* context, + cusolverDnHandle_t cusolver_dn_handle, int m, + int n, int k, Scalar* dev_a, int lda, + const Scalar* dev_tau, int* dev_lapack_info) { + /* Get amount of workspace memory required. */ + int lwork; + TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k, + CUDAComplex(dev_a), lda, + CUDAComplex(dev_tau), &lwork)); + /* Allocate device memory for workspace. */ + ScratchSpace dev_workspace(context, lwork, /* on_host */ false); + /* Launch the solver kernel. */ + TF_RETURN_IF_CUSOLVER_ERROR( + solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda, + CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()), + lwork, dev_lapack_info)); + return Status::OK(); +} + +// Unfortunately the LAPACK function name differs for the real and complex case +// (complex ones are prefixed with "UN" for "unitary"), so we instantiate each +// one separately. +template <> +Status CudaSolver::Orgqr(int m, int n, int k, float* dev_a, int lda, + const float* dev_tau, int* dev_lapack_info) const { + return OrgqrImpl(DN_BUFSIZE_FN(orgqr, S), DN_SOLVER_FN(orgqr, S), context_, + cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau, + dev_lapack_info); +} +template <> +Status CudaSolver::Orgqr(int m, int n, int k, double* dev_a, int lda, + const double* dev_tau, int* dev_lapack_info) const { + return OrgqrImpl(DN_BUFSIZE_FN(orgqr, D), DN_SOLVER_FN(orgqr, D), context_, + cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau, + dev_lapack_info); +} +template <> +Status CudaSolver::Orgqr(int m, int n, int k, std::complex* dev_a, + int lda, const std::complex* dev_tau, + int* dev_lapack_info) const { + return OrgqrImpl(DN_BUFSIZE_FN(ungqr, C), DN_SOLVER_FN(ungqr, C), context_, + cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau, + dev_lapack_info); +} +template <> +Status CudaSolver::Orgqr(int m, int n, int k, std::complex* dev_a, + int lda, const std::complex* dev_tau, + int* dev_lapack_info) const { + return OrgqrImpl(DN_BUFSIZE_FN(ungqr, Z), DN_SOLVER_FN(ungqr, Z), context_, + cusolver_dn_handle_, m, n, k, dev_a, lda, dev_tau, + dev_lapack_info); +} + //============================================================================= // Wrappers of cuBlas computational methods begin here. // diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 7cbdc895dde..38873a0decf 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -147,7 +147,7 @@ class CudaSolver { Status CopyLapackInfoToHostAsync( const std::vector& dev_lapack_info, std::function&)> - info_checker_callback) const; + info_checker_callback) const TF_MUST_USE_RESULT; // ==================================================================== // Wrappers for cuSolverDN and cuBlas solvers start here. @@ -166,28 +166,29 @@ class CudaSolver { const Scalar* alpha, /* host or device pointer */ const Scalar* A, int lda, const Scalar* beta, /* host or device pointer */ - const Scalar* B, int ldb, Scalar* C, int ldc) const; + const Scalar* B, int ldb, Scalar* C, + int ldc) const TF_MUST_USE_RESULT; // Computes the Cholesky factorization A = L * L^T for a single matrix. // Returns Status::OK() if the kernel was launched successfully. See: // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf template Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, - int* dev_lapack_info) const; + int* dev_lapack_info) const TF_MUST_USE_RESULT; // LU factorization. // Computes LU factorization with partial pivoting P * A = L * U. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf template Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, - int* dev_lapack_info) const; + int* dev_lapack_info) const TF_MUST_USE_RESULT; // Uses LU factorization to solve A * X = B. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs template Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, const int* pivots, Scalar* B, int ldb, - int* dev_lapack_info) const; + int* dev_lapack_info) const TF_MUST_USE_RESULT; // Computes partially pivoted LU factorizations for a batch of small matrices. // Returns Status::OK() if the kernel was launched successfully.See: @@ -195,7 +196,7 @@ class CudaSolver { template Status GetrfBatched(int n, const Scalar* host_a_dev_ptrs[], int lda, int* dev_pivots, DeviceLapackInfo* dev_lapack_info, - int batch_size) const; + int batch_size) const TF_MUST_USE_RESULT; // Batched linear solver using LU factorization from getrfBatched. // See: @@ -204,7 +205,8 @@ class CudaSolver { Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, const Scalar* dev_Aarray[], int lda, const int* devIpiv, const Scalar* dev_Barray[], int ldb, - DeviceLapackInfo* dev_lapack_info, int batch_size) const; + DeviceLapackInfo* dev_lapack_info, + int batch_size) const TF_MUST_USE_RESULT; // Computes matrix inverses for a batch of small matrices. Uses the outputs // from GetrfBatched. Returns Status::OK() if the kernel was launched @@ -214,7 +216,8 @@ class CudaSolver { Status GetriBatched(int n, const Scalar* host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* host_a_inverse_dev_ptrs[], int ldainv, - DeviceLapackInfo* dev_lapack_info, int batch_size) const; + DeviceLapackInfo* dev_lapack_info, + int batch_size) const TF_MUST_USE_RESULT; // Computes matrix inverses for a batch of small matrices with size n < 32. // Returns Status::OK() if the kernel was launched successfully. See: @@ -222,59 +225,58 @@ class CudaSolver { template Status MatInvBatched(int n, const Scalar* host_a_dev_ptrs[], int lda, const Scalar* host_a_inverse_dev_ptrs[], int ldainv, - DeviceLapackInfo* dev_lapack_info, int batch_size) const; - - /* - TODO(rmlarsen, volunteers): Implement the kernels below. - // Uses Cholesky factorization to solve A * X = B. - // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrs - template - Status Potrs(cublasFillMode_t uplo, int n, int nrhs, const Scalar* dev_A, int - lda, Scalar* dev_B, int ldb, int* dev_lapack_info) const; + DeviceLapackInfo* dev_lapack_info, + int batch_size) const TF_MUST_USE_RESULT; // QR factorization. // Computes QR factorization A = Q * R. + // Returns Status::OK() if the kernel was launched successfully. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf template - Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_TAU, int* - devInfo) const; + Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, + int* dev_lapack_info) const TF_MUST_USE_RESULT; - // Multiplies by Q. + // Overwrite matrix C by product of C and Householder matrix Q. The + // Householder matrix Q is represented by the output from Geqrf in dev_a and + // dev_tau. + // Returns Status::OK() if the kernel was launched successfully. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr template - Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int - k, const Scalar* dev_a, int lda, const Scalar* dev_tau, Scalar* dev_c, int - ldc, int* dev_lapack_info) const; + Status Ormqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, + int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, + Scalar* dev_c, int ldc, + int* dev_lapack_info) const TF_MUST_USE_RESULT; - // Generate Q. + // Overwrites QR factorization produced by Geqrf by Householder matrix Q. + // On input, the Householder matrix Q is represented by the output from Geqrf + // in dev_a and dev_tau. On output, dev_a is overwritten with the first n + // columns of Q. + // Requires m >= n >= 0. + // Returns Status::OK() if the kernel was launched successfully. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr template - Status Orgqr(int m, int n, int k, Scalar* dev_A, int lda, const Scalar* - dev_tau, int* dev_lapack_info) const; + Status Orgqr(int m, int n, int k, Scalar* dev_a, int lda, + const Scalar* dev_tau, + int* dev_lapack_info) const TF_MUST_USE_RESULT; + + // Singular value decomposition. + // Returns Status::OK() if the kernel was launched successfully. + // TODO(rmlarsen, volunteers): Add support for complex types. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd + template + Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, + int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, + int ldvt, int* dev_lapack_info) const TF_MUST_USE_RESULT; + + /* + TODO(rmlarsen, volunteers): Implement the kernels below. // Symmetric/Hermitian Eigen decomposition. // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd template Status Syevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar* - dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const; - -*/ - // Singular value decomposition. - // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd - template - Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, - int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, - int ldvt, int* dev_lapack_info) const; - /* - // Batched linear solver using LU factorization from getrfBatched. - // See: - http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched - template - Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, - const Scalar* dev_Aarray[], int lda, const int* devIpiv, - Scalar* dev_Barray[], int ldb, int* info, int batch_size) - const; - */ + dev_A, int lda, Scalar* dev_W, int* dev_lapack_info) const TF_MUST_USE_RESULT; + */ private: OpKernelContext* context_; // not owned. @@ -371,7 +373,7 @@ namespace functor { template struct AdjointBatchFunctor { // We assume that the tensor sizes are correct. - void operator()(const Device& d, + void operator()(const Device& device, typename TTypes::ConstTensor input, typename TTypes::Tensor output); }; @@ -380,7 +382,8 @@ struct AdjointBatchFunctor { // in a flattened batch. template struct DeterminantFromPivotedLUFunctor { - void operator()(const Device& d, typename TTypes::Tensor lu_factor, + void operator()(const Device& device, + typename TTypes::Tensor lu_factor, const int* pivots, typename TTypes::Tensor output, int* info); }; @@ -390,7 +393,7 @@ struct DeterminantFromPivotedLUFunctor { // op. template struct EyeFunctor { - void operator()(const Device& d, + void operator()(const Device& device, typename TTypes::Tensor matrix_batch); }; diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc index af6c094d7ac..bbbe1377b25 100644 --- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc +++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc @@ -190,7 +190,6 @@ struct DeterminantFromPivotedLUFunctor { } }; -// Instantiate implementations for the 4 numeric types. template struct DeterminantFromPivotedLUFunctor; template struct DeterminantFromPivotedLUFunctor; template struct DeterminantFromPivotedLUFunctor>; @@ -202,7 +201,6 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m, int n, Scalar* matrix_batch_ptr) { const int matrix_size = m * n; const Scalar one = Const::make_const(1.0); - const Scalar zero = Const::make_const(0.0); CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) { if (batch >= batch_size) { break; @@ -216,7 +214,7 @@ __global__ void EyeKernel(Cuda3DLaunchConfig config, int batch_size, int m, if (col >= n) { break; } - matrix_batch_ptr[row_start + col] = row == col ? one : zero; + matrix_batch_ptr[row_start + col] = row == col ? one : Scalar(); } } } @@ -239,7 +237,6 @@ struct EyeFunctor { } }; -// Instantiate implementations for the 4 numeric types. template struct EyeFunctor; template struct EyeFunctor; template struct EyeFunctor>; diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc index 8b8accc0b3c..e5f9086dbaf 100644 --- a/tensorflow/core/kernels/matrix_band_part_op.cc +++ b/tensorflow/core/kernels/matrix_band_part_op.cc @@ -93,7 +93,7 @@ class MatrixBandPartOp : public OpKernel { auto output_reshaped = output->flat_inner_dims(); functor::MatrixBandPartFunctor fn; fn(context, context->eigen_device(), num_lower, num_upper, - false /* transpose */, input_reshaped, output_reshaped); + input_reshaped, output_reshaped); } private: @@ -126,7 +126,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; template struct MatrixBandPartFunctor { void operator()(OpKernelContext* context, const CPUDevice& device, - int num_lower_diags, int num_upper_diags, bool transpose, + int num_lower_diags, int num_upper_diags, typename TTypes::ConstTensor input, typename TTypes::Tensor output) { const int64 b = input.dimension(0); @@ -137,72 +137,46 @@ struct MatrixBandPartFunctor { const int64 total_rows = b * m; const int64 row_cost = 10 * n; const bool in_place = input.data() == output.data(); - CHECK(!(transpose && in_place)); - if (!transpose) { - auto compute_shard = [=, &input, &output](int64 begin, int64 end) { - if (!in_place) { - std::fill(output.data() + begin * n, output.data() + end * n, - Scalar()); - } - const int64 batch_begin = begin / m; - const int64 batch_end = (end + m - 1) / m; - for (int64 batch = batch_begin; batch < batch_end; ++batch) { - const int64 row_begin = begin > batch * m ? begin % m : 0; - const int64 row_end = end < (batch + 1) * m ? end % m : m; - for (int64 row = row_begin; row < row_end; ++row) { - const int64 band_start = - num_lower_diags < 0 - ? 0 - : std::min(n, std::max(0ll, row - num_lower_diags)); - const int64 band_end = num_upper_diags < 0 - ? n - : std::min(static_cast(n), - row + num_upper_diags + 1); - if (in_place) { - if (band_start > 0) { - std::fill(&output(batch, row, 0), - &output(batch, row, band_start), Scalar()); - } - if (band_end < n) { - std::fill(&output(batch, row, band_end), &output(batch, row, n), - Scalar()); - } - } else { - if (band_start < band_end) { - const Eigen::DSizes indices(batch, row, - band_start); - const Eigen::DSizes sizes( - 1, 1, band_end - band_start); - output.slice(indices, sizes) = input.slice(indices, sizes); - } + auto compute_shard = [=, &input, &output](int64 begin, int64 end) { + if (!in_place) { + std::fill(output.data() + begin * n, output.data() + end * n, Scalar()); + } + const int64 batch_begin = begin / m; + const int64 batch_end = (end + m - 1) / m; + for (int64 batch = batch_begin; batch < batch_end; ++batch) { + const int64 row_begin = begin > batch * m ? begin % m : 0; + const int64 row_end = end < (batch + 1) * m ? end % m : m; + for (int64 row = row_begin; row < row_end; ++row) { + const int64 band_start = + num_lower_diags < 0 + ? 0 + : std::min(n, std::max(0ll, row - num_lower_diags)); + const int64 band_end = + num_upper_diags < 0 + ? n + : std::min(static_cast(n), row + num_upper_diags + 1); + if (in_place) { + if (band_start > 0) { + std::fill(&output(batch, row, 0), &output(batch, row, band_start), + Scalar()); + } + if (band_end < n) { + std::fill(&output(batch, row, band_end), &output(batch, row, n), + Scalar()); + } + } else { + if (band_start < band_end) { + const Eigen::DSizes indices(batch, row, + band_start); + const Eigen::DSizes sizes( + 1, 1, band_end - band_start); + output.slice(indices, sizes) = input.slice(indices, sizes); } } } - }; - thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard)); - } else { - output.device(device) = output.constant(Scalar()); - auto compute_shard = [=, &input, &output](int64 begin, int64 end) { - const int64 batch_begin = begin / m; - const int64 batch_end = (end + m - 1) / m; - for (int64 batch = batch_begin; batch < batch_end; ++batch) { - const int64 row_begin = begin > batch * m ? begin % m : 0; - const int64 row_end = end < (batch + 1) * m ? end % m : m; - for (int64 row = row_begin; row < row_end; ++row) { - const int64 band_start = - num_lower_diags < 0 ? 0 : std::max(0ll, row - num_lower_diags); - const int64 band_end = num_upper_diags < 0 - ? n - : std::min(static_cast(n), - row + num_upper_diags + 1); - for (int64 col = band_start; col < band_end; ++col) { - output(batch, col, row) = input(batch, row, col); - } - } - } - }; - thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard)); - } + } + }; + thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard)); } }; @@ -216,14 +190,14 @@ TF_CALL_POD_TYPES(DEFINE_CPU_SPEC); // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - struct MatrixBandPartFunctor { \ - void operator()(OpKernelContext* context, const GPUDevice& device, \ - int num_upper_diags, int num_lower_diags, bool transpose, \ - typename TTypes::ConstTensor input, \ - typename TTypes::Tensor output); \ - }; \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + struct MatrixBandPartFunctor { \ + void operator()(OpKernelContext* context, const GPUDevice& device, \ + int num_upper_diags, int num_lower_diags, \ + typename TTypes::ConstTensor input, \ + typename TTypes::Tensor output); \ + }; \ extern template struct MatrixBandPartFunctor; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h index 43b6724dae2..97cc9507932 100644 --- a/tensorflow/core/kernels/matrix_band_part_op.h +++ b/tensorflow/core/kernels/matrix_band_part_op.h @@ -26,7 +26,7 @@ namespace functor { template struct MatrixBandPartFunctor { void operator()(OpKernelContext* context, const Device& device, - int num_upper_diags, int num_lower_diags, bool transpose, + int num_upper_diags, int num_lower_diags, typename TTypes::ConstTensor input, typename TTypes::Tensor output); }; diff --git a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc index afebdacdca9..41b2f5c0efb 100644 --- a/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_band_part_op_gpu.cu.cc @@ -28,41 +28,22 @@ namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -template +template __global__ void MatrixBandPartKernel(const int num_threads, const int batch_size, const int m, const int n, const int num_lower_diags, const int num_upper_diags, const Scalar* input_ptr, Scalar* output_ptr) { - if (!transpose) { - CUDA_1D_KERNEL_LOOP(index, num_threads) { - const int col = index % n; - const int row = (index / n) % m; - const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags); - const int band_end = - (num_upper_diags < 0 ? n : row + num_upper_diags + 1); - if (col < band_start || col >= band_end) { - output_ptr[index] = Scalar(); - } else { - output_ptr[index] = input_ptr[index]; - } - } - } else { - const int matrix_size = m * n; - CUDA_1D_KERNEL_LOOP(index, num_threads) { - const int col = index % n; - const int row = (index / n) % m; - const int batch = index / matrix_size; - const int transpose_index = batch * matrix_size + n * col + row; - const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags); - const int band_end = - (num_upper_diags < 0 ? n : row + num_upper_diags + 1); - if (col < band_start || col >= band_end) { - output_ptr[transpose_index] = Scalar(); - } else { - output_ptr[transpose_index] = input_ptr[index]; - } + CUDA_1D_KERNEL_LOOP(index, num_threads) { + const int col = index % n; + const int row = (index / n) % m; + const int band_start = (num_lower_diags < 0 ? 0 : row - num_lower_diags); + const int band_end = (num_upper_diags < 0 ? n : row + num_upper_diags + 1); + if (col < band_start || col >= band_end) { + output_ptr[index] = Scalar(); + } else { + output_ptr[index] = input_ptr[index]; } } } @@ -70,7 +51,7 @@ __global__ void MatrixBandPartKernel(const int num_threads, template struct MatrixBandPartFunctor { void operator()(OpKernelContext* context, const GPUDevice& device, - int num_lower_diags, int num_upper_diags, bool transpose, + int num_lower_diags, int num_upper_diags, typename TTypes::ConstTensor input, typename TTypes::Tensor output) { using CudaType = typename CUDAComplexT::type; @@ -80,17 +61,10 @@ struct MatrixBandPartFunctor { const CudaType* input_ptr = reinterpret_cast(input.data()); CudaType* output_ptr = reinterpret_cast(output.data()); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * m * n, device); - if (transpose) { - MatrixBandPartKernel - <<>>( - config.virtual_thread_count, batch_size, m, n, num_lower_diags, - num_upper_diags, input_ptr, output_ptr); - } else { - MatrixBandPartKernel - <<>>( - config.virtual_thread_count, batch_size, m, n, num_lower_diags, - num_upper_diags, input_ptr, output_ptr); - } + MatrixBandPartKernel<<>>( + config.virtual_thread_count, batch_size, m, n, num_lower_diags, + num_upper_diags, input_ptr, output_ptr); } }; diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc index f22bdf0d219..c5b73139bb1 100644 --- a/tensorflow/core/kernels/qr_op_complex128.cc +++ b/tensorflow/core/kernels/qr_op_complex128.cc @@ -19,4 +19,8 @@ namespace tensorflow { REGISTER_LINALG_OP("Qr", (QrOp), complex128); +#if GOOGLE_CUDA +REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu), complex128); +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/qr_op_complex64.cc b/tensorflow/core/kernels/qr_op_complex64.cc index 2d99a856a38..4e14f2639c2 100644 --- a/tensorflow/core/kernels/qr_op_complex64.cc +++ b/tensorflow/core/kernels/qr_op_complex64.cc @@ -19,4 +19,8 @@ namespace tensorflow { REGISTER_LINALG_OP("Qr", (QrOp), complex64); +#if GOOGLE_CUDA +REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu), complex64); +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc index 3873d7fbcf8..51885eb3557 100644 --- a/tensorflow/core/kernels/qr_op_double.cc +++ b/tensorflow/core/kernels/qr_op_double.cc @@ -19,4 +19,8 @@ namespace tensorflow { REGISTER_LINALG_OP("Qr", (QrOp), double); +#if GOOGLE_CUDA +REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu), double); +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc index e23cd5a0d99..d0a1dd42048 100644 --- a/tensorflow/core/kernels/qr_op_float.cc +++ b/tensorflow/core/kernels/qr_op_float.cc @@ -19,4 +19,8 @@ namespace tensorflow { REGISTER_LINALG_OP("Qr", (QrOp), float); +#if GOOGLE_CUDA +REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu), float); +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h index 029ef834808..aea0c552de4 100644 --- a/tensorflow/core/kernels/qr_op_impl.h +++ b/tensorflow/core/kernels/qr_op_impl.h @@ -19,10 +19,16 @@ limitations under the License. // individual kernels. A separate file is used for each instantiated kernel to // improve compilation times. #include +#include + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif #include "third_party/eigen3/Eigen/QR" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/linalg_ops_common.h" #include "tensorflow/core/lib/core/errors.h" @@ -30,6 +36,13 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#if GOOGLE_CUDA +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/kernels/matrix_band_part_op.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#endif + namespace tensorflow { template @@ -107,4 +120,189 @@ class QrOp : public LinearAlgebraOp { TF_DISALLOW_COPY_AND_ASSIGN(QrOp); }; +#if GOOGLE_CUDA + +typedef Eigen::GpuDevice GPUDevice; + +template +class QrOpGpu : public AsyncOpKernel { + public: + explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) final { + const Tensor& input = context->input(0); + const int ndims = input.dims(); + const int64 m = input.dim_size(ndims - 2); + const int64 n = input.dim_size(ndims - 1); + const int64 min_size = std::min(m, n); + const int64 batch_size = + input.template flat_inner_dims().dimension(0); + + // Validate inputs. + OP_REQUIRES_ASYNC( + context, ndims >= 2, + errors::InvalidArgument("Input must have rank >= 2, got ", ndims), + done); + + // Allocate output. + // If full_matrices_ is true then Q is m x m and R is m x n. + // Otherwise, Q is m x min(m, n), and R is min(m, n) x n. + Tensor* q; + TensorShape q_shape = input.shape(); + q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size); + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q), + done); + Tensor* r; + TensorShape r_shape = input.shape(); + r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size); + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r), + done); + + if (input.NumElements() == 0) { + done(); + return; + } + + // Allocate temporaries. + Tensor input_transposed; + TensorShape transposed_shape = input.shape(); + transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1)); + transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2)); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, transposed_shape, + &input_transposed), + done); + + Tensor tau; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({batch_size, min_size}), &tau), + done); + + // Transpose input, since cuSolver uses column-major, while TensorFlow uses + // row-major storage. + std::vector perm(ndims); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[ndims - 2], perm[ndims - 1]); + const GPUDevice& device = context->eigen_device(); + OP_REQUIRES_OK_ASYNC( + context, DoTranspose(device, input, perm, &input_transposed), done); + + // Compute QR decomposition in-place in input_transposed. + CudaSolver solver(context); + std::vector dev_info; + dev_info.emplace_back(context, batch_size, "geqrf"); + auto input_transposed_reshaped = + input_transposed.flat_inner_dims(); + auto tau_matrix = tau.matrix(); + auto r_reshaped = r->flat_inner_dims(); + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver.Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m, + &tau_matrix(batch, 0), + dev_info.back().mutable_data() + batch), + done); + } + + // Generate R. R is equal to the upper triangle of the decomposition + // stored in input_transposed. Crop, transpose (to get back to row-major) + // and copy it to the output buffer. + if (full_matrices_ || m == n) { + OP_REQUIRES_OK_ASYNC( + context, DoTranspose(device, input_transposed, perm, r), done); + } else { + const Scalar alpha(1); + const Scalar beta(0); + const Scalar* dummy = nullptr; + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver.Geam(CUBLAS_OP_T, CUBLAS_OP_N, n, + full_matrices_ ? m : min_size, &alpha, + &input_transposed_reshaped(batch, 0, 0), m, &beta, + dummy, n, &r_reshaped(batch, 0, 0), n), + done); + } + } + // Extract the upper triangle of r (i.e. zero out the strictly lower + // triangle). + functor::MatrixBandPartFunctor band_part; + auto r_reshaped_const = + const_cast(r)->flat_inner_dims(); + band_part(context, device, 0 /* num_lower_diags */, + -1 /* num_upper_diags */, r_reshaped_const, r_reshaped); + + // Generate Q from the decomposition in input_transposed. + if (m != n && (full_matrices_ || m < n)) { + context->CtxFailure( + errors::Unimplemented("The case m != n && (full_matrices_ || m < " + "n) is not currently supported on GPU.")); + done(); + return; + + /* TODO(rmlarsen): FIXME. This branch fails with info != 0 (both + positive and negative) error statuses from ORMQR. + + // Generate full m x m matrix Q by computing the product Q^T * I + // (transpose to get back to row-major form). + functor::EyeFunctor eye; + auto q_reshaped = q->flat_inner_dims(); + eye(device, q_reshaped); + dev_info.emplace_back(context, batch_size, "ormqr"); + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver.Ormqr(CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, m, min_size, + &input_transposed_reshaped(batch, 0, 0), m, + &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m, + dev_info.back().mutable_data() + batch), + done); + } + */ + } else { + // Generate m x n matrix Q. In this case we can use the more efficient + // algorithm in Orgqr to generate Q in place. + dev_info.emplace_back(context, batch_size, "orgqr"); + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver.Orgqr( + m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m, + &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch), + done); + } + OP_REQUIRES_OK_ASYNC( + context, DoTranspose(device, input_transposed, perm, q), done); + } + + // Asynchronously check return status from cuSolver kernels. + TensorReference input_transposed_ref(input_transposed); + TensorReference tau_ref(tau); + auto info_checker = [context, dev_info, input_transposed_ref, tau_ref, + done](const Status& status, + const std::vector& host_infos) { + input_transposed_ref.Unref(); + tau_ref.Unref(); + OP_REQUIRES_OK_ASYNC(context, status, done); + done(); + }; + OP_REQUIRES_OK_ASYNC( + context, + solver.CopyLapackInfoToHostAsync(dev_info, std::move(info_checker)), + done); + } + + private: + bool full_matrices_; + + TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu); +}; + +#endif + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index 7867e0e42d3..6b5becef60b 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -27,6 +27,13 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +def _AddTest(test_class, op_name, testcase_name, fn): + test_name = "_".join(["test", op_name, testcase_name]) + if hasattr(test_class, test_name): + raise RuntimeError("Test %s defined more than once" % test_name) + setattr(test_class, test_name, fn) + + class QrOpTest(test.TestCase): def testWrongDimensions(self): @@ -41,7 +48,7 @@ class QrOpTest(test.TestCase): linalg_ops.qr(vector) -def _GetQrOpTest(dtype_, shape_, use_static_shape_): +def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): is_complex = dtype_ in (np.complex64, np.complex128) is_single = dtype_ in (np.float32, np.complex64) @@ -95,36 +102,41 @@ def _GetQrOpTest(dtype_, shape_, use_static_shape_): low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) - for full_matrices in False, True: - with self.test_session() as sess: - if use_static_shape_: - x_tf = constant_op.constant(x_np) - else: - x_tf = array_ops.placeholder(dtype_) - q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) + # TODO(rmlarsen): Debug failure due to invalid parameter to ORMQR. + rows_ = shape_[-2] + cols_ = shape_[-1] + use_gpu = False if rows_ < cols_ or (full_matrices_ and + rows_ != cols_) else True - if use_static_shape_: - q_tf_val, r_tf_val = sess.run([q_tf, r_tf]) - else: - q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) + with self.test_session(use_gpu=use_gpu) as sess: + if use_static_shape_: + x_tf = constant_op.constant(x_np) + else: + x_tf = array_ops.placeholder(dtype_) + q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_) - q_dims = q_tf_val.shape - np_q = np.ndarray(q_dims, dtype_) - np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) - new_first_dim = np_q_reshape.shape[0] + if use_static_shape_: + q_tf_val, r_tf_val = sess.run([q_tf, r_tf]) + else: + q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) - x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) - for i in range(new_first_dim): - if full_matrices: - np_q_reshape[i,:,:], _ = \ + q_dims = q_tf_val.shape + np_q = np.ndarray(q_dims, dtype_) + np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) + new_first_dim = np_q_reshape.shape[0] + + x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) + for i in range(new_first_dim): + if full_matrices_: + np_q_reshape[i,:,:], _ = \ np.linalg.qr(x_reshape[i,:,:], mode="complete") - else: - np_q_reshape[i,:,:], _ = \ + else: + np_q_reshape[i,:,:], _ = \ np.linalg.qr(x_reshape[i,:,:], mode="reduced") - np_q = np.reshape(np_q_reshape, q_dims) - CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:])) - CheckApproximation(self, x_np, q_tf_val, r_tf_val) - CheckUnitary(self, q_tf_val) + np_q = np.reshape(np_q_reshape, q_dims) + CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:])) + CheckApproximation(self, x_np, q_tf_val, r_tf_val) + CheckUnitary(self, q_tf_val) return Test @@ -133,11 +145,15 @@ if __name__ == "__main__": for dtype in np.float32, np.float64, np.complex64, np.complex128: for rows in 1, 2, 5, 10, 32, 100: for cols in 1, 2, 5, 10, 32, 100: - for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): - shape = batch_dims + (rows, cols) - for use_static_shape in True, False: - name = "%s_%s_%s" % (dtype.__name__, "_".join(map(str, shape)), - use_static_shape) - setattr(QrOpTest, "testQr_" + name, - _GetQrOpTest(dtype, shape, use_static_shape)) + for full_matrices in False, True: + for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): + for use_static_shape in True, False: + shape = batch_dims + (rows, cols) + name = "%s_%s_full_%s_static_%s" % (dtype.__name__, + "_".join(map(str, shape)), + full_matrices, + use_static_shape) + _AddTest(QrOpTest, "Qr", name, + _GetQrOpTest(dtype, shape, full_matrices, + use_static_shape)) test.main()