226 lines
9.0 KiB
Python
226 lines
9.0 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Registrations for LinearOperator.inverse."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops.linalg import linear_operator
|
|
from tensorflow.python.ops.linalg import linear_operator_addition
|
|
from tensorflow.python.ops.linalg import linear_operator_algebra
|
|
from tensorflow.python.ops.linalg import linear_operator_block_diag
|
|
from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular
|
|
from tensorflow.python.ops.linalg import linear_operator_circulant
|
|
from tensorflow.python.ops.linalg import linear_operator_diag
|
|
from tensorflow.python.ops.linalg import linear_operator_full_matrix
|
|
from tensorflow.python.ops.linalg import linear_operator_householder
|
|
from tensorflow.python.ops.linalg import linear_operator_identity
|
|
from tensorflow.python.ops.linalg import linear_operator_inversion
|
|
from tensorflow.python.ops.linalg import linear_operator_kronecker
|
|
|
|
|
|
# By default, return LinearOperatorInversion which switched the .matmul
|
|
# and .solve methods.
|
|
@linear_operator_algebra.RegisterInverse(linear_operator.LinearOperator)
|
|
def _inverse_linear_operator(linop):
|
|
return linear_operator_inversion.LinearOperatorInversion(
|
|
linop,
|
|
is_non_singular=linop.is_non_singular,
|
|
is_self_adjoint=linop.is_self_adjoint,
|
|
is_positive_definite=linop.is_positive_definite,
|
|
is_square=linop.is_square)
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_inversion.LinearOperatorInversion)
|
|
def _inverse_inverse_linear_operator(linop_inversion):
|
|
return linop_inversion.operator
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_diag.LinearOperatorDiag)
|
|
def _inverse_diag(diag_operator):
|
|
return linear_operator_diag.LinearOperatorDiag(
|
|
1. / diag_operator.diag,
|
|
is_non_singular=diag_operator.is_non_singular,
|
|
is_self_adjoint=diag_operator.is_self_adjoint,
|
|
is_positive_definite=diag_operator.is_positive_definite,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_identity.LinearOperatorIdentity)
|
|
def _inverse_identity(identity_operator):
|
|
return identity_operator
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_identity.LinearOperatorScaledIdentity)
|
|
def _inverse_scaled_identity(identity_operator):
|
|
return linear_operator_identity.LinearOperatorScaledIdentity(
|
|
num_rows=identity_operator._num_rows, # pylint: disable=protected-access
|
|
multiplier=1. / identity_operator.multiplier,
|
|
is_non_singular=identity_operator.is_non_singular,
|
|
is_self_adjoint=True,
|
|
is_positive_definite=identity_operator.is_positive_definite,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_block_diag.LinearOperatorBlockDiag)
|
|
def _inverse_block_diag(block_diag_operator):
|
|
# We take the inverse of each block on the diagonal.
|
|
return linear_operator_block_diag.LinearOperatorBlockDiag(
|
|
operators=[
|
|
operator.inverse() for operator in block_diag_operator.operators],
|
|
is_non_singular=block_diag_operator.is_non_singular,
|
|
is_self_adjoint=block_diag_operator.is_self_adjoint,
|
|
is_positive_definite=block_diag_operator.is_positive_definite,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular)
|
|
def _inverse_block_lower_triangular(block_lower_triangular_operator):
|
|
"""Inverse of LinearOperatorBlockLowerTriangular.
|
|
|
|
We recursively apply the identity:
|
|
|
|
```none
|
|
|A 0|' = | A' 0|
|
|
|B C| |-C'BA' C'|
|
|
```
|
|
|
|
where `A` is n-by-n, `B` is m-by-n, `C` is m-by-m, and `'` denotes inverse.
|
|
|
|
This identity can be verified through multiplication:
|
|
|
|
```none
|
|
|A 0|| A' 0|
|
|
|B C||-C'BA' C'|
|
|
|
|
= | AA' 0|
|
|
|BA'-CC'BA' CC'|
|
|
|
|
= |I 0|
|
|
|0 I|
|
|
```
|
|
|
|
Args:
|
|
block_lower_triangular_operator: Instance of
|
|
`LinearOperatorBlockLowerTriangular`.
|
|
|
|
Returns:
|
|
block_lower_triangular_operator_inverse: Instance of
|
|
`LinearOperatorBlockLowerTriangular`, the inverse of
|
|
`block_lower_triangular_operator`.
|
|
"""
|
|
if len(block_lower_triangular_operator.operators) == 1:
|
|
return (linear_operator_block_lower_triangular.
|
|
LinearOperatorBlockLowerTriangular(
|
|
[[block_lower_triangular_operator.operators[0][0].inverse()]],
|
|
is_non_singular=block_lower_triangular_operator.is_non_singular,
|
|
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
|
|
is_positive_definite=(block_lower_triangular_operator.
|
|
is_positive_definite),
|
|
is_square=True))
|
|
|
|
blockwise_dim = len(block_lower_triangular_operator.operators)
|
|
|
|
# Calculate the inverse of the `LinearOperatorBlockLowerTriangular`
|
|
# representing all but the last row of `block_lower_triangular_operator` with
|
|
# a recursive call (the matrix `A'` in the docstring definition).
|
|
upper_left_inverse = (
|
|
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
|
block_lower_triangular_operator.operators[:-1]).inverse())
|
|
|
|
bottom_row = block_lower_triangular_operator.operators[-1]
|
|
bottom_right_inverse = bottom_row[-1].inverse()
|
|
|
|
# Find the bottom row of the inverse (equal to `[-C'BA', C']` in the docstring
|
|
# definition, where `C` is the bottom-right operator of
|
|
# `block_lower_triangular_operator` and `B` is the set of operators in the
|
|
# bottom row excluding `C`). To find `-C'BA'`, we first iterate over the
|
|
# column partitions of `A'`.
|
|
inverse_bottom_row = []
|
|
for i in range(blockwise_dim - 1):
|
|
# Find the `i`-th block of `BA'`.
|
|
blocks = []
|
|
for j in range(i, blockwise_dim - 1):
|
|
result = bottom_row[j].matmul(upper_left_inverse.operators[j][i])
|
|
if not any(isinstance(result, op_type)
|
|
for op_type in linear_operator_addition.SUPPORTED_OPERATORS):
|
|
result = linear_operator_full_matrix.LinearOperatorFullMatrix(
|
|
result.to_dense())
|
|
blocks.append(result)
|
|
|
|
summed_blocks = linear_operator_addition.add_operators(blocks)
|
|
assert len(summed_blocks) == 1
|
|
block = summed_blocks[0]
|
|
|
|
# Find the `i`-th block of `-C'BA'`.
|
|
block = bottom_right_inverse.matmul(block)
|
|
block = linear_operator_identity.LinearOperatorScaledIdentity(
|
|
num_rows=bottom_right_inverse.domain_dimension_tensor(),
|
|
multiplier=math_ops.cast(-1, dtype=block.dtype)).matmul(block)
|
|
inverse_bottom_row.append(block)
|
|
|
|
# `C'` is the last block of the inverted linear operator.
|
|
inverse_bottom_row.append(bottom_right_inverse)
|
|
|
|
return (
|
|
linear_operator_block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
|
upper_left_inverse.operators + [inverse_bottom_row],
|
|
is_non_singular=block_lower_triangular_operator.is_non_singular,
|
|
is_self_adjoint=block_lower_triangular_operator.is_self_adjoint,
|
|
is_positive_definite=(block_lower_triangular_operator.
|
|
is_positive_definite),
|
|
is_square=True))
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_kronecker.LinearOperatorKronecker)
|
|
def _inverse_kronecker(kronecker_operator):
|
|
# Inverse decomposition of a Kronecker product is the Kronecker product
|
|
# of inverse decompositions.
|
|
return linear_operator_kronecker.LinearOperatorKronecker(
|
|
operators=[
|
|
operator.inverse() for operator in kronecker_operator.operators],
|
|
is_non_singular=kronecker_operator.is_non_singular,
|
|
is_self_adjoint=kronecker_operator.is_self_adjoint,
|
|
is_positive_definite=kronecker_operator.is_positive_definite,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_circulant.LinearOperatorCirculant)
|
|
def _inverse_circulant(circulant_operator):
|
|
# Inverting the spectrum is sufficient to get the inverse.
|
|
return linear_operator_circulant.LinearOperatorCirculant(
|
|
spectrum=1. / circulant_operator.spectrum,
|
|
is_non_singular=circulant_operator.is_non_singular,
|
|
is_self_adjoint=circulant_operator.is_self_adjoint,
|
|
is_positive_definite=circulant_operator.is_positive_definite,
|
|
is_square=True,
|
|
input_output_dtype=circulant_operator.dtype)
|
|
|
|
|
|
@linear_operator_algebra.RegisterInverse(
|
|
linear_operator_householder.LinearOperatorHouseholder)
|
|
def _inverse_householder(householder_operator):
|
|
return householder_operator
|