Add support for solvers based on the Nvidia cuSolver library.

Implement a GPU version of tf.cholesky as a first example.
Change: 152756373
This commit is contained in:
A. Unique TensorFlower 2017-04-10 16:17:44 -08:00 committed by TensorFlower Gardener
parent 2f5fde8dd9
commit 72c023d396
11 changed files with 546 additions and 22 deletions

View File

@ -47,7 +47,7 @@ class OperatorPDFullTest(test.TestCase):
operator = operator_pd_full.OperatorPDFull(matrix, verify_pd=True)
# Could fail inside Cholesky decomposition, or later when we test the
# diag.
with self.assertRaisesOpError("x > 0|LLT"):
with self.assertRaisesOpError("x > 0|Cholesky"):
operator.to_dense().eval()
def testNonSymmetricMatrixRaises(self):

View File

@ -339,7 +339,7 @@ class WishartCholeskyTest(test.TestCase):
chol_scale_deferred: chol_scale})
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"LLT decomposition was not successful"):
"Cholesky decomposition was not successful"):
chol_w = distributions.WishartFull(
df=df_deferred, scale=chol_scale_deferred)
# np.ones((3, 3)) is not positive, definite.

View File

@ -41,6 +41,7 @@ load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
config_setting(
# Add "--define tensorflow_xsmm=1" to your build command to use libxsmm for
@ -1972,6 +1973,23 @@ cc_library(
],
)
tf_kernel_library(
name = "cuda_solvers",
srcs = ["cuda_solvers.cc"],
hdrs = ["cuda_solvers.h"],
# @local_config_cuda//cuda:cusolver, //third_party/eigen3:blas,
# and //third_party/libf2c all contain various parts of BLAS, LAPACK,
# and f2c helper functions in global namespace. Tell the compiler to
# allow multiple definitions when linking this.
linkopts = ["-Wl,-zmuldefs"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@local_config_cuda//cuda:cusolver",
],
)
LINALG_DEPS = [
":linalg_ops_common",
"//third_party/eigen3",
@ -1983,7 +2001,10 @@ LINALG_DEPS = [
tf_kernel_library(
name = "cholesky_op",
prefix = "cholesky_op",
deps = LINALG_DEPS,
deps = if_cuda([
":cuda_solvers",
":matrix_band_part_op",
]) + LINALG_DEPS,
)
tf_kernel_library(

View File

@ -16,31 +16,40 @@ limitations under the License.
// See docs in ../ops/linalg_ops.cc.
// TODO(konstantinos): Enable complex inputs. This will require additional tests
// and OP_REQUIRES.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#include "third_party/eigen3/Eigen/Cholesky"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif
namespace tensorflow {
static const char kErrMsg[] =
"Cholesky decomposition was not successful. The input might not be valid.";
template <class Scalar>
class CholeskyOp : public LinearAlgebraOp<Scalar> {
public:
typedef LinearAlgebraOp<Scalar> Base;
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit CholeskyOp(OpKernelConstruction* context) : Base(context) {}
using Matrix = typename Base::Matrix;
using MatrixMaps = typename Base::MatrixMaps;
using ConstMatrixMap = typename Base::ConstMatrixMap;
using ConstMatrixMaps = typename Base::ConstMatrixMaps;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& input = inputs[0];
@ -60,11 +69,74 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
outputs->at(0) = llt_decomposition.matrixL();
OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success,
errors::InvalidArgument("LLT decomposition was not successful. "
"The input might not be valid."));
errors::InvalidArgument(kErrMsg));
}
};
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void MatrixBandPart<GPUDevice, T>::Compute( \
const GPUDevice& d, Eigen::DenseIndex num_lower, \
Eigen::DenseIndex num_upper, typename TTypes<T, 3>::ConstTensor input, \
typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixBandPart<GPUDevice, T>;
TF_CALL_float(DECLARE_GPU_SPEC);
TF_CALL_double(DECLARE_GPU_SPEC);
} // namespace functor
template <class Scalar>
class CholeskyOpGpu : public LinearAlgebraOp<Scalar> {
public:
INHERIT_LINALG_TYPEDEFS(Scalar);
explicit CholeskyOpGpu(OpKernelConstruction* context) : Base(context) {}
// Copy the lower triangular part of the input matrices to the output and
// set the strictly upper triangular part to zero. We use a pre-existing
// kernel MatrixBandPart to do this for all matrices in the batch at once,
// before we launch each of the Cholesky factorization kernels in parallel.
void BatchPreCompute(OpKernelContext* context, const TensorInputs& inputs,
const TensorShapes& input_matrix_shapes,
const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes) final {
const int n = input_matrix_shapes[0].dim_size(0);
auto input_reshaped = inputs[0]->template flat_inner_dims<Scalar, 3>();
auto output_reshaped = outputs[0]->template flat_inner_dims<Scalar, 3>();
functor::MatrixBandPart<GPUDevice, Scalar>::Compute(
context->eigen_device<GPUDevice>(), n, 0, input_reshaped,
output_reshaped);
}
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
MatrixMaps* outputs) final {
const ConstMatrixMap& input = inputs[0];
const int n = input.rows();
if (n == 0) {
// If X is an empty matrix (0 rows, 0 col), X * X' == X.
// Therefore, we return X.
return;
}
// Launch the Cholesky kernel.
CudaSolverDN cusolver(context);
const Status status = cusolver.potrf(CUBLAS_FILL_MODE_UPPER, n,
outputs->at(0).data(), n, nullptr);
if (!status.ok()) {
LOG(ERROR) << status.ToString();
}
OP_REQUIRES(context, status.ok(), errors::InvalidArgument(kErrMsg));
}
};
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float);
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double);
#endif // GOOGLE_CUDA
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);

View File

@ -0,0 +1,190 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
#ifdef GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_solvers.h"
#include <chrono>
#include <complex>
#include "cuda/include/cublas_v2.h"
#include "cuda/include/cusolverDn.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
namespace {
template <typename Scalar>
class ScratchSpace {
public:
explicit ScratchSpace(OpKernelContext* context, int size) {
TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value,
TensorShape({size}), &scratch_tensor_));
}
Scalar* data() { return scratch_tensor_.template flat<Scalar>().data(); }
private:
Tensor scratch_tensor_;
};
// Type traits to get CUDA complex types from std::complex<>.
template <typename T>
struct CUDAComplexT {
typedef T type;
};
template <>
struct CUDAComplexT<std::complex<float>> {
typedef cuComplex type;
};
template <>
struct CUDAComplexT<std::complex<double>> {
typedef cuDoubleComplex type;
};
// Converts pointers of std::complex<> to pointers of
// cuComplex/cuDoubleComplex. No type conversion for non-complex types.
template <typename T>
inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) {
return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p);
}
template <typename T>
inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) {
return reinterpret_cast<typename CUDAComplexT<T>::type*>(p);
}
// Converts values of std::complex<float/double> to values of
// cuComplex/cuDoubleComplex.
inline cuComplex CUDAComplexValue(std::complex<float> val) {
return {val.real(), val.imag()};
}
inline cuDoubleComplex CUDAComplexValue(std::complex<double> val) {
return {val.real(), val.imag()};
}
} // namespace
#define TF_RETURN_IF_CUSOLVER_ERROR_MSG(expr, msg) \
do { \
auto status = (expr); \
if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
return errors::Internal(msg); \
} \
} while (0)
#define TF_RETURN_IF_CUSOLVER_ERROR(expr) \
TF_RETURN_IF_CUSOLVER_ERROR_MSG(expr, "cuSolverDN call failed.")
#define TF_RETURN_STATUS_FROM_INFO(method, device_info_ptr, info_ptr) \
do { \
int local_info; \
TF_RETURN_IF_ERROR(GetInfo(device_info_ptr, &local_info)); \
if (info_ptr != nullptr) *info_ptr = local_info; \
if (TF_PREDICT_FALSE(local_info != 0)) { \
return errors::Internal("cuSolverDN::" #method " returned info = ", \
local_info, ", expected info = 0"); \
} else { \
return Status::OK(); \
} \
} while (0)
CudaSolverDN::CudaSolverDN(OpKernelContext* context) : context_(context) {
const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
->stream()
->implementation()
->CudaStreamMemberHack()));
cuda_stream_ = *cu_stream_ptr;
CHECK(cusolverDnCreate(&handle_) == CUSOLVER_STATUS_SUCCESS)
<< "Failed to create cuSolverDN instance.";
CHECK(cusolverDnSetStream(handle_, cuda_stream_) == CUSOLVER_STATUS_SUCCESS)
<< "Failed to set cuSolverDN stream.";
}
CudaSolverDN::~CudaSolverDN() {
CHECK(cusolverDnDestroy(handle_) == CUSOLVER_STATUS_SUCCESS)
<< "Failed to destroy cuSolverDN instance.";
}
Status CudaSolverDN::GetInfo(const int* dev_info, int* host_info) const {
CHECK(dev_info != nullptr);
CHECK(host_info != nullptr);
auto stream = context_->op_device_context()->stream();
perftools::gputools::DeviceMemoryBase wrapped(const_cast<int*>(dev_info));
if (!stream
->ThenMemcpy(host_info /* destination */, wrapped /* source */,
sizeof(int))
.ok()) {
return errors::Internal("Failed to copy dev_info to host.");
}
BlockingCounter barrier(1);
context_->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, [&barrier]() { barrier.DecrementCount(); });
if (!barrier.WaitFor(std::chrono::minutes(1))) {
return errors::Internal("Failed to copy dev_info to host within 1 minute.");
}
return Status::OK();
}
// Macro that specializes a solver method for all 4 standard
// numeric types.
#define TF_CALL_LAPACK_TYPES(m) \
m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
// Macros to construct cusolver method names.
#define SOLVER_NAME(method, lapack_prefix) cusolverDn##lapack_prefix##method
#define BUFSIZE_NAME(method, lapack_prefix) \
cusolverDn##lapack_prefix##method##_bufferSize
//=============================================================================
// Wrappers of cuSolverDN computational methods begin here.
//=============================================================================
#define POTRF_INSTANCE(Scalar, lapack_prefix) \
template <> \
Status CudaSolverDN::potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \
int lda, int* info) const { \
/* Get amount of workspace memory required. */ \
int lwork; \
TF_RETURN_IF_CUSOLVER_ERROR(BUFSIZE_NAME(potrf, lapack_prefix)( \
handle_, uplo, n, CUDAComplex(A), lda, &lwork)); \
\
/* Allocate device memory for workspace and info. */ \
ScratchSpace<Scalar> device_workspace(context_, lwork); \
ScratchSpace<int> device_info(context_, 1); \
\
/* Launch the solver kernel. */ \
TF_RETURN_IF_CUSOLVER_ERROR(SOLVER_NAME(potrf, lapack_prefix)( \
handle_, uplo, n, CUDAComplex(A), lda, \
CUDAComplex(device_workspace.data()), lwork, device_info.data())); \
\
/* Get info from device and return status. */ \
TF_RETURN_STATUS_FROM_INFO(potrf, device_info.data(), info); \
return Status::OK(); \
}
TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,178 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
// This header implements CudaSolverDN and CuBlas, which contain templatized
// wrappers of linear algebra solvers in the cuBlas and cuSolverDN libraries
// for use in TensorFlow kernels.
#ifdef GOOGLE_CUDA
#include "cuda/include/cublas_v2.h"
#include "cuda/include/cusolverDn.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// A class that provides a simplified templated API for the solver methods
// in cuSolverDN (http://docs.nvidia.com/cuda/cusolver).
// An object of this class wraps a cuSolverDN instance, and will launch
// kernels on the cuda stream wrapped by the GPU device in the OpKernelContext
// provided to the constructor. The class methods transparently fetch the output
// status of the solvers (a.k.a. the LAPACK "info" output variable) without
// having to manually synchronize the underlying Cuda stream.
class CudaSolverDN {
public:
explicit CudaSolverDN(OpKernelContext* context);
virtual ~CudaSolverDN();
// ====================================================================
// Templated wrappers for cuSolver functions start here.
// Cholesky factorization.
// Computes Cholesky factorization A = L * L^T.
// Returns Status::OK(), if the Cholesky factorization was successful.
// If info is not nullptr it is used to return the potrf info code:
// Returns zero if success, returns -i if the
// i-th parameter is wrong, returns i > 0, if the leading minor of order i is
// not positive definite, see:
// http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
template <typename Scalar>
Status potrf(cublasFillMode_t uplo, int n, Scalar* A, int lda,
int* info) const;
/*
TODO(rmlarsen, volunteers): Implement the kernels below.
// Uses Cholesky factorization to solve A * X = B.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrs
template <typename Scalar>
Status potrs(cublasFillMode_t uplo, int n, int nrhs, const Scalar* A, int lda,
Scalar* B, int ldb, int* info) const;
// LU factorization.
// Computes LU factorization with partial pivoting P * A = L * U.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
template <typename Scalar>
Status getrf(int m, int n, Scalar* A, int lda, int* devIpiv,
int* devInfo) const;
// Uses LU factorization to solve A * X = B.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
template <typename Scalar>
Status getrs(int n, int nrhs, const Scalar* A, int lda, const int* devIpiv,
Scalar* B, int ldb, int* devInfo) const;
// QR factorization.
// Computes QR factorization A = Q * R.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
template <typename Scalar>
Status geqrf(int m, int n, Scalar* A, int lda, Scalar* TAU, int* devInfo)
const;
// Multiplies by Q.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
template <typename Scalar>
Status mqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, int
k, const Scalar* A, int lda, const Scalar* tau, Scalar* C, int ldc, int*
devInfo const);
// Materializes Q.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
template <typename Scalar>
Status gqr(int m, int n, int k, Scalar* A, int lda, const Scalar* tau,
int* devInfo) const;
// Symmetric/Hermitian Eigen decomposition.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
template <typename Scalar>
Status evd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, Scalar* A,
int lda, Scalar* W, int* devInfo) const;
// Singular value decomposition.
// See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
template <typename Scalar>
Status gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* A,
int lda, Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt,
int* devInfo);
*/
private:
// Copies dev_info status back from the device to host and uses event manager
// to wait (with a timeout) until the copy has finished. Returns an error if
// the copy fails to complete successfully within the timeout period.
Status GetInfo(const int* dev_info, int* host_info) const;
OpKernelContext* context_; // not owned.
cudaStream_t cuda_stream_;
cusolverDnHandle_t handle_;
};
/*
TODO(rmlarsen, volunteers): Implement the kernels below. These are utils and
batched solvers not currently wrapped by stream executor. class CudaBlas {
public:
// Initializes a cuSolverDN handle that will launch kernels using the
// cuda stream wrapped by the GPU device in context.
explicit CudaBlas(OpKernelContext* context);
virtual ~CudaBlas();
// Templatized wrappers for cuBlas functions.
// Matrix addition, copy and transposition.
// See: http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam
template <typename Scalar>
Status geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n,
const Scalar* alpha, const Scalar* A, int lda, const Scalar* beta,
const Scalar* B, int ldb, Scalar* C, int ldc) const;
// Batched LU fatorization.
// See:
http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
template <typename Scalar>
Status getrfBatched(int n, Scalar* Aarray[], int lda, int* PivotArray,
int* infoArray, int batchSize) const;
// Batched linear solver using LU factorization from getrfBatched.
// See:
http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
template <typename Scalar>
Status getrsBatched(cublasOperation_t trans, int n, int nrhs,
const Scalar* Aarray[], int lda, const int* devIpiv,
Scalar* Barray[], int ldb, int* info, int batchSize) const;
// Batched matrix inverse.
// See:
http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched
template <typename Scalar>
Status getriBatched(cublasHandle_t handle, int n, Scalar* Aarray[], int lda,
int* PivotArray, Scalar* Carray[], int ldc, int* infoArray,
int batchSize);
private:
// Copies dev_info status back from the device to host and uses event manager
// to wait (with a timeout) until the copy has finished. Returns an error if
// the copy fails to complete successfully within the timeout period.
Status GetInfo(const int* dev_info, int* host_info) const;
OpKernelContext* context_; // not owned.
cudaStream_t cuda_stream_;
cublasHandle_t handle_;
};
*/
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -95,6 +95,10 @@ void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
&output_matrix_shapes);
// Perform batch-wide pre-computions, if any.
BatchPreCompute(context, inputs, input_matrix_shapes, outputs,
output_matrix_shapes);
// Process the individual matrix problems in parallel using a threadpool.
auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
&output_matrix_shapes, context](int64 begin, int64 end) {
@ -106,6 +110,10 @@ void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
// Perform batch-wide post-computions, if any.
BatchPostCompute(context, inputs, input_matrix_shapes, outputs,
output_matrix_shapes);
}
template <typename Scalar>

View File

@ -130,12 +130,28 @@ class LinearAlgebraOp : public OpKernel {
const ConstMatrixMaps& inputs,
MatrixMaps* outputs) = 0;
private:
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
// This function maps slices (matrices) of the input and output tensors using
// Eigen::Map and calls ComputeMatrix implemented in terms of the
// Hook for doing batch-wide processing before ComputeMatrix is called
// on each individual slice.
virtual void BatchPreCompute(OpKernelContext* context,
const TensorInputs& inputs,
const TensorShapes& input_matrix_shapes,
const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes) {}
// Hook for doing batch-wide processing after ComputeMatrix is called
// on each individual slice.
virtual void BatchPostCompute(OpKernelContext* context,
const TensorInputs& inputs,
const TensorShapes& input_matrix_shapes,
const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes) {}
private:
// This function maps 2-d slices (matrices) of the input and output tensors
// using Eigen::Map and calls ComputeMatrix implemented in terms of the
// Eigen::MatrixBase API by the derived class.
//
// The 'matrix_index' parameter specifies the index of the matrix to be used
@ -176,8 +192,27 @@ extern template class LinearAlgebraOp<complex128>;
} // namespace tensorflow
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
#define INHERIT_LINALG_TYPEDEFS(Scalar) \
typedef LinearAlgebraOp<Scalar> Base; \
using Matrix = typename Base::Matrix; \
using MatrixMap = typename Base::MatrixMap; \
using MatrixMaps = typename Base::MatrixMaps; \
using ConstMatrixMap = typename Base::ConstMatrixMap; \
using ConstMatrixMaps = typename Base::ConstMatrixMaps; \
using TensorShapes = typename Base::TensorShapes; \
using TensorInputs = typename Base::TensorInputs; \
using TensorOutputs = typename Base::TensorOutputs
#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
// Deprecated, use one of the device-specific macros above.
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
#endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_

View File

@ -48,13 +48,15 @@ class CholeskyOpTest(test.TestCase):
def _verifyCholesky(self, x):
# Verify that LL^T == x.
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
chol = linalg_ops.cholesky(x)
verification = math_ops.matmul(chol, chol, adjoint_b=True)
self._verifyCholeskyBase(sess, x, chol, verification)
def testBasic(self):
self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
for dtype in (np.float32, np.float64):
self._verifyCholesky(
np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]).astype(dtype))
def testBatch(self):
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
@ -84,11 +86,12 @@ class CholeskyOpTest(test.TestCase):
with self.assertRaises(ValueError):
linalg_ops.cholesky(tensor3)
def testNotInvertible(self):
def testNotInvertibleCPU(self):
# The input should be invertible.
with self.test_session():
with self.assertRaisesOpError("LLT decomposition was not successful. The"
" input might not be valid."):
with self.test_session(use_gpu=False):
with self.assertRaisesOpError(
"Cholesky decomposition was not successful. The"
" input might not be valid."):
# All rows of the matrix below add to zero
self._verifyCholesky(
np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]]))

View File

@ -90,6 +90,16 @@ cc_library(
visibility = ["//visibility:public"],
)
cc_library(
name = "cusolver",
srcs = ["lib/%{cusolver_lib}"],
data = ["lib/%{cusolver_lib}"],
includes = ["include/"],
linkstatic = 1,
linkopts = ["-lgomp"],
visibility = ["//visibility:public"],
)
cc_library(
name = "cudnn",
srcs = ["lib/%{cudnn_lib}"],

View File

@ -541,6 +541,9 @@ def _find_libs(repository_ctx, cuda_config):
"cublas": _find_cuda_lib(
"cublas", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version),
"cusolver": _find_cuda_lib(
"cusolver", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version),
"curand": _find_cuda_lib(
"curand", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version),
@ -695,6 +698,7 @@ def _create_dummy_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
"%{cudart_lib}": _lib_name("cudart", cpu_value),
"%{cublas_lib}": _lib_name("cublas", cpu_value),
"%{cusolver_lib}": _lib_name("cusolver", cpu_value),
"%{cudnn_lib}": _lib_name("cudnn", cpu_value),
"%{cufft_lib}": _lib_name("cufft", cpu_value),
"%{curand_lib}": _lib_name("curand", cpu_value),
@ -708,6 +712,7 @@ def _create_dummy_repository(repository_ctx):
"%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
"%{cudart_lib}": _lib_name("cudart", cpu_value),
"%{cublas_lib}": _lib_name("cublas", cpu_value),
"%{cusolver_lib}": _lib_name("cusolver", cpu_value),
"%{cudnn_lib}": _lib_name("cudnn", cpu_value),
"%{cufft_lib}": _lib_name("cufft", cpu_value),
"%{curand_lib}": _lib_name("curand", cpu_value),
@ -730,6 +735,7 @@ def _create_dummy_repository(repository_ctx):
repository_ctx.file("cuda/lib/%s" % _lib_name("cudart", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cublas", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cusolver", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cudnn", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("curand", cpu_value))
repository_ctx.file("cuda/lib/%s" % _lib_name("cufft", cpu_value))
@ -822,6 +828,7 @@ def _create_cuda_repository(repository_ctx):
cuda_config.cpu_value),
"%{cudart_lib}": cuda_libs["cudart"].file_name,
"%{cublas_lib}": cuda_libs["cublas"].file_name,
"%{cusolver_lib}": cuda_libs["cusolver"].file_name,
"%{cudnn_lib}": cuda_libs["cudnn"].file_name,
"%{cufft_lib}": cuda_libs["cufft"].file_name,
"%{curand_lib}": cuda_libs["curand"].file_name,