In SparseCholesky Op, don't permute the output matrix. Because having the Cholesky factor in Lower Triangular form is more useful; even when it satisfies the equation: LLt = PAPt. For e.g. we can then use a simple SparseTriangularSolve to solve the system Ax = b.
PiperOrigin-RevId: 276387322 Change-Id: I42c8a04812c3a8f65145b96f5314113c43f90a2f
This commit is contained in:
parent
23da884d00
commit
f2e7824b78
@ -150,8 +150,7 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
|
||||
// lower triangular part of the output CSRSparseMatrix when
|
||||
// interpreted in row major format.
|
||||
sparse_cholesky_factors[batch_index] =
|
||||
solver.matrixU().twistedBy(permutation);
|
||||
|
||||
std::move(solver.matrixU());
|
||||
// For now, batch_ptr contains the number of nonzeros in each
|
||||
// batch.
|
||||
batch_ptr_vec(batch_index + 1) =
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -59,6 +60,36 @@ def _swap(a, i, j):
|
||||
a[i], a[j] = a[j], a[i]
|
||||
|
||||
|
||||
def twist_matrix(matrix, permutation_indices):
|
||||
"""Permute the rows and columns of a 2D or (batched) 3D Tensor."""
|
||||
# Shuffle the rows and columns with the same permutation.
|
||||
if matrix.shape.ndims == 2:
|
||||
# Invert the permutation since `tf.gather` and `tf.gather_nd` need the
|
||||
# mapping from each index `i` to the index that maps to `i`.
|
||||
permutation_indices_inv = array_ops.invert_permutation(permutation_indices)
|
||||
matrix = array_ops.gather(matrix, permutation_indices_inv, axis=0)
|
||||
matrix = array_ops.gather(matrix, permutation_indices_inv, axis=1)
|
||||
elif matrix.shape.ndims == 3:
|
||||
permutation_indices_inv = map_fn.map_fn(array_ops.invert_permutation,
|
||||
permutation_indices)
|
||||
# For 3D Tensors, it's easy to shuffle the rows but not the columns. We
|
||||
# permute the rows, transpose, permute the rows again, and transpose back.
|
||||
batch_size = matrix.shape[0]
|
||||
batch_indices = array_ops.broadcast_to(
|
||||
math_ops.range(batch_size)[:, None], permutation_indices.shape)
|
||||
for _ in range(2):
|
||||
matrix = array_ops.gather_nd(
|
||||
matrix,
|
||||
array_ops.stack([batch_indices, permutation_indices_inv], axis=-1))
|
||||
# Transpose the matrix, or equivalently, swap dimensions 1 and 2.
|
||||
matrix = array_ops.transpose(matrix, perm=[0, 2, 1])
|
||||
else:
|
||||
raise ValueError("Input matrix must have rank 2 or 3. Got: {}".format(
|
||||
matrix.shape.ndims))
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
@ -1088,6 +1119,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
# Compute L * Lh where L is the Sparse Cholesky factor.
|
||||
verification = math_ops.matmul(
|
||||
dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
|
||||
verification = twist_matrix(verification, ordering_amd)
|
||||
# Assert that input matrix A satisfies A = L * Lh.
|
||||
verification_values = self.evaluate(verification)
|
||||
full_dense_matrix = (
|
||||
@ -1141,6 +1173,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
verification = math_ops.matmul(
|
||||
dense_cholesky,
|
||||
array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
|
||||
verification = twist_matrix(verification, ordering_amd)
|
||||
|
||||
verification_values = self.evaluate(verification)
|
||||
self.assertAllClose(
|
||||
@ -1180,6 +1213,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
||||
# Compute L * Lh.
|
||||
verification = math_ops.matmul(
|
||||
dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
|
||||
verification = twist_matrix(verification, ordering_amd)
|
||||
verification_values = self.evaluate(verification)
|
||||
self.assertAllClose(dense_matrix, verification_values, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user