In SparseCholesky Op, don't permute the output matrix. Because having the Cholesky factor in Lower Triangular form is more useful; even when it satisfies the equation: LLt = PAPt. For e.g. we can then use a simple SparseTriangularSolve to solve the system Ax = b.
PiperOrigin-RevId: 276387322 Change-Id: I42c8a04812c3a8f65145b96f5314113c43f90a2f
This commit is contained in:
parent
23da884d00
commit
f2e7824b78
@ -150,8 +150,7 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
|
|||||||
// lower triangular part of the output CSRSparseMatrix when
|
// lower triangular part of the output CSRSparseMatrix when
|
||||||
// interpreted in row major format.
|
// interpreted in row major format.
|
||||||
sparse_cholesky_factors[batch_index] =
|
sparse_cholesky_factors[batch_index] =
|
||||||
solver.matrixU().twistedBy(permutation);
|
std::move(solver.matrixU());
|
||||||
|
|
||||||
// For now, batch_ptr contains the number of nonzeros in each
|
// For now, batch_ptr contains the number of nonzeros in each
|
||||||
// batch.
|
// batch.
|
||||||
batch_ptr_vec(batch_index + 1) =
|
batch_ptr_vec(batch_index + 1) =
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import map_fn
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -59,6 +60,36 @@ def _swap(a, i, j):
|
|||||||
a[i], a[j] = a[j], a[i]
|
a[i], a[j] = a[j], a[i]
|
||||||
|
|
||||||
|
|
||||||
|
def twist_matrix(matrix, permutation_indices):
|
||||||
|
"""Permute the rows and columns of a 2D or (batched) 3D Tensor."""
|
||||||
|
# Shuffle the rows and columns with the same permutation.
|
||||||
|
if matrix.shape.ndims == 2:
|
||||||
|
# Invert the permutation since `tf.gather` and `tf.gather_nd` need the
|
||||||
|
# mapping from each index `i` to the index that maps to `i`.
|
||||||
|
permutation_indices_inv = array_ops.invert_permutation(permutation_indices)
|
||||||
|
matrix = array_ops.gather(matrix, permutation_indices_inv, axis=0)
|
||||||
|
matrix = array_ops.gather(matrix, permutation_indices_inv, axis=1)
|
||||||
|
elif matrix.shape.ndims == 3:
|
||||||
|
permutation_indices_inv = map_fn.map_fn(array_ops.invert_permutation,
|
||||||
|
permutation_indices)
|
||||||
|
# For 3D Tensors, it's easy to shuffle the rows but not the columns. We
|
||||||
|
# permute the rows, transpose, permute the rows again, and transpose back.
|
||||||
|
batch_size = matrix.shape[0]
|
||||||
|
batch_indices = array_ops.broadcast_to(
|
||||||
|
math_ops.range(batch_size)[:, None], permutation_indices.shape)
|
||||||
|
for _ in range(2):
|
||||||
|
matrix = array_ops.gather_nd(
|
||||||
|
matrix,
|
||||||
|
array_ops.stack([batch_indices, permutation_indices_inv], axis=-1))
|
||||||
|
# Transpose the matrix, or equivalently, swap dimensions 1 and 2.
|
||||||
|
matrix = array_ops.transpose(matrix, perm=[0, 2, 1])
|
||||||
|
else:
|
||||||
|
raise ValueError("Input matrix must have rank 2 or 3. Got: {}".format(
|
||||||
|
matrix.shape.ndims))
|
||||||
|
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
class CSRSparseMatrixOpsTest(test.TestCase):
|
class CSRSparseMatrixOpsTest(test.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1088,6 +1119,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
|||||||
# Compute L * Lh where L is the Sparse Cholesky factor.
|
# Compute L * Lh where L is the Sparse Cholesky factor.
|
||||||
verification = math_ops.matmul(
|
verification = math_ops.matmul(
|
||||||
dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
|
dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True))
|
||||||
|
verification = twist_matrix(verification, ordering_amd)
|
||||||
# Assert that input matrix A satisfies A = L * Lh.
|
# Assert that input matrix A satisfies A = L * Lh.
|
||||||
verification_values = self.evaluate(verification)
|
verification_values = self.evaluate(verification)
|
||||||
full_dense_matrix = (
|
full_dense_matrix = (
|
||||||
@ -1141,6 +1173,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
|||||||
verification = math_ops.matmul(
|
verification = math_ops.matmul(
|
||||||
dense_cholesky,
|
dense_cholesky,
|
||||||
array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
|
array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True))
|
||||||
|
verification = twist_matrix(verification, ordering_amd)
|
||||||
|
|
||||||
verification_values = self.evaluate(verification)
|
verification_values = self.evaluate(verification)
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
@ -1180,6 +1213,7 @@ class CSRSparseMatrixOpsTest(test.TestCase):
|
|||||||
# Compute L * Lh.
|
# Compute L * Lh.
|
||||||
verification = math_ops.matmul(
|
verification = math_ops.matmul(
|
||||||
dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
|
dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1]))
|
||||||
|
verification = twist_matrix(verification, ordering_amd)
|
||||||
verification_values = self.evaluate(verification)
|
verification_values = self.evaluate(verification)
|
||||||
self.assertAllClose(dense_matrix, verification_values, atol=1e-5, rtol=1e-5)
|
self.assertAllClose(dense_matrix, verification_values, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user