In SparseCholesky Op, don't permute the output matrix. Because having the Cholesky factor in Lower Triangular form is more useful; even when it satisfies the equation: LLt = PAPt. For e.g. we can then use a simple SparseTriangularSolve to solve the system Ax = b.

PiperOrigin-RevId: 276387322
Change-Id: I42c8a04812c3a8f65145b96f5314113c43f90a2f
This commit is contained in:
Anudhyan Boral 2019-10-23 17:39:38 -07:00 committed by TensorFlower Gardener
parent 23da884d00
commit f2e7824b78
2 changed files with 35 additions and 2 deletions

View File

@ -150,8 +150,7 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
// lower triangular part of the output CSRSparseMatrix when
// interpreted in row major format.
sparse_cholesky_factors[batch_index] =
solver.matrixU().twistedBy(permutation);
std::move(solver.matrixU());
// For now, batch_ptr contains the number of nonzeros in each
// batch.
batch_ptr_vec(batch_index + 1) =

View File

@ -35,6 +35,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
@ -59,6 +60,36 @@ def _swap(a, i, j):
a[i], a[j] = a[j], a[i]
def twist_matrix(matrix, permutation_indices):
"""Permute the rows and columns of a 2D or (batched) 3D Tensor."""
# Shuffle the rows and columns with the same permutation.
if matrix.shape.ndims == 2:
# Invert the permutation since `tf.gather` and `tf.gather_nd` need the
# mapping from each index `i` to the index that maps to `i`.
permutation_indices_inv = array_ops.invert_permutation(permutation_indices)
matrix = array_ops.gather(matrix, permutation_indices_inv, axis=0)
matrix = array_ops.gather(matrix, permutation_indices_inv, axis=1)
elif matrix.shape.ndims == 3:
permutation_indices_inv = map_fn.map_fn(array_ops.invert_permutation,
permutation_indices)
# For 3D Tensors, it's easy to shuffle the rows but not the columns. We
# permute the rows, transpose, permute the rows again, and transpose back.
batch_size = matrix.shape[0]
batch_indices = array_ops.broadcast_to(
math_ops.range(batch_size)[:, None], permutation_indices.shape)
for _ in range(2):
matrix = array_ops.gather_nd(
matrix,
array_ops.stack([batch_indices, permutation_indices_inv], axis=-1))
# Transpose the matrix, or equivalently, swap dimensions 1 and 2.
matrix = array_ops.transpose(matrix, perm=[0, 2, 1])
else:
raise ValueError("Input matrix must have rank 2 or 3. Got: {}".format(
matrix.shape.ndims))
return matrix
class CSRSparseMatrixOpsTest(test.TestCase):
@classmethod
@ -1088,6 +1119,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
# Compute L * Lh where L is the Sparse Cholesky factor.
verification = math_ops.matmul(
dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
verification = twist_matrix(verification, ordering_amd)
# Assert that input matrix A satisfies A = L * Lh.
verification_values = self.evaluate(verification)
full_dense_matrix = (
@ -1141,6 +1173,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
verification = math_ops.matmul(
dense_cholesky,
array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
verification = twist_matrix(verification, ordering_amd)
verification_values = self.evaluate(verification)
self.assertAllClose(
@ -1180,6 +1213,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
# Compute L * Lh.
verification = math_ops.matmul(
dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
verification = twist_matrix(verification, ordering_amd)
verification_values = self.evaluate(verification)
self.assertAllClose(dense_matrix, verification_values, atol=1e-5, rtol=1e-5)