Also register `linop.solve(identity_linop) = linop.inverse()`. This is useful for families like ScaledIdentity that are closed under inversion. PiperOrigin-RevId: 272530839
219 lines
8.7 KiB
Python
219 lines
8.7 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Registrations for LinearOperator.matmul."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.ops.linalg import linear_operator
|
|
from tensorflow.python.ops.linalg import linear_operator_algebra
|
|
from tensorflow.python.ops.linalg import linear_operator_circulant
|
|
from tensorflow.python.ops.linalg import linear_operator_composition
|
|
from tensorflow.python.ops.linalg import linear_operator_diag
|
|
from tensorflow.python.ops.linalg import linear_operator_identity
|
|
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
|
from tensorflow.python.ops.linalg import linear_operator_zeros
|
|
from tensorflow.python.ops.linalg import registrations_util
|
|
|
|
|
|
# By default, use a LinearOperatorComposition to delay the computation.
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator.LinearOperator, linear_operator.LinearOperator)
|
|
def _matmul_linear_operator(linop_a, linop_b):
|
|
"""Generic matmul of two `LinearOperator`s."""
|
|
is_square = registrations_util.is_square(linop_a, linop_b)
|
|
is_non_singular = None
|
|
is_self_adjoint = None
|
|
is_positive_definite = None
|
|
|
|
if is_square:
|
|
is_non_singular = registrations_util.combined_non_singular_hint(
|
|
linop_a, linop_b)
|
|
elif is_square is False: # pylint:disable=g-bool-id-comparison
|
|
is_non_singular = False
|
|
is_self_adjoint = False
|
|
is_positive_definite = False
|
|
|
|
return linear_operator_composition.LinearOperatorComposition(
|
|
operators=[linop_a, linop_b],
|
|
is_non_singular=is_non_singular,
|
|
is_self_adjoint=is_self_adjoint,
|
|
is_positive_definite=is_positive_definite,
|
|
is_square=is_square,
|
|
)
|
|
|
|
# Identity
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_identity.LinearOperatorIdentity,
|
|
linear_operator.LinearOperator)
|
|
def _matmul_linear_operator_identity_left(identity, linop):
|
|
del identity
|
|
return linop
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator.LinearOperator,
|
|
linear_operator_identity.LinearOperatorIdentity)
|
|
def _matmul_linear_operator_identity_right(linop, identity):
|
|
del identity
|
|
return linop
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_identity.LinearOperatorScaledIdentity,
|
|
linear_operator_identity.LinearOperatorScaledIdentity)
|
|
def _matmul_linear_operator_scaled_identity(linop_a, linop_b):
|
|
"""Matmul of two ScaledIdentity `LinearOperators`."""
|
|
return linear_operator_identity.LinearOperatorScaledIdentity(
|
|
num_rows=linop_a.domain_dimension_tensor(),
|
|
multiplier=linop_a.multiplier * linop_b.multiplier,
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_a, linop_b),
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_a, linop_b),
|
|
is_positive_definite=(
|
|
registrations_util.combined_commuting_positive_definite_hint(
|
|
linop_a, linop_b)),
|
|
is_square=True)
|
|
|
|
|
|
# Zeros
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator.LinearOperator,
|
|
linear_operator_zeros.LinearOperatorZeros)
|
|
def _matmul_linear_operator_zeros_right(linop, zeros):
|
|
if not zeros.is_square or not linop.is_square:
|
|
raise ValueError("Matmul with non-square `LinearOperator`s or non-square "
|
|
"`LinearOperatorZeros` not supported at this time.")
|
|
return zeros
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_zeros.LinearOperatorZeros,
|
|
linear_operator.LinearOperator)
|
|
def _matmul_linear_operator_zeros_left(zeros, linop):
|
|
if not zeros.is_square or not linop.is_square:
|
|
raise ValueError("Matmul with non-square `LinearOperator`s or non-square "
|
|
"`LinearOperatorZeros` not supported at this time.")
|
|
return zeros
|
|
|
|
|
|
# Diag.
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_diag.LinearOperatorDiag,
|
|
linear_operator_diag.LinearOperatorDiag)
|
|
def _matmul_linear_operator_diag(linop_a, linop_b):
|
|
return linear_operator_diag.LinearOperatorDiag(
|
|
diag=linop_a.diag * linop_b.diag,
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_a, linop_b),
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_a, linop_b),
|
|
is_positive_definite=(
|
|
registrations_util.combined_commuting_positive_definite_hint(
|
|
linop_a, linop_b)),
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_diag.LinearOperatorDiag,
|
|
linear_operator_identity.LinearOperatorScaledIdentity)
|
|
def _matmul_linear_operator_diag_scaled_identity_right(
|
|
linop_diag, linop_scaled_identity):
|
|
return linear_operator_diag.LinearOperatorDiag(
|
|
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_diag, linop_scaled_identity),
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_diag, linop_scaled_identity),
|
|
is_positive_definite=(
|
|
registrations_util.combined_commuting_positive_definite_hint(
|
|
linop_diag, linop_scaled_identity)),
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_identity.LinearOperatorScaledIdentity,
|
|
linear_operator_diag.LinearOperatorDiag)
|
|
def _matmul_linear_operator_diag_scaled_identity_left(
|
|
linop_scaled_identity, linop_diag):
|
|
return linear_operator_diag.LinearOperatorDiag(
|
|
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_diag, linop_scaled_identity),
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_diag, linop_scaled_identity),
|
|
is_positive_definite=(
|
|
registrations_util.combined_commuting_positive_definite_hint(
|
|
linop_diag, linop_scaled_identity)),
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_diag.LinearOperatorDiag,
|
|
linear_operator_lower_triangular.LinearOperatorLowerTriangular)
|
|
def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
|
|
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
|
tril=linop_diag.diag[..., None] * linop_triangular.to_dense(),
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_diag, linop_triangular),
|
|
# This is safe to do since the Triangular matrix is only self-adjoint
|
|
# when it is a diagonal matrix, and hence commutes.
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_diag, linop_triangular),
|
|
is_positive_definite=None,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_lower_triangular.LinearOperatorLowerTriangular,
|
|
linear_operator_diag.LinearOperatorDiag)
|
|
def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
|
|
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
|
tril=linop_triangular.to_dense() * linop_diag.diag,
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_diag, linop_triangular),
|
|
# This is safe to do since the Triangular matrix is only self-adjoint
|
|
# when it is a diagonal matrix, and hence commutes.
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_diag, linop_triangular),
|
|
is_positive_definite=None,
|
|
is_square=True)
|
|
|
|
# Circulant.
|
|
|
|
|
|
@linear_operator_algebra.RegisterMatmul(
|
|
linear_operator_circulant.LinearOperatorCirculant,
|
|
linear_operator_circulant.LinearOperatorCirculant)
|
|
def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
|
|
return linear_operator_circulant.LinearOperatorCirculant(
|
|
spectrum=linop_a.spectrum * linop_b.spectrum,
|
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
|
linop_a, linop_b),
|
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
|
linop_a, linop_b),
|
|
is_positive_definite=(
|
|
registrations_util.combined_commuting_positive_definite_hint(
|
|
linop_a, linop_b)),
|
|
is_square=True)
|