From 72c023d3967a3218cd3d830ce6e57f7c4d87a18c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Apr 2017 16:17:44 -0800 Subject: [PATCH] Add support for solvers based on the Nvidia cuSolver library. Implement a GPU version of tf.cholesky as a first example. Change: 152756373 --- .../kernel_tests/operator_pd_full_test.py | 2 +- .../python/kernel_tests/wishart_test.py | 2 +- tensorflow/core/kernels/BUILD | 23 ++- tensorflow/core/kernels/cholesky_op.cc | 88 +++++++- tensorflow/core/kernels/cuda_solvers.cc | 190 ++++++++++++++++++ tensorflow/core/kernels/cuda_solvers.h | 178 ++++++++++++++++ tensorflow/core/kernels/linalg_ops_common.cc | 8 + tensorflow/core/kernels/linalg_ops_common.h | 45 ++++- .../python/kernel_tests/cholesky_op_test.py | 15 +- third_party/gpus/cuda/BUILD.tpl | 10 + third_party/gpus/cuda_configure.bzl | 7 + 11 files changed, 546 insertions(+), 22 deletions(-) create mode 100644 tensorflow/core/kernels/cuda_solvers.cc create mode 100644 tensorflow/core/kernels/cuda_solvers.h diff --git a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py index dd59c649e10..35a7c7e6039 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/operator_pd_full_test.py @@ -47,7 +47,7 @@ class OperatorPDFullTest(test.TestCase): operator = operator_pd_full.OperatorPDFull(matrix, verify_pd=True) # Could fail inside Cholesky decomposition, or later when we test the # diag. - with self.assertRaisesOpError("x > 0|LLT"): + with self.assertRaisesOpError("x > 0|Cholesky"): operator.to_dense().eval() def testNonSymmetricMatrixRaises(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index 1fa6ca0906d..d9dc978f23d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -339,7 +339,7 @@ class WishartCholeskyTest(test.TestCase): chol_scale_deferred: chol_scale}) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "LLT decomposition was not successful"): + "Cholesky decomposition was not successful"): chol_w = distributions.WishartFull( df=df_deferred, scale=chol_scale_deferred) # np.ones((3, 3)) is not positive, definite. diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0778821cc84..72136c3ae99 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -41,6 +41,7 @@ load( "//third_party/mkl:build_defs.bzl", "if_mkl", ) +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") config_setting( # Add "--define tensorflow_xsmm=1" to your build command to use libxsmm for @@ -1972,6 +1973,23 @@ cc_library( ], ) +tf_kernel_library( + name = "cuda_solvers", + srcs = ["cuda_solvers.cc"], + hdrs = ["cuda_solvers.h"], + # @local_config_cuda//cuda:cusolver, //third_party/eigen3:blas, + # and //third_party/libf2c all contain various parts of BLAS, LAPACK, + # and f2c helper functions in global namespace. Tell the compiler to + # allow multiple definitions when linking this. + linkopts = ["-Wl,-zmuldefs"], + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@local_config_cuda//cuda:cusolver", + ], +) + LINALG_DEPS = [ ":linalg_ops_common", "//third_party/eigen3", @@ -1983,7 +2001,10 @@ LINALG_DEPS = [ tf_kernel_library( name = "cholesky_op", prefix = "cholesky_op", - deps = LINALG_DEPS, + deps = if_cuda([ + ":cuda_solvers", + ":matrix_band_part_op", + ]) + LINALG_DEPS, ) tf_kernel_library( diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc index e5bf164cfaa..2e33170d27b 100644 --- a/tensorflow/core/kernels/cholesky_op.cc +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -16,31 +16,40 @@ limitations under the License. // See docs in ../ops/linalg_ops.cc. // TODO(konstantinos): Enable complex inputs. This will require additional tests // and OP_REQUIRES. +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA #include "third_party/eigen3/Eigen/Cholesky" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/linalg_ops_common.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#if GOOGLE_CUDA +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/kernels/matrix_band_part_op.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif + namespace tensorflow { +static const char kErrMsg[] = + "Cholesky decomposition was not successful. The input might not be valid."; + template class CholeskyOp : public LinearAlgebraOp { public: - typedef LinearAlgebraOp Base; + INHERIT_LINALG_TYPEDEFS(Scalar); explicit CholeskyOp(OpKernelConstruction* context) : Base(context) {} - using Matrix = typename Base::Matrix; - using MatrixMaps = typename Base::MatrixMaps; - using ConstMatrixMap = typename Base::ConstMatrixMap; - using ConstMatrixMaps = typename Base::ConstMatrixMaps; - void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, MatrixMaps* outputs) final { const ConstMatrixMap& input = inputs[0]; @@ -60,11 +69,74 @@ class CholeskyOp : public LinearAlgebraOp { outputs->at(0) = llt_decomposition.matrixL(); OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success, - errors::InvalidArgument("LLT decomposition was not successful. " - "The input might not be valid.")); + errors::InvalidArgument(kErrMsg)); } }; +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void MatrixBandPart::Compute( \ + const GPUDevice& d, Eigen::DenseIndex num_lower, \ + Eigen::DenseIndex num_upper, typename TTypes::ConstTensor input, \ + typename TTypes::Tensor output); \ + extern template struct MatrixBandPart; + +TF_CALL_float(DECLARE_GPU_SPEC); +TF_CALL_double(DECLARE_GPU_SPEC); +} // namespace functor + +template +class CholeskyOpGpu : public LinearAlgebraOp { + public: + INHERIT_LINALG_TYPEDEFS(Scalar); + + explicit CholeskyOpGpu(OpKernelConstruction* context) : Base(context) {} + + // Copy the lower triangular part of the input matrices to the output and + // set the strictly upper triangular part to zero. We use a pre-existing + // kernel MatrixBandPart to do this for all matrices in the batch at once, + // before we launch each of the Cholesky factorization kernels in parallel. + void BatchPreCompute(OpKernelContext* context, const TensorInputs& inputs, + const TensorShapes& input_matrix_shapes, + const TensorOutputs& outputs, + const TensorShapes& output_matrix_shapes) final { + const int n = input_matrix_shapes[0].dim_size(0); + auto input_reshaped = inputs[0]->template flat_inner_dims(); + auto output_reshaped = outputs[0]->template flat_inner_dims(); + functor::MatrixBandPart::Compute( + context->eigen_device(), n, 0, input_reshaped, + output_reshaped); + } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + const ConstMatrixMap& input = inputs[0]; + const int n = input.rows(); + if (n == 0) { + // If X is an empty matrix (0 rows, 0 col), X * X' == X. + // Therefore, we return X. + return; + } + // Launch the Cholesky kernel. + CudaSolverDN cusolver(context); + const Status status = cusolver.potrf(CUBLAS_FILL_MODE_UPPER, n, + outputs->at(0).data(), n, nullptr); + if (!status.ok()) { + LOG(ERROR) << status.ToString(); + } + OP_REQUIRES(context, status.ok(), errors::InvalidArgument(kErrMsg)); + } +}; + +REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu), float); +REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu), double); + +#endif // GOOGLE_CUDA + REGISTER_LINALG_OP("Cholesky", (CholeskyOp), float); REGISTER_LINALG_OP("Cholesky", (CholeskyOp), double); REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp), float); diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc new file mode 100644 index 00000000000..3e59fb0a5ab --- /dev/null +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +#ifdef GOOGLE_CUDA +#include "tensorflow/core/kernels/cuda_solvers.h" + +#include +#include + +#include "cuda/include/cublas_v2.h" +#include "cuda/include/cusolverDn.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +namespace { + +template +class ScratchSpace { + public: + explicit ScratchSpace(OpKernelContext* context, int size) { + TF_CHECK_OK(context->allocate_temp(DataTypeToEnum::value, + TensorShape({size}), &scratch_tensor_)); + } + Scalar* data() { return scratch_tensor_.template flat().data(); } + + private: + Tensor scratch_tensor_; +}; + +// Type traits to get CUDA complex types from std::complex<>. + +template +struct CUDAComplexT { + typedef T type; +}; + +template <> +struct CUDAComplexT> { + typedef cuComplex type; +}; + +template <> +struct CUDAComplexT> { + typedef cuDoubleComplex type; +}; + +// Converts pointers of std::complex<> to pointers of +// cuComplex/cuDoubleComplex. No type conversion for non-complex types. + +template +inline const typename CUDAComplexT::type* CUDAComplex(const T* p) { + return reinterpret_cast::type*>(p); +} + +template +inline typename CUDAComplexT::type* CUDAComplex(T* p) { + return reinterpret_cast::type*>(p); +} + +// Converts values of std::complex to values of +// cuComplex/cuDoubleComplex. +inline cuComplex CUDAComplexValue(std::complex val) { + return {val.real(), val.imag()}; +} + +inline cuDoubleComplex CUDAComplexValue(std::complex val) { + return {val.real(), val.imag()}; +} +} // namespace + +#define TF_RETURN_IF_CUSOLVER_ERROR_MSG(expr, msg) \ + do { \ + auto status = (expr); \ + if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \ + return errors::Internal(msg); \ + } \ + } while (0) + +#define TF_RETURN_IF_CUSOLVER_ERROR(expr) \ + TF_RETURN_IF_CUSOLVER_ERROR_MSG(expr, "cuSolverDN call failed.") + +#define TF_RETURN_STATUS_FROM_INFO(method, device_info_ptr, info_ptr) \ + do { \ + int local_info; \ + TF_RETURN_IF_ERROR(GetInfo(device_info_ptr, &local_info)); \ + if (info_ptr != nullptr) *info_ptr = local_info; \ + if (TF_PREDICT_FALSE(local_info != 0)) { \ + return errors::Internal("cuSolverDN::" #method " returned info = ", \ + local_info, ", expected info = 0"); \ + } else { \ + return Status::OK(); \ + } \ + } while (0) + +CudaSolverDN::CudaSolverDN(OpKernelContext* context) : context_(context) { + const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + cuda_stream_ = *cu_stream_ptr; + CHECK(cusolverDnCreate(&handle_) == CUSOLVER_STATUS_SUCCESS) + << "Failed to create cuSolverDN instance."; + CHECK(cusolverDnSetStream(handle_, cuda_stream_) == CUSOLVER_STATUS_SUCCESS) + << "Failed to set cuSolverDN stream."; +} + +CudaSolverDN::~CudaSolverDN() { + CHECK(cusolverDnDestroy(handle_) == CUSOLVER_STATUS_SUCCESS) + << "Failed to destroy cuSolverDN instance."; +} + +Status CudaSolverDN::GetInfo(const int* dev_info, int* host_info) const { + CHECK(dev_info != nullptr); + CHECK(host_info != nullptr); + auto stream = context_->op_device_context()->stream(); + perftools::gputools::DeviceMemoryBase wrapped(const_cast(dev_info)); + if (!stream + ->ThenMemcpy(host_info /* destination */, wrapped /* source */, + sizeof(int)) + .ok()) { + return errors::Internal("Failed to copy dev_info to host."); + } + BlockingCounter barrier(1); + context_->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, [&barrier]() { barrier.DecrementCount(); }); + if (!barrier.WaitFor(std::chrono::minutes(1))) { + return errors::Internal("Failed to copy dev_info to host within 1 minute."); + } + return Status::OK(); +} + +// Macro that specializes a solver method for all 4 standard +// numeric types. +#define TF_CALL_LAPACK_TYPES(m) \ + m(float, S) m(double, D) m(std::complex, C) m(std::complex, Z) + +// Macros to construct cusolver method names. +#define SOLVER_NAME(method, lapack_prefix) cusolverDn##lapack_prefix##method +#define BUFSIZE_NAME(method, lapack_prefix) \ + cusolverDn##lapack_prefix##method##_bufferSize + +//============================================================================= +// Wrappers of cuSolverDN computational methods begin here. +//============================================================================= +#define POTRF_INSTANCE(Scalar, lapack_prefix) \ + template <> \ + Status CudaSolverDN::potrf(cublasFillMode_t uplo, int n, Scalar* A, \ + int lda, int* info) const { \ + /* Get amount of workspace memory required. */ \ + int lwork; \ + TF_RETURN_IF_CUSOLVER_ERROR(BUFSIZE_NAME(potrf, lapack_prefix)( \ + handle_, uplo, n, CUDAComplex(A), lda, &lwork)); \ + \ + /* Allocate device memory for workspace and info. */ \ + ScratchSpace device_workspace(context_, lwork); \ + ScratchSpace device_info(context_, 1); \ + \ + /* Launch the solver kernel. */ \ + TF_RETURN_IF_CUSOLVER_ERROR(SOLVER_NAME(potrf, lapack_prefix)( \ + handle_, uplo, n, CUDAComplex(A), lda, \ + CUDAComplex(device_workspace.data()), lwork, device_info.data())); \ + \ + /* Get info from device and return status. */ \ + TF_RETURN_STATUS_FROM_INFO(potrf, device_info.data(), info); \ + return Status::OK(); \ + } + +TF_CALL_LAPACK_TYPES(POTRF_INSTANCE); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h new file mode 100644 index 00000000000..eeb179cfa66 --- /dev/null +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ + +// This header implements CudaSolverDN and CuBlas, which contain templatized +// wrappers of linear algebra solvers in the cuBlas and cuSolverDN libraries +// for use in TensorFlow kernels. + +#ifdef GOOGLE_CUDA + +#include "cuda/include/cublas_v2.h" +#include "cuda/include/cusolverDn.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A class that provides a simplified templated API for the solver methods +// in cuSolverDN (http://docs.nvidia.com/cuda/cusolver). +// An object of this class wraps a cuSolverDN instance, and will launch +// kernels on the cuda stream wrapped by the GPU device in the OpKernelContext +// provided to the constructor. The class methods transparently fetch the output +// status of the solvers (a.k.a. the LAPACK "info" output variable) without +// having to manually synchronize the underlying Cuda stream. +class CudaSolverDN { + public: + explicit CudaSolverDN(OpKernelContext* context); + virtual ~CudaSolverDN(); + + // ==================================================================== + // Templated wrappers for cuSolver functions start here. + + // Cholesky factorization. + // Computes Cholesky factorization A = L * L^T. + // Returns Status::OK(), if the Cholesky factorization was successful. + // If info is not nullptr it is used to return the potrf info code: + // Returns zero if success, returns -i if the + // i-th parameter is wrong, returns i > 0, if the leading minor of order i is + // not positive definite, see: + // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf + template + Status potrf(cublasFillMode_t uplo, int n, Scalar* A, int lda, + int* info) const; + + /* + TODO(rmlarsen, volunteers): Implement the kernels below. + // Uses Cholesky factorization to solve A * X = B. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrs + template + Status potrs(cublasFillMode_t uplo, int n, int nrhs, const Scalar* A, int lda, + Scalar* B, int ldb, int* info) const; + + // LU factorization. + // Computes LU factorization with partial pivoting P * A = L * U. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf + template + Status getrf(int m, int n, Scalar* A, int lda, int* devIpiv, + int* devInfo) const; + + // Uses LU factorization to solve A * X = B. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs + template + Status getrs(int n, int nrhs, const Scalar* A, int lda, const int* devIpiv, + Scalar* B, int ldb, int* devInfo) const; + + // QR factorization. + // Computes QR factorization A = Q * R. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf + template + Status geqrf(int m, int n, Scalar* A, int lda, Scalar* TAU, int* devInfo) + const; + + // Multiplies by Q. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr + template + Status mqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int + k, const Scalar* A, int lda, const Scalar* tau, Scalar* C, int ldc, int* + devInfo const); + + // Materializes Q. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr + template + Status gqr(int m, int n, int k, Scalar* A, int lda, const Scalar* tau, + int* devInfo) const; + + // Symmetric/Hermitian Eigen decomposition. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd + template + Status evd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar* A, + int lda, Scalar* W, int* devInfo) const; + + // Singular value decomposition. + // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd + template + Status gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* A, + int lda, Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, + int* devInfo); +*/ + + private: + // Copies dev_info status back from the device to host and uses event manager + // to wait (with a timeout) until the copy has finished. Returns an error if + // the copy fails to complete successfully within the timeout period. + Status GetInfo(const int* dev_info, int* host_info) const; + + OpKernelContext* context_; // not owned. + cudaStream_t cuda_stream_; + cusolverDnHandle_t handle_; +}; + +/* + TODO(rmlarsen, volunteers): Implement the kernels below. These are utils and +batched solvers not currently wrapped by stream executor. class CudaBlas { + public: + // Initializes a cuSolverDN handle that will launch kernels using the + // cuda stream wrapped by the GPU device in context. + explicit CudaBlas(OpKernelContext* context); + virtual ~CudaBlas(); + + // Templatized wrappers for cuBlas functions. + + // Matrix addition, copy and transposition. + // See: http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam + template + Status geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n, + const Scalar* alpha, const Scalar* A, int lda, const Scalar* beta, + const Scalar* B, int ldb, Scalar* C, int ldc) const; + + // Batched LU fatorization. + // See: +http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched + template + Status getrfBatched(int n, Scalar* Aarray[], int lda, int* PivotArray, + int* infoArray, int batchSize) const; + + // Batched linear solver using LU factorization from getrfBatched. + // See: +http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched + template + Status getrsBatched(cublasOperation_t trans, int n, int nrhs, + const Scalar* Aarray[], int lda, const int* devIpiv, + Scalar* Barray[], int ldb, int* info, int batchSize) const; + + // Batched matrix inverse. + // See: +http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched + template + Status getriBatched(cublasHandle_t handle, int n, Scalar* Aarray[], int lda, + int* PivotArray, Scalar* Carray[], int ldc, int* infoArray, + int batchSize); + + private: + // Copies dev_info status back from the device to host and uses event manager + // to wait (with a timeout) until the copy has finished. Returns an error if + // the copy fails to complete successfully within the timeout period. + Status GetInfo(const int* dev_info, int* host_info) const; + + OpKernelContext* context_; // not owned. + cudaStream_t cuda_stream_; + cublasHandle_t handle_; +}; +*/ + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc index dc001857d7b..a38ddf76eb6 100644 --- a/tensorflow/core/kernels/linalg_ops_common.cc +++ b/tensorflow/core/kernels/linalg_ops_common.cc @@ -95,6 +95,10 @@ void LinearAlgebraOp::Compute(OpKernelContext* context) { PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs, &output_matrix_shapes); + // Perform batch-wide pre-computions, if any. + BatchPreCompute(context, inputs, input_matrix_shapes, outputs, + output_matrix_shapes); + // Process the individual matrix problems in parallel using a threadpool. auto shard = [this, &inputs, &input_matrix_shapes, &outputs, &output_matrix_shapes, context](int64 begin, int64 end) { @@ -106,6 +110,10 @@ void LinearAlgebraOp::Compute(OpKernelContext* context) { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard); + + // Perform batch-wide post-computions, if any. + BatchPostCompute(context, inputs, input_matrix_shapes, outputs, + output_matrix_shapes); } template diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h index ab4142ac932..fcb735bfdb9 100644 --- a/tensorflow/core/kernels/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -130,12 +130,28 @@ class LinearAlgebraOp : public OpKernel { const ConstMatrixMaps& inputs, MatrixMaps* outputs) = 0; - private: using TensorInputs = gtl::InlinedVector; using TensorOutputs = gtl::InlinedVector; - // This function maps slices (matrices) of the input and output tensors using - // Eigen::Map and calls ComputeMatrix implemented in terms of the + // Hook for doing batch-wide processing before ComputeMatrix is called + // on each individual slice. + virtual void BatchPreCompute(OpKernelContext* context, + const TensorInputs& inputs, + const TensorShapes& input_matrix_shapes, + const TensorOutputs& outputs, + const TensorShapes& output_matrix_shapes) {} + + // Hook for doing batch-wide processing after ComputeMatrix is called + // on each individual slice. + virtual void BatchPostCompute(OpKernelContext* context, + const TensorInputs& inputs, + const TensorShapes& input_matrix_shapes, + const TensorOutputs& outputs, + const TensorShapes& output_matrix_shapes) {} + + private: + // This function maps 2-d slices (matrices) of the input and output tensors + // using Eigen::Map and calls ComputeMatrix implemented in terms of the // Eigen::MatrixBase API by the derived class. // // The 'matrix_index' parameter specifies the index of the matrix to be used @@ -176,8 +192,27 @@ extern template class LinearAlgebraOp; } // namespace tensorflow -#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ - REGISTER_KERNEL_BUILDER( \ +#define INHERIT_LINALG_TYPEDEFS(Scalar) \ + typedef LinearAlgebraOp Base; \ + using Matrix = typename Base::Matrix; \ + using MatrixMap = typename Base::MatrixMap; \ + using MatrixMaps = typename Base::MatrixMaps; \ + using ConstMatrixMap = typename Base::ConstMatrixMap; \ + using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ + using TensorShapes = typename Base::TensorShapes; \ + using TensorInputs = typename Base::TensorInputs; \ + using TensorOutputs = typename Base::TensorOutputs + +#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ + REGISTER_KERNEL_BUILDER( \ Name(OpName).Device(DEVICE_CPU).TypeConstraint("T"), OpClass) +#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \ + REGISTER_KERNEL_BUILDER( \ + Name(OpName).Device(DEVICE_GPU).TypeConstraint("T"), OpClass) + +// Deprecated, use one of the device-specific macros above. +#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ + REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) + #endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_ diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index bbe1d052f03..d95200ec92a 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -48,13 +48,15 @@ class CholeskyOpTest(test.TestCase): def _verifyCholesky(self, x): # Verify that LL^T == x. - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: chol = linalg_ops.cholesky(x) verification = math_ops.matmul(chol, chol, adjoint_b=True) self._verifyCholeskyBase(sess, x, chol, verification) def testBasic(self): - self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])) + for dtype in (np.float32, np.float64): + self._verifyCholesky( + np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]).astype(dtype)) def testBatch(self): simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2) @@ -84,11 +86,12 @@ class CholeskyOpTest(test.TestCase): with self.assertRaises(ValueError): linalg_ops.cholesky(tensor3) - def testNotInvertible(self): + def testNotInvertibleCPU(self): # The input should be invertible. - with self.test_session(): - with self.assertRaisesOpError("LLT decomposition was not successful. The" - " input might not be valid."): + with self.test_session(use_gpu=False): + with self.assertRaisesOpError( + "Cholesky decomposition was not successful. The" + " input might not be valid."): # All rows of the matrix below add to zero self._verifyCholesky( np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]])) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index e101f9fbd84..474c972a4c4 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -90,6 +90,16 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "cusolver", + srcs = ["lib/%{cusolver_lib}"], + data = ["lib/%{cusolver_lib}"], + includes = ["include/"], + linkstatic = 1, + linkopts = ["-lgomp"], + visibility = ["//visibility:public"], +) + cc_library( name = "cudnn", srcs = ["lib/%{cudnn_lib}"], diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 05ff584be02..e06092ab4e7 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -541,6 +541,9 @@ def _find_libs(repository_ctx, cuda_config): "cublas": _find_cuda_lib( "cublas", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path, cuda_config.cuda_version), + "cusolver": _find_cuda_lib( + "cusolver", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path, + cuda_config.cuda_version), "curand": _find_cuda_lib( "curand", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path, cuda_config.cuda_version), @@ -695,6 +698,7 @@ def _create_dummy_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), "%{cudart_lib}": _lib_name("cudart", cpu_value), "%{cublas_lib}": _lib_name("cublas", cpu_value), + "%{cusolver_lib}": _lib_name("cusolver", cpu_value), "%{cudnn_lib}": _lib_name("cudnn", cpu_value), "%{cufft_lib}": _lib_name("cufft", cpu_value), "%{curand_lib}": _lib_name("curand", cpu_value), @@ -708,6 +712,7 @@ def _create_dummy_repository(repository_ctx): "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), "%{cudart_lib}": _lib_name("cudart", cpu_value), "%{cublas_lib}": _lib_name("cublas", cpu_value), + "%{cusolver_lib}": _lib_name("cusolver", cpu_value), "%{cudnn_lib}": _lib_name("cudnn", cpu_value), "%{cufft_lib}": _lib_name("cufft", cpu_value), "%{curand_lib}": _lib_name("curand", cpu_value), @@ -730,6 +735,7 @@ def _create_dummy_repository(repository_ctx): repository_ctx.file("cuda/lib/%s" % _lib_name("cudart", cpu_value)) repository_ctx.file("cuda/lib/%s" % _lib_name("cudart_static", cpu_value)) repository_ctx.file("cuda/lib/%s" % _lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/lib/%s" % _lib_name("cusolver", cpu_value)) repository_ctx.file("cuda/lib/%s" % _lib_name("cudnn", cpu_value)) repository_ctx.file("cuda/lib/%s" % _lib_name("curand", cpu_value)) repository_ctx.file("cuda/lib/%s" % _lib_name("cufft", cpu_value)) @@ -822,6 +828,7 @@ def _create_cuda_repository(repository_ctx): cuda_config.cpu_value), "%{cudart_lib}": cuda_libs["cudart"].file_name, "%{cublas_lib}": cuda_libs["cublas"].file_name, + "%{cusolver_lib}": cuda_libs["cusolver"].file_name, "%{cudnn_lib}": cuda_libs["cudnn"].file_name, "%{cufft_lib}": cuda_libs["cufft"].file_name, "%{curand_lib}": cuda_libs["curand"].file_name,