Add inverse registration for LinearOperatorBlockLowerTriangular
.
PiperOrigin-RevId: 292365539 Change-Id: I1922eee7119e97af022662a56c7933bab6b4bf36
This commit is contained in:
parent
3021e0d5d7
commit
8cc57e5815
@ -156,22 +156,21 @@ class SquareLinearOperatorBlockLowerTriangularTest(
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertFalse(operator.is_self_adjoint)
|
||||
|
||||
# TODO(emilyaf): Enable this test when the inverse registration is submitted.
|
||||
# def test_block_lower_triangular_inverse_type(self):
|
||||
# matrix = [[1., 0.], [0., 1.]]
|
||||
# operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
||||
# [[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)],
|
||||
# [linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True),
|
||||
# linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)]],
|
||||
# is_non_singular=True,
|
||||
# )
|
||||
# inverse = operator.inverse()
|
||||
# self.assertIsInstance(
|
||||
# inverse,
|
||||
# block_lower_triangular.LinearOperatorBlockLowerTriangular)
|
||||
# self.assertEqual(2, len(inverse.operators))
|
||||
# self.assertEqual(1, len(inverse.operators[0]))
|
||||
# self.assertEqual(2, len(inverse.operators[1]))
|
||||
def test_block_lower_triangular_inverse_type(self):
|
||||
matrix = [[1., 0.], [0., 1.]]
|
||||
operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
||||
[[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)],
|
||||
[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True),
|
||||
linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)]],
|
||||
is_non_singular=True,
|
||||
)
|
||||
inverse = operator.inverse()
|
||||
self.assertIsInstance(
|
||||
inverse,
|
||||
block_lower_triangular.LinearOperatorBlockLowerTriangular)
|
||||
self.assertEqual(2, len(inverse.operators))
|
||||
self.assertEqual(1, len(inverse.operators[0]))
|
||||
self.assertEqual(2, len(inverse.operators[1]))
|
||||
|
||||
def test_tape_safe(self):
|
||||
operator_1 = linalg.LinearOperatorFullMatrix(
|
||||
|
@ -18,11 +18,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_addition
|
||||
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||
from tensorflow.python.ops.linalg import linear_operator_block_diag
|
||||
from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular
|
||||
from tensorflow.python.ops.linalg import linear_operator_circulant
|
||||
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||
from tensorflow.python.ops.linalg import linear_operator_full_matrix
|
||||
from tensorflow.python.ops.linalg import linear_operator_householder
|
||||
from tensorflow.python.ops.linalg import linear_operator_identity
|
||||
from tensorflow.python.ops.linalg import linear_operator_inversion
|
||||
@ -89,6 +93,105 @@ def _inverse_block_diag(block_diag_operator):
|
||||
is_square=True)
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterInverse(
|
||||
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular)
|
||||
def _inverse_block_lower_triangular(block_lower_triangular_operator):
|
||||
"""Inverse of LinearOperatorBlockLowerTriangular.
|
||||
|
||||
We recursively apply the identity:
|
||||
|
||||
```none
|
||||
|A 0|' = | A' 0|
|
||||
|B C| |-C'BA' C'|
|
||||
```
|
||||
|
||||
where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse.
|
||||
|
||||
This identity can be verified through multiplication:
|
||||
|
||||
```none
|
||||
|A 0|| A' 0|
|
||||
|B C||-C'BA' C'|
|
||||
|
||||
= | AA' 0|
|
||||
|BA'-CC'BA' CC'|
|
||||
|
||||
= |I 0|
|
||||
|0 I|
|
||||
```
|
||||
|
||||
Args:
|
||||
block_lower_triangular_operator: Instance of
|
||||
`LinearOperatorBlockLowerTriangular`.
|
||||
|
||||
Returns:
|
||||
block_lower_triangular_operator_inverse: Instance of
|
||||
`LinearOperatorBlockLowerTriangular`, the inverse of
|
||||
`block_lower_triangular_operator`.
|
||||
"""
|
||||
if len(block_lower_triangular_operator.operators) == 1:
|
||||
return (linear_operator_block_lower_triangular.
|
||||
LinearOperatorBlockLowerTriangular(
|
||||
[[block_lower_triangular_operator.operators[0][0].inverse()]],
|
||||
is_non_singular=block_lower_triangular_operator.is_non_singular,
|
||||
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
|
||||
is_positive_definite=(block_lower_triangular_operator.
|
||||
is_positive_definite),
|
||||
is_square=True))
|
||||
|
||||
blockwise_dim = len(block_lower_triangular_operator.operators)
|
||||
|
||||
# Calculate the inverse of the `LinearOperatorBlockLowerTriangular`
|
||||
# representing all but the last row of `block_lower_triangular_operator` with
|
||||
# a recursive call (the matrix `A'` in the docstring definition).
|
||||
upper_left_inverse = (
|
||||
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
||||
block_lower_triangular_operator.operators[:-1]).inverse())
|
||||
|
||||
bottom_row = block_lower_triangular_operator.operators[-1]
|
||||
bottom_right_inverse = bottom_row[-1].inverse()
|
||||
|
||||
# Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring
|
||||
# definition, where `C` is the bottom-right operator of
|
||||
# `block_lower_triangular_operator` and `B` is the set of operators in the
|
||||
# bottom row excluding `C`). To find `-C'BA'`, we first iterate over the
|
||||
# column partitions of `A'`.
|
||||
inverse_bottom_row = []
|
||||
for i in range(blockwise_dim - 1):
|
||||
# Find the `i`-th block of `BA'`.
|
||||
blocks = []
|
||||
for j in range(i, blockwise_dim - 1):
|
||||
result = bottom_row[j].matmul(upper_left_inverse.operators[j][i])
|
||||
if not any(isinstance(result, op_type)
|
||||
for op_type in linear_operator_addition.SUPPORTED_OPERATORS):
|
||||
result = linear_operator_full_matrix.LinearOperatorFullMatrix(
|
||||
result.to_dense())
|
||||
blocks.append(result)
|
||||
|
||||
summed_blocks = linear_operator_addition.add_operators(blocks)
|
||||
assert len(summed_blocks) == 1
|
||||
block = summed_blocks[0]
|
||||
|
||||
# Find the `i`-th block of `-C'BA'`.
|
||||
block = bottom_right_inverse.matmul(block)
|
||||
block = linear_operator_identity.LinearOperatorScaledIdentity(
|
||||
num_rows=bottom_right_inverse.domain_dimension_tensor(),
|
||||
multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block)
|
||||
inverse_bottom_row.append(block)
|
||||
|
||||
# `C'` is the last block of the inverted linear operator.
|
||||
inverse_bottom_row.append(bottom_right_inverse)
|
||||
|
||||
return (
|
||||
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
||||
upper_left_inverse.operators + [inverse_bottom_row],
|
||||
is_non_singular=block_lower_triangular_operator.is_non_singular,
|
||||
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
|
||||
is_positive_definite=(block_lower_triangular_operator.
|
||||
is_positive_definite),
|
||||
is_square=True))
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterInverse(
|
||||
linear_operator_kronecker.LinearOperatorKronecker)
|
||||
def _inverse_kronecker(kronecker_operator):
|
||||
|
Loading…
Reference in New Issue
Block a user