STT-tensorflow/tensorflow/python/ops/linalg/inverse_registrations.py
Emily Fertig 8cc57e5815 Add inverse registration for LinearOperatorBlockLowerTriangular.
PiperOrigin-RevId: 292365539
Change-Id: I1922eee7119e97af022662a56c7933bab6b4bf36
2020-01-30 09:44:38 -08:00

226 lines
9.0 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registrations for LinearOperator.inverse."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_addition
from tensorflow.python.ops.linalg import linear_operator_algebra
from tensorflow.python.ops.linalg import linear_operator_block_diag
from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular
from tensorflow.python.ops.linalg import linear_operator_circulant
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_full_matrix
from tensorflow.python.ops.linalg import linear_operator_householder
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_inversion
from tensorflow.python.ops.linalg import linear_operator_kronecker
# By default, return LinearOperatorInversion which switched the .matmul
# and .solve methods.
@linear_operator_algebra.RegisterInverse(linear_operator.LinearOperator)
def _inverse_linear_operator(linop):
return linear_operator_inversion.LinearOperatorInversion(
linop,
is_non_singular=linop.is_non_singular,
is_self_adjoint=linop.is_self_adjoint,
is_positive_definite=linop.is_positive_definite,
is_square=linop.is_square)
@linear_operator_algebra.RegisterInverse(
linear_operator_inversion.LinearOperatorInversion)
def _inverse_inverse_linear_operator(linop_inversion):
return linop_inversion.operator
@linear_operator_algebra.RegisterInverse(
linear_operator_diag.LinearOperatorDiag)
def _inverse_diag(diag_operator):
return linear_operator_diag.LinearOperatorDiag(
1. / diag_operator.diag,
is_non_singular=diag_operator.is_non_singular,
is_self_adjoint=diag_operator.is_self_adjoint,
is_positive_definite=diag_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterInverse(
linear_operator_identity.LinearOperatorIdentity)
def _inverse_identity(identity_operator):
return identity_operator
@linear_operator_algebra.RegisterInverse(
linear_operator_identity.LinearOperatorScaledIdentity)
def _inverse_scaled_identity(identity_operator):
return linear_operator_identity.LinearOperatorScaledIdentity(
num_rows=identity_operator._num_rows, # pylint: disable=protected-access
multiplier=1. / identity_operator.multiplier,
is_non_singular=identity_operator.is_non_singular,
is_self_adjoint=True,
is_positive_definite=identity_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterInverse(
linear_operator_block_diag.LinearOperatorBlockDiag)
def _inverse_block_diag(block_diag_operator):
# We take the inverse of each block on the diagonal.
return linear_operator_block_diag.LinearOperatorBlockDiag(
operators=[
operator.inverse() for operator in block_diag_operator.operators],
is_non_singular=block_diag_operator.is_non_singular,
is_self_adjoint=block_diag_operator.is_self_adjoint,
is_positive_definite=block_diag_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterInverse(
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular)
def _inverse_block_lower_triangular(block_lower_triangular_operator):
"""Inverse of LinearOperatorBlockLowerTriangular.
We recursively apply the identity:
```none
|A 0|' = | A' 0|
|B C| |-C'BA' C'|
```
where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse.
This identity can be verified through multiplication:
```none
|A 0|| A' 0|
|B C||-C'BA' C'|
= | AA' 0|
|BA'-CC'BA' CC'|
= |I 0|
|0 I|
```
Args:
block_lower_triangular_operator: Instance of
`LinearOperatorBlockLowerTriangular`.
Returns:
block_lower_triangular_operator_inverse: Instance of
`LinearOperatorBlockLowerTriangular`, the inverse of
`block_lower_triangular_operator`.
"""
if len(block_lower_triangular_operator.operators) == 1:
return (linear_operator_block_lower_triangular.
LinearOperatorBlockLowerTriangular(
[[block_lower_triangular_operator.operators[0][0].inverse()]],
is_non_singular=block_lower_triangular_operator.is_non_singular,
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
is_positive_definite=(block_lower_triangular_operator.
is_positive_definite),
is_square=True))
blockwise_dim = len(block_lower_triangular_operator.operators)
# Calculate the inverse of the `LinearOperatorBlockLowerTriangular`
# representing all but the last row of `block_lower_triangular_operator` with
# a recursive call (the matrix `A'` in the docstring definition).
upper_left_inverse = (
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
block_lower_triangular_operator.operators[:-1]).inverse())
bottom_row = block_lower_triangular_operator.operators[-1]
bottom_right_inverse = bottom_row[-1].inverse()
# Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring
# definition, where `C` is the bottom-right operator of
# `block_lower_triangular_operator` and `B` is the set of operators in the
# bottom row excluding `C`). To find `-C'BA'`, we first iterate over the
# column partitions of `A'`.
inverse_bottom_row = []
for i in range(blockwise_dim - 1):
# Find the `i`-th block of `BA'`.
blocks = []
for j in range(i, blockwise_dim - 1):
result = bottom_row[j].matmul(upper_left_inverse.operators[j][i])
if not any(isinstance(result, op_type)
for op_type in linear_operator_addition.SUPPORTED_OPERATORS):
result = linear_operator_full_matrix.LinearOperatorFullMatrix(
result.to_dense())
blocks.append(result)
summed_blocks = linear_operator_addition.add_operators(blocks)
assert len(summed_blocks) == 1
block = summed_blocks[0]
# Find the `i`-th block of `-C'BA'`.
block = bottom_right_inverse.matmul(block)
block = linear_operator_identity.LinearOperatorScaledIdentity(
num_rows=bottom_right_inverse.domain_dimension_tensor(),
multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block)
inverse_bottom_row.append(block)
# `C'` is the last block of the inverted linear operator.
inverse_bottom_row.append(bottom_right_inverse)
return (
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
upper_left_inverse.operators + [inverse_bottom_row],
is_non_singular=block_lower_triangular_operator.is_non_singular,
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
is_positive_definite=(block_lower_triangular_operator.
is_positive_definite),
is_square=True))
@linear_operator_algebra.RegisterInverse(
linear_operator_kronecker.LinearOperatorKronecker)
def _inverse_kronecker(kronecker_operator):
# Inverse decomposition of a Kronecker product is the Kronecker product
# of inverse decompositions.
return linear_operator_kronecker.LinearOperatorKronecker(
operators=[
operator.inverse() for operator in kronecker_operator.operators],
is_non_singular=kronecker_operator.is_non_singular,
is_self_adjoint=kronecker_operator.is_self_adjoint,
is_positive_definite=kronecker_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterInverse(
linear_operator_circulant.LinearOperatorCirculant)
def _inverse_circulant(circulant_operator):
# Inverting the spectrum is sufficient to get the inverse.
return linear_operator_circulant.LinearOperatorCirculant(
spectrum=1. / circulant_operator.spectrum,
is_non_singular=circulant_operator.is_non_singular,
is_self_adjoint=circulant_operator.is_self_adjoint,
is_positive_definite=circulant_operator.is_positive_definite,
is_square=True,
input_output_dtype=circulant_operator.dtype)
@linear_operator_algebra.RegisterInverse(
linear_operator_householder.LinearOperatorHouseholder)
def _inverse_householder(householder_operator):
return householder_operator