From f2e7824b787ce5e117b8f33761708af641366d19 Mon Sep 17 00:00:00 2001 From: Anudhyan Boral Date: Wed, 23 Oct 2019 17:39:38 -0700 Subject: [PATCH] In SparseCholesky Op, don't permute the output matrix. Because having the Cholesky factor in Lower Triangular form is more useful; even when it satisfies the equation: LLt = PAPt. For e.g. we can then use a simple SparseTriangularSolve to solve the system Ax = b. PiperOrigin-RevId: 276387322 Change-Id: I42c8a04812c3a8f65145b96f5314113c43f90a2f --- .../core/kernels/sparse/sparse_cholesky_op.cc | 3 +- .../sparse_csr_matrix_ops_test.py | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc index bd62fa2a296..3786033c98c 100644 --- a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc @@ -150,8 +150,7 @@ class CSRSparseCholeskyCPUOp : public OpKernel { // lower triangular part of the output CSRSparseMatrix when // interpreted in row major format. sparse_cholesky_factors[batch_index] = - solver.matrixU().twistedBy(permutation); - + std::move(solver.matrixU()); // For now, batch_ptr contains the number of nonzeros in each // batch. batch_ptr_vec(batch_index + 1) = diff --git a/tensorflow/python/kernel_tests/sparse_csr_matrix_ops_test.py b/tensorflow/python/kernel_tests/sparse_csr_matrix_ops_test.py index 6caa8e23b44..c05e50664b2 100644 --- a/tensorflow/python/kernel_tests/sparse_csr_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_csr_matrix_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops @@ -59,6 +60,36 @@ def _swap(a, i, j): a[i], a[j] = a[j], a[i] +def twist_matrix(matrix, permutation_indices): + """Permute the rows and columns of a 2D or (batched) 3D Tensor.""" + # Shuffle the rows and columns with the same permutation. + if matrix.shape.ndims == 2: + # Invert the permutation since `tf.gather` and `tf.gather_nd` need the + # mapping from each index `i` to the index that maps to `i`. + permutation_indices_inv = array_ops.invert_permutation(permutation_indices) + matrix = array_ops.gather(matrix, permutation_indices_inv, axis=0) + matrix = array_ops.gather(matrix, permutation_indices_inv, axis=1) + elif matrix.shape.ndims == 3: + permutation_indices_inv = map_fn.map_fn(array_ops.invert_permutation, + permutation_indices) + # For 3D Tensors, it's easy to shuffle the rows but not the columns. We + # permute the rows, transpose, permute the rows again, and transpose back. + batch_size = matrix.shape[0] + batch_indices = array_ops.broadcast_to( + math_ops.range(batch_size)[:, None], permutation_indices.shape) + for _ in range(2): + matrix = array_ops.gather_nd( + matrix, + array_ops.stack([batch_indices, permutation_indices_inv], axis=-1)) + # Transpose the matrix, or equivalently, swap dimensions 1 and 2. + matrix = array_ops.transpose(matrix, perm=[0, 2, 1]) + else: + raise ValueError("Input matrix must have rank 2 or 3. Got: {}".format( + matrix.shape.ndims)) + + return matrix + + class CSRSparseMatrixOpsTest(test.TestCase): @classmethod @@ -1088,6 +1119,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): # Compute L * Lh where L is the Sparse Cholesky factor. verification = math_ops.matmul( dense_cholesky, array_ops.transpose(dense_cholesky, conjugate=True)) + verification = twist_matrix(verification, ordering_amd) # Assert that input matrix A satisfies A = L * Lh. verification_values = self.evaluate(verification) full_dense_matrix = ( @@ -1141,6 +1173,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): verification = math_ops.matmul( dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1], conjugate=True)) + verification = twist_matrix(verification, ordering_amd) verification_values = self.evaluate(verification) self.assertAllClose( @@ -1180,6 +1213,7 @@ class CSRSparseMatrixOpsTest(test.TestCase): # Compute L * Lh. verification = math_ops.matmul( dense_cholesky, array_ops.transpose(dense_cholesky, perm=[0, 2, 1])) + verification = twist_matrix(verification, ordering_amd) verification_values = self.evaluate(verification) self.assertAllClose(dense_matrix, verification_values, atol=1e-5, rtol=1e-5)