STT-tensorflow/tensorflow/python/ops/linalg/adjoint_registrations.py
A. Unique TensorFlower d5e5d996ad Add LinearOperatorHouseholder for Householder transformations.
PiperOrigin-RevId: 244757987
2019-04-22 17:25:30 -07:00

135 lines
5.2 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registrations for LinearOperator.adjoint."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_adjoint
from tensorflow.python.ops.linalg import linear_operator_algebra
from tensorflow.python.ops.linalg import linear_operator_block_diag
from tensorflow.python.ops.linalg import linear_operator_circulant
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_householder
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_kronecker
# By default, return LinearOperatorAdjoint which switched the .matmul
# and .solve methods.
@linear_operator_algebra.RegisterAdjoint(linear_operator.LinearOperator)
def _adjoint_linear_operator(linop):
return linear_operator_adjoint.LinearOperatorAdjoint(
linop,
is_non_singular=linop.is_non_singular,
is_self_adjoint=linop.is_self_adjoint,
is_positive_definite=linop.is_positive_definite,
is_square=linop.is_square)
@linear_operator_algebra.RegisterAdjoint(
linear_operator_adjoint.LinearOperatorAdjoint)
def _adjoint_adjoint_linear_operator(linop):
return linop.operator
@linear_operator_algebra.RegisterAdjoint(
linear_operator_identity.LinearOperatorIdentity)
def _adjoint_identity(identity_operator):
return identity_operator
@linear_operator_algebra.RegisterAdjoint(
linear_operator_identity.LinearOperatorScaledIdentity)
def _adjoint_scaled_identity(identity_operator):
multiplier = identity_operator.multiplier
if multiplier.dtype.is_complex:
multiplier = math_ops.conj(multiplier)
return linear_operator_identity.LinearOperatorScaledIdentity(
num_rows=identity_operator._num_rows, # pylint: disable=protected-access
multiplier=multiplier,
is_non_singular=identity_operator.is_non_singular,
is_self_adjoint=identity_operator.is_self_adjoint,
is_positive_definite=identity_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterAdjoint(
linear_operator_diag.LinearOperatorDiag)
def _adjoint_diag(diag_operator):
diag = diag_operator.diag
if diag.dtype.is_complex:
diag = math_ops.conj(diag)
return linear_operator_diag.LinearOperatorDiag(
diag=diag,
is_non_singular=diag_operator.is_non_singular,
is_self_adjoint=diag_operator.is_self_adjoint,
is_positive_definite=diag_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterAdjoint(
linear_operator_block_diag.LinearOperatorBlockDiag)
def _adjoint_block_diag(block_diag_operator):
# We take the adjoint of each block on the diagonal.
return linear_operator_block_diag.LinearOperatorBlockDiag(
operators=[
operator.adjoint() for operator in block_diag_operator.operators],
is_non_singular=block_diag_operator.is_non_singular,
is_self_adjoint=block_diag_operator.is_self_adjoint,
is_positive_definite=block_diag_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterAdjoint(
linear_operator_kronecker.LinearOperatorKronecker)
def _adjoint_kronecker(kronecker_operator):
# Adjoint of a Kronecker product is the Kronecker product
# of adjoints.
return linear_operator_kronecker.LinearOperatorKronecker(
operators=[
operator.adjoint() for operator in kronecker_operator.operators],
is_non_singular=kronecker_operator.is_non_singular,
is_self_adjoint=kronecker_operator.is_self_adjoint,
is_positive_definite=kronecker_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterAdjoint(
linear_operator_circulant.LinearOperatorCirculant)
def _adjoint_circulant(circulant_operator):
spectrum = circulant_operator.spectrum
if spectrum.dtype.is_complex:
spectrum = math_ops.conj(spectrum)
# Conjugating the spectrum is sufficient to get the adjoint.
return linear_operator_circulant.LinearOperatorCirculant(
spectrum=spectrum,
is_non_singular=circulant_operator.is_non_singular,
is_self_adjoint=circulant_operator.is_self_adjoint,
is_positive_definite=circulant_operator.is_positive_definite,
is_square=True)
@linear_operator_algebra.RegisterAdjoint(
linear_operator_householder.LinearOperatorHouseholder)
def _adjoint_householder(householder_operator):
return householder_operator