Add inverse registration for LinearOperatorBlockLowerTriangular.

PiperOrigin-RevId: 292365539
Change-Id: I1922eee7119e97af022662a56c7933bab6b4bf36
This commit is contained in:
Emily Fertig 2020-01-30 09:39:23 -08:00 committed by TensorFlower Gardener
parent 3021e0d5d7
commit 8cc57e5815
2 changed files with 118 additions and 16 deletions

View File

@ -156,22 +156,21 @@ class SquareLinearOperatorBlockLowerTriangularTest(
self.assertTrue(operator.is_non_singular)
self.assertFalse(operator.is_self_adjoint)
# TODO(emilyaf): Enable this test when the inverse registration is submitted.
# def test_block_lower_triangular_inverse_type(self):
# matrix = [[1., 0.], [0., 1.]]
# operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
# [[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)],
# [linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True),
# linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)]],
# is_non_singular=True,
# )
# inverse = operator.inverse()
# self.assertIsInstance(
# inverse,
# block_lower_triangular.LinearOperatorBlockLowerTriangular)
# self.assertEqual(2, len(inverse.operators))
# self.assertEqual(1, len(inverse.operators[0]))
# self.assertEqual(2, len(inverse.operators[1]))
def test_block_lower_triangular_inverse_type(self):
matrix = [[1., 0.], [0., 1.]]
operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
[[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)],
[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True),
linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)]],
is_non_singular=True,
)
inverse = operator.inverse()
self.assertIsInstance(
inverse,
block_lower_triangular.LinearOperatorBlockLowerTriangular)
self.assertEqual(2, len(inverse.operators))
self.assertEqual(1, len(inverse.operators[0]))
self.assertEqual(2, len(inverse.operators[1]))
def test_tape_safe(self):
operator_1 = linalg.LinearOperatorFullMatrix(

View File

@ -18,11 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_addition
from tensorflow.python.ops.linalg import linear_operator_algebra
from tensorflow.python.ops.linalg import linear_operator_block_diag
from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular
from tensorflow.python.ops.linalg import linear_operator_circulant
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_full_matrix
from tensorflow.python.ops.linalg import linear_operator_householder
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_inversion
@ -89,6 +93,105 @@ def _inverse_block_diag(block_diag_operator):
is_square=True)
@linear_operator_algebra.RegisterInverse(
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular)
def _inverse_block_lower_triangular(block_lower_triangular_operator):
"""Inverse of LinearOperatorBlockLowerTriangular.
We recursively apply the identity:
```none
|A 0|' = | A' 0|
|B C| |-C'BA' C'|
```
where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse.
This identity can be verified through multiplication:
```none
|A 0|| A' 0|
|B C||-C'BA' C'|
= | AA' 0|
|BA'-CC'BA' CC'|
= |I 0|
|0 I|
```
Args:
block_lower_triangular_operator: Instance of
`LinearOperatorBlockLowerTriangular`.
Returns:
block_lower_triangular_operator_inverse: Instance of
`LinearOperatorBlockLowerTriangular`, the inverse of
`block_lower_triangular_operator`.
"""
if len(block_lower_triangular_operator.operators) == 1:
return (linear_operator_block_lower_triangular.
LinearOperatorBlockLowerTriangular(
[[block_lower_triangular_operator.operators[0][0].inverse()]],
is_non_singular=block_lower_triangular_operator.is_non_singular,
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
is_positive_definite=(block_lower_triangular_operator.
is_positive_definite),
is_square=True))
blockwise_dim = len(block_lower_triangular_operator.operators)
# Calculate the inverse of the `LinearOperatorBlockLowerTriangular`
# representing all but the last row of `block_lower_triangular_operator` with
# a recursive call (the matrix `A'` in the docstring definition).
upper_left_inverse = (
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
block_lower_triangular_operator.operators[:-1]).inverse())
bottom_row = block_lower_triangular_operator.operators[-1]
bottom_right_inverse = bottom_row[-1].inverse()
# Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring
# definition, where `C` is the bottom-right operator of
# `block_lower_triangular_operator` and `B` is the set of operators in the
# bottom row excluding `C`). To find `-C'BA'`, we first iterate over the
# column partitions of `A'`.
inverse_bottom_row = []
for i in range(blockwise_dim - 1):
# Find the `i`-th block of `BA'`.
blocks = []
for j in range(i, blockwise_dim - 1):
result = bottom_row[j].matmul(upper_left_inverse.operators[j][i])
if not any(isinstance(result, op_type)
for op_type in linear_operator_addition.SUPPORTED_OPERATORS):
result = linear_operator_full_matrix.LinearOperatorFullMatrix(
result.to_dense())
blocks.append(result)
summed_blocks = linear_operator_addition.add_operators(blocks)
assert len(summed_blocks) == 1
block = summed_blocks[0]
# Find the `i`-th block of `-C'BA'`.
block = bottom_right_inverse.matmul(block)
block = linear_operator_identity.LinearOperatorScaledIdentity(
num_rows=bottom_right_inverse.domain_dimension_tensor(),
multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block)
inverse_bottom_row.append(block)
# `C'` is the last block of the inverted linear operator.
inverse_bottom_row.append(bottom_right_inverse)
return (
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
upper_left_inverse.operators + [inverse_bottom_row],
is_non_singular=block_lower_triangular_operator.is_non_singular,
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
is_positive_definite=(block_lower_triangular_operator.
is_positive_definite),
is_square=True))
@linear_operator_algebra.RegisterInverse(
linear_operator_kronecker.LinearOperatorKronecker)
def _inverse_kronecker(kronecker_operator):