Use inplace Cholesky factorization and solves to speed up and reduce memory usage in matrix_solve_ls.

Check succes before copying outputs in cholesky_op.

PiperOrigin-RevId: 157887564
This commit is contained in:
A. Unique TensorFlower 2017-06-02 16:06:51 -07:00 committed by TensorFlower Gardener
parent a4caeb2ea4
commit 0c92dada6a
2 changed files with 10 additions and 9 deletions

View File

@ -64,11 +64,11 @@ class CholeskyOp : public LinearAlgebraOp<Scalar> {
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
llt_decomposition(input);
// Output the lower triangular in a dense form.
outputs->at(0) = llt_decomposition.matrixL();
OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success,
errors::InvalidArgument(kErrMsg));
// Output the lower triangular in a dense form.
outputs->at(0) = llt_decomposition.matrixL();
}
};

View File

@ -105,18 +105,19 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
// using Cholesky decomposition.
Matrix gramian(cols, cols);
gramian.template triangularView<Eigen::Lower>() =
matrix.transpose() * matrix;
matrix.adjoint() * matrix;
if (l2_regularizer > 0) {
gramian +=
(Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal();
}
const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian);
const Eigen::LLT<Eigen::Ref<Matrix>, Eigen::Lower> llt(gramian);
OP_REQUIRES(
context, llt.info() == Eigen::Success,
errors::InvalidArgument("Input matrix was rank deficient or "
"ill-conditioned. Try setting fast=False "
"or provide a larger l2_regularizer > 0."));
outputs->at(0) = llt.solve(matrix.transpose() * rhs);
outputs->at(0).noalias() = matrix.adjoint() * rhs;
llt.solveInPlace(outputs->at(0));
} else {
// Underdetermined case (rows < cols): Solves the minimum-norm problem
// min ||X||_F^2 s.t. A*X = RHS
@ -125,18 +126,18 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
// using Cholesky decomposition.
Matrix gramian(rows, rows);
gramian.template triangularView<Eigen::Lower>() =
matrix * matrix.transpose();
matrix * matrix.adjoint();
if (l2_regularizer > 0) {
gramian +=
(Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal();
}
const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian);
const Eigen::LLT<Eigen::Ref<Matrix>, Eigen::Lower> llt(gramian);
OP_REQUIRES(
context, llt.info() == Eigen::Success,
errors::InvalidArgument("Input matrix was rank deficient or "
"ill-conditioned. Try setting fast=False "
"or provide an l2_regularizer > 0."));
outputs->at(0) = matrix.transpose() * llt.solve(rhs);
outputs->at(0).noalias() = matrix.adjoint() * llt.solve(rhs);
}
} else {
// Use complete orthogonal decomposition which is backwards stable and