102 lines
4.0 KiB
Python
102 lines
4.0 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Registrations for LinearOperator.cholesky."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.ops import linalg_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops.linalg import linear_operator
|
|
from tensorflow.python.ops.linalg import linear_operator_algebra
|
|
from tensorflow.python.ops.linalg import linear_operator_block_diag
|
|
from tensorflow.python.ops.linalg import linear_operator_diag
|
|
from tensorflow.python.ops.linalg import linear_operator_identity
|
|
from tensorflow.python.ops.linalg import linear_operator_kronecker
|
|
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
|
|
|
|
|
# By default, compute the Cholesky of the dense matrix, and return a
|
|
# LowerTriangular operator. Methods below specialize this registration.
|
|
@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator)
|
|
def _cholesky_linear_operator(linop):
|
|
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
|
linalg_ops.cholesky(linop.to_dense()),
|
|
is_non_singular=True,
|
|
is_self_adjoint=False,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterCholesky(
|
|
linear_operator_diag.LinearOperatorDiag)
|
|
def _cholesky_diag(diag_operator):
|
|
return linear_operator_diag.LinearOperatorDiag(
|
|
math_ops.sqrt(diag_operator.diag),
|
|
is_non_singular=True,
|
|
is_self_adjoint=True,
|
|
is_positive_definite=True,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterCholesky(
|
|
linear_operator_identity.LinearOperatorIdentity)
|
|
def _cholesky_identity(identity_operator):
|
|
return linear_operator_identity.LinearOperatorIdentity(
|
|
num_rows=identity_operator._num_rows, # pylint: disable=protected-access
|
|
batch_shape=identity_operator.batch_shape,
|
|
dtype=identity_operator.dtype,
|
|
is_non_singular=True,
|
|
is_self_adjoint=True,
|
|
is_positive_definite=True,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterCholesky(
|
|
linear_operator_identity.LinearOperatorScaledIdentity)
|
|
def _cholesky_scaled_identity(identity_operator):
|
|
return linear_operator_identity.LinearOperatorScaledIdentity(
|
|
num_rows=identity_operator._num_rows, # pylint: disable=protected-access
|
|
multiplier=math_ops.sqrt(identity_operator.multiplier),
|
|
is_non_singular=True,
|
|
is_self_adjoint=True,
|
|
is_positive_definite=True,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterCholesky(
|
|
linear_operator_block_diag.LinearOperatorBlockDiag)
|
|
def _cholesky_block_diag(block_diag_operator):
|
|
# We take the cholesky of each block on the diagonal.
|
|
return linear_operator_block_diag.LinearOperatorBlockDiag(
|
|
operators=[
|
|
operator.cholesky() for operator in block_diag_operator.operators],
|
|
is_non_singular=True,
|
|
is_self_adjoint=False,
|
|
is_square=True)
|
|
|
|
|
|
@linear_operator_algebra.RegisterCholesky(
|
|
linear_operator_kronecker.LinearOperatorKronecker)
|
|
def _cholesky_kronecker(kronecker_operator):
|
|
# Cholesky decomposition of a Kronecker product is the Kronecker product
|
|
# of cholesky decompositions.
|
|
return linear_operator_kronecker.LinearOperatorKronecker(
|
|
operators=[
|
|
operator.cholesky() for operator in kronecker_operator.operators],
|
|
is_non_singular=True,
|
|
is_self_adjoint=False,
|
|
is_square=True)
|