Allow LinearOperator.solve to take a LinearOperator.
PiperOrigin-RevId: 244388120
This commit is contained in:
parent
58f67785f6
commit
0aa8055f1a
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||||
@ -113,6 +115,110 @@ class LinearOperatorAdjointTest(
|
|||||||
|
|
||||||
self.assertEqual("my_operator_adjoint", operator.name)
|
self.assertEqual("my_operator_adjoint", operator.name)
|
||||||
|
|
||||||
|
def test_matmul_adjoint_operator(self):
|
||||||
|
matrix1 = np.random.randn(4, 4)
|
||||||
|
matrix2 = np.random.randn(4, 4)
|
||||||
|
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
||||||
|
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(matrix1, matrix2.T),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(matrix1.T, matrix2),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.matmul(full_matrix2, adjoint=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(matrix1.T, matrix2.T),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.matmul(
|
||||||
|
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
def test_matmul_adjoint_complex_operator(self):
|
||||||
|
matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||||
|
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||||
|
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
||||||
|
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(matrix1, matrix2.conj().T),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(matrix1.conj().T, matrix2),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.matmul(full_matrix2, adjoint=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(matrix1.conj().T, matrix2.conj().T),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.matmul(
|
||||||
|
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
def test_solve_adjoint_operator(self):
|
||||||
|
matrix1 = self.evaluate(
|
||||||
|
linear_operator_test_util.random_tril_matrix(
|
||||||
|
[4, 4], dtype=dtypes.float64, force_well_conditioned=True))
|
||||||
|
matrix2 = np.random.randn(4, 4)
|
||||||
|
full_matrix1 = linalg.LinearOperatorLowerTriangular(
|
||||||
|
matrix1, is_non_singular=True)
|
||||||
|
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(linalg.triangular_solve(matrix1, matrix2.T)),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(
|
||||||
|
linalg.triangular_solve(
|
||||||
|
matrix1.T, matrix2, lower=False)),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.solve(full_matrix2, adjoint=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(
|
||||||
|
linalg.triangular_solve(matrix1.T, matrix2.T, lower=False)),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.solve(
|
||||||
|
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
def test_solve_adjoint_complex_operator(self):
|
||||||
|
matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix(
|
||||||
|
[4, 4], dtype=dtypes.complex128, force_well_conditioned=True) +
|
||||||
|
1j * linear_operator_test_util.random_tril_matrix(
|
||||||
|
[4, 4], dtype=dtypes.complex128,
|
||||||
|
force_well_conditioned=True))
|
||||||
|
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||||
|
|
||||||
|
full_matrix1 = linalg.LinearOperatorLowerTriangular(
|
||||||
|
matrix1, is_non_singular=True)
|
||||||
|
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(linalg.triangular_solve(matrix1, matrix2.conj().T)),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(
|
||||||
|
linalg.triangular_solve(
|
||||||
|
matrix1.conj().T, matrix2, lower=False)),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.solve(full_matrix2, adjoint=True).to_dense()))
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(
|
||||||
|
linalg.triangular_solve(
|
||||||
|
matrix1.conj().T, matrix2.conj().T, lower=False)),
|
||||||
|
self.evaluate(
|
||||||
|
full_matrix1.solve(
|
||||||
|
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||||
|
|
||||||
|
|
||||||
class LinearOperatorAdjointNonSquareTest(
|
class LinearOperatorAdjointNonSquareTest(
|
||||||
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
|
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations # pylint: disab
|
|||||||
from tensorflow.python.ops.linalg import linear_operator
|
from tensorflow.python.ops.linalg import linear_operator
|
||||||
from tensorflow.python.ops.linalg import linear_operator_algebra
|
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||||
from tensorflow.python.ops.linalg import matmul_registrations # pylint: disable=unused-import
|
from tensorflow.python.ops.linalg import matmul_registrations # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.ops.linalg import solve_registrations # pylint: disable=unused-import
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
@ -34,6 +35,8 @@ _INVERSES = linear_operator_algebra._INVERSES
|
|||||||
_registered_inverse = linear_operator_algebra._registered_inverse
|
_registered_inverse = linear_operator_algebra._registered_inverse
|
||||||
_MATMUL = linear_operator_algebra._MATMUL
|
_MATMUL = linear_operator_algebra._MATMUL
|
||||||
_registered_matmul = linear_operator_algebra._registered_matmul
|
_registered_matmul = linear_operator_algebra._registered_matmul
|
||||||
|
_SOLVE = linear_operator_algebra._SOLVE
|
||||||
|
_registered_solve = linear_operator_algebra._registered_solve
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@ -175,6 +178,55 @@ class MatmulTest(test.TestCase):
|
|||||||
self.assertEqual(v, _registered_matmul(k[0], k[1]))
|
self.assertEqual(v, _registered_matmul(k[0], k[1]))
|
||||||
|
|
||||||
|
|
||||||
|
class SolveTest(test.TestCase):
|
||||||
|
|
||||||
|
def testRegistration(self):
|
||||||
|
|
||||||
|
class CustomLinOp(linear_operator.LinearOperator):
|
||||||
|
|
||||||
|
def _matmul(self, a):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _solve(self, a):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _shape(self):
|
||||||
|
return tensor_shape.TensorShape([1, 1])
|
||||||
|
|
||||||
|
def _shape_tensor(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Register Solve to a lambda that spits out the name parameter
|
||||||
|
@linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)
|
||||||
|
def _solve(a, b): # pylint: disable=unused-argument,unused-variable
|
||||||
|
return "OK"
|
||||||
|
|
||||||
|
custom_linop = CustomLinOp(
|
||||||
|
dtype=None, is_self_adjoint=True, is_positive_definite=True)
|
||||||
|
self.assertEqual("OK", custom_linop.solve(custom_linop))
|
||||||
|
|
||||||
|
def testRegistrationFailures(self):
|
||||||
|
|
||||||
|
class CustomLinOp(linear_operator.LinearOperator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, "must be callable"):
|
||||||
|
linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)("blah")
|
||||||
|
|
||||||
|
# First registration is OK
|
||||||
|
linear_operator_algebra.RegisterSolve(
|
||||||
|
CustomLinOp, CustomLinOp)(lambda a: None)
|
||||||
|
|
||||||
|
# Second registration fails
|
||||||
|
with self.assertRaisesRegexp(ValueError, "has already been registered"):
|
||||||
|
linear_operator_algebra.RegisterSolve(
|
||||||
|
CustomLinOp, CustomLinOp)(lambda a: None)
|
||||||
|
|
||||||
|
def testExactSolveRegistrationsAllMatch(self):
|
||||||
|
for (k, v) in _SOLVE.items():
|
||||||
|
self.assertEqual(v, _registered_solve(k[0], k[1]))
|
||||||
|
|
||||||
|
|
||||||
class InverseTest(test.TestCase):
|
class InverseTest(test.TestCase):
|
||||||
|
|
||||||
def testRegistration(self):
|
def testRegistration(self):
|
||||||
|
|||||||
@ -187,6 +187,35 @@ class LinearOperatorDiagTest(
|
|||||||
linalg_lib.LinearOperatorDiag))
|
linalg_lib.LinearOperatorDiag))
|
||||||
self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag))
|
self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag))
|
||||||
|
|
||||||
|
def test_diag_solve(self):
|
||||||
|
operator1 = linalg_lib.LinearOperatorDiag([2., 3.], is_non_singular=True)
|
||||||
|
operator2 = linalg_lib.LinearOperatorDiag([1., 2.], is_non_singular=True)
|
||||||
|
operator3 = linalg_lib.LinearOperatorScaledIdentity(
|
||||||
|
num_rows=2, multiplier=3., is_non_singular=True)
|
||||||
|
operator_solve = operator1.solve(operator2)
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
operator_solve,
|
||||||
|
linalg_lib.LinearOperatorDiag))
|
||||||
|
self.assertAllClose([0.5, 2 / 3.], self.evaluate(operator_solve.diag))
|
||||||
|
|
||||||
|
operator_solve = operator2.solve(operator1)
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
operator_solve,
|
||||||
|
linalg_lib.LinearOperatorDiag))
|
||||||
|
self.assertAllClose([2., 3 / 2.], self.evaluate(operator_solve.diag))
|
||||||
|
|
||||||
|
operator_solve = operator1.solve(operator3)
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
operator_solve,
|
||||||
|
linalg_lib.LinearOperatorDiag))
|
||||||
|
self.assertAllClose([3 / 2., 1.], self.evaluate(operator_solve.diag))
|
||||||
|
|
||||||
|
operator_solve = operator3.solve(operator1)
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
operator_solve,
|
||||||
|
linalg_lib.LinearOperatorDiag))
|
||||||
|
self.assertAllClose([2 / 3., 1.], self.evaluate(operator_solve.diag))
|
||||||
|
|
||||||
def test_diag_adjoint_type(self):
|
def test_diag_adjoint_type(self):
|
||||||
diag = [1., 3., 5., 8.]
|
diag = [1., 3., 5., 8.]
|
||||||
operator = linalg.LinearOperatorDiag(diag, is_non_singular=True)
|
operator = linalg.LinearOperatorDiag(diag, is_non_singular=True)
|
||||||
|
|||||||
@ -495,6 +495,20 @@ class LinearOperatorScaledIdentityTest(
|
|||||||
linalg_lib.LinearOperatorScaledIdentity))
|
linalg_lib.LinearOperatorScaledIdentity))
|
||||||
self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))
|
self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))
|
||||||
|
|
||||||
|
def test_identity_solve(self):
|
||||||
|
operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
|
||||||
|
operator2 = linalg_lib.LinearOperatorScaledIdentity(
|
||||||
|
num_rows=2, multiplier=3.)
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
operator1.solve(operator1),
|
||||||
|
linalg_lib.LinearOperatorIdentity))
|
||||||
|
|
||||||
|
operator_solve = operator1.solve(operator2)
|
||||||
|
self.assertTrue(isinstance(
|
||||||
|
operator_solve,
|
||||||
|
linalg_lib.LinearOperatorScaledIdentity))
|
||||||
|
self.assertAllClose(3., self.evaluate(operator_solve.multiplier))
|
||||||
|
|
||||||
def test_scaled_identity_cholesky_type(self):
|
def test_scaled_identity_cholesky_type(self):
|
||||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||||
num_rows=2,
|
num_rows=2,
|
||||||
|
|||||||
@ -238,7 +238,7 @@ class LinearOperatorTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(operator_matmul.is_square)
|
self.assertTrue(operator_matmul.is_square)
|
||||||
self.assertTrue(operator_matmul.is_non_singular)
|
self.assertTrue(operator_matmul.is_non_singular)
|
||||||
self.assertTrue(operator_matmul.is_self_adjoint)
|
self.assertEqual(None, operator_matmul.is_self_adjoint)
|
||||||
self.assertEqual(None, operator_matmul.is_positive_definite)
|
self.assertEqual(None, operator_matmul.is_positive_definite)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations as _cholesky_reg
|
|||||||
from tensorflow.python.ops.linalg import inverse_registrations as _inverse_registrations
|
from tensorflow.python.ops.linalg import inverse_registrations as _inverse_registrations
|
||||||
from tensorflow.python.ops.linalg import linear_operator_algebra as _linear_operator_algebra
|
from tensorflow.python.ops.linalg import linear_operator_algebra as _linear_operator_algebra
|
||||||
from tensorflow.python.ops.linalg import matmul_registrations as _matmul_registrations
|
from tensorflow.python.ops.linalg import matmul_registrations as _matmul_registrations
|
||||||
|
from tensorflow.python.ops.linalg import solve_registrations as _solve_registrations
|
||||||
from tensorflow.python.ops.linalg.linalg_impl import *
|
from tensorflow.python.ops.linalg.linalg_impl import *
|
||||||
from tensorflow.python.ops.linalg.linear_operator import *
|
from tensorflow.python.ops.linalg.linear_operator import *
|
||||||
from tensorflow.python.ops.linalg.linear_operator_block_diag import *
|
from tensorflow.python.ops.linalg.linear_operator_block_diag import *
|
||||||
|
|||||||
@ -597,16 +597,18 @@ class LinearOperator(object):
|
|||||||
as `self`.
|
as `self`.
|
||||||
"""
|
"""
|
||||||
if isinstance(x, LinearOperator):
|
if isinstance(x, LinearOperator):
|
||||||
if adjoint or adjoint_arg:
|
left_operator = self.adjoint() if adjoint else self
|
||||||
raise ValueError(".matmul not supported with adjoints.")
|
right_operator = x.adjoint() if adjoint_arg else x
|
||||||
if (x.range_dimension is not None and
|
|
||||||
self.domain_dimension is not None and
|
if (right_operator.range_dimension is not None and
|
||||||
x.range_dimension != self.domain_dimension):
|
left_operator.domain_dimension is not None and
|
||||||
|
right_operator.range_dimension != left_operator.domain_dimension):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Operators are incompatible. Expected `x` to have dimension"
|
"Operators are incompatible. Expected `x` to have dimension"
|
||||||
" {} but got {}.".format(self.domain_dimension, x.range_dimension))
|
" {} but got {}.".format(
|
||||||
|
left_operator.domain_dimension, right_operator.range_dimension))
|
||||||
with self._name_scope(name):
|
with self._name_scope(name):
|
||||||
return linear_operator_algebra.matmul(self, x)
|
return linear_operator_algebra.matmul(left_operator, right_operator)
|
||||||
|
|
||||||
with self._name_scope(name, values=[x]):
|
with self._name_scope(name, values=[x]):
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
@ -780,6 +782,20 @@ class LinearOperator(object):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Exact solve not implemented for an operator that is expected to "
|
"Exact solve not implemented for an operator that is expected to "
|
||||||
"not be square.")
|
"not be square.")
|
||||||
|
if isinstance(rhs, LinearOperator):
|
||||||
|
left_operator = self.adjoint() if adjoint else self
|
||||||
|
right_operator = rhs.adjoint() if adjoint_arg else rhs
|
||||||
|
|
||||||
|
if (right_operator.range_dimension is not None and
|
||||||
|
left_operator.domain_dimension is not None and
|
||||||
|
right_operator.range_dimension != left_operator.domain_dimension):
|
||||||
|
raise ValueError(
|
||||||
|
"Operators are incompatible. Expected `rhs` to have dimension"
|
||||||
|
" {} but got {}.".format(
|
||||||
|
left_operator.domain_dimension, right_operator.range_dimension))
|
||||||
|
with self._name_scope(name):
|
||||||
|
return linear_operator_algebra.solve(left_operator, right_operator)
|
||||||
|
|
||||||
with self._name_scope(name, values=[rhs]):
|
with self._name_scope(name, values=[rhs]):
|
||||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||||
self._check_input_dtype(rhs)
|
self._check_input_dtype(rhs)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from tensorflow.python.util import tf_inspect
|
|||||||
_ADJOINTS = {}
|
_ADJOINTS = {}
|
||||||
_CHOLESKY_DECOMPS = {}
|
_CHOLESKY_DECOMPS = {}
|
||||||
_MATMUL = {}
|
_MATMUL = {}
|
||||||
|
_SOLVE = {}
|
||||||
_INVERSES = {}
|
_INVERSES = {}
|
||||||
|
|
||||||
|
|
||||||
@ -62,6 +63,11 @@ def _registered_matmul(type_a, type_b):
|
|||||||
return _registered_function([type_a, type_b], _MATMUL)
|
return _registered_function([type_a, type_b], _MATMUL)
|
||||||
|
|
||||||
|
|
||||||
|
def _registered_solve(type_a, type_b):
|
||||||
|
"""Get the Solve function registered for classes a and b."""
|
||||||
|
return _registered_function([type_a, type_b], _SOLVE)
|
||||||
|
|
||||||
|
|
||||||
def _registered_inverse(type_a):
|
def _registered_inverse(type_a):
|
||||||
"""Get the Cholesky function registered for class a."""
|
"""Get the Cholesky function registered for class a."""
|
||||||
return _registered_function([type_a], _INVERSES)
|
return _registered_function([type_a], _INVERSES)
|
||||||
@ -138,6 +144,31 @@ def matmul(lin_op_a, lin_op_b, name=None):
|
|||||||
return matmul_fn(lin_op_a, lin_op_b)
|
return matmul_fn(lin_op_a, lin_op_b)
|
||||||
|
|
||||||
|
|
||||||
|
def solve(lin_op_a, lin_op_b, name=None):
|
||||||
|
"""Compute lin_op_a.solve(lin_op_b).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lin_op_a: The LinearOperator on the left.
|
||||||
|
lin_op_b: The LinearOperator on the right.
|
||||||
|
name: Name to use for this operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LinearOperator that represents the solve between `lin_op_a` and
|
||||||
|
`lin_op_b`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If no solve method is defined between types of
|
||||||
|
`lin_op_a` and `lin_op_b`.
|
||||||
|
"""
|
||||||
|
solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b))
|
||||||
|
if solve_fn is None:
|
||||||
|
raise ValueError("No solve registered for {}.solve({})".format(
|
||||||
|
type(lin_op_a), type(lin_op_b)))
|
||||||
|
|
||||||
|
with ops.name_scope(name, "Solve"):
|
||||||
|
return solve_fn(lin_op_a, lin_op_b)
|
||||||
|
|
||||||
|
|
||||||
def inverse(lin_op_a, name=None):
|
def inverse(lin_op_a, name=None):
|
||||||
"""Get the Inverse associated to lin_op_a.
|
"""Get the Inverse associated to lin_op_a.
|
||||||
|
|
||||||
@ -291,6 +322,52 @@ class RegisterMatmul(object):
|
|||||||
return matmul_fn
|
return matmul_fn
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterSolve(object):
|
||||||
|
"""Decorator to register a Solve implementation function.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
lin_op.LinearOperatorIdentity,
|
||||||
|
lin_op.LinearOperatorIdentity)
|
||||||
|
def _solve_identity(a, b):
|
||||||
|
# Return the identity matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lin_op_cls_a, lin_op_cls_b):
|
||||||
|
"""Initialize the LinearOperator registrar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lin_op_cls_a: the class of the LinearOperator that is computing solve.
|
||||||
|
lin_op_cls_b: the class of the second LinearOperator to solve.
|
||||||
|
"""
|
||||||
|
self._key = (lin_op_cls_a, lin_op_cls_b)
|
||||||
|
|
||||||
|
def __call__(self, solve_fn):
|
||||||
|
"""Perform the Solve registration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
solve_fn: The function to use for the Solve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
solve_fn
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if solve_fn is not a callable.
|
||||||
|
ValueError: if a Solve function has already been registered for
|
||||||
|
the given argument classes.
|
||||||
|
"""
|
||||||
|
if not callable(solve_fn):
|
||||||
|
raise TypeError(
|
||||||
|
"solve_fn must be callable, received: {}".format(solve_fn))
|
||||||
|
if self._key in _SOLVE:
|
||||||
|
raise ValueError("Solve({}, {}) has already been registered.".format(
|
||||||
|
self._key[0].__name__,
|
||||||
|
self._key[1].__name__))
|
||||||
|
_SOLVE[self._key] = solve_fn
|
||||||
|
return solve_fn
|
||||||
|
|
||||||
|
|
||||||
class RegisterInverse(object):
|
class RegisterInverse(object):
|
||||||
"""Decorator to register an Inverse implementation function.
|
"""Decorator to register an Inverse implementation function.
|
||||||
|
|
||||||
|
|||||||
@ -26,66 +26,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag
|
|||||||
from tensorflow.python.ops.linalg import linear_operator_identity
|
from tensorflow.python.ops.linalg import linear_operator_identity
|
||||||
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
||||||
from tensorflow.python.ops.linalg import linear_operator_zeros
|
from tensorflow.python.ops.linalg import linear_operator_zeros
|
||||||
|
from tensorflow.python.ops.linalg import registrations_util
|
||||||
|
|
||||||
def _combined_self_adjoint_hint(operator_a, operator_b):
|
|
||||||
"""Get combined hint for self-adjoint-ness."""
|
|
||||||
# Note: only use this method in the commuting case.
|
|
||||||
# The property is preserved under composition when the operators commute.
|
|
||||||
if operator_a.is_self_adjoint and operator_b.is_self_adjoint:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# The property is not preserved when an operator with the property is composed
|
|
||||||
# with an operator without the property.
|
|
||||||
if ((operator_a.is_self_adjoint is True and
|
|
||||||
operator_b.is_self_adjoint is False) or
|
|
||||||
(operator_a.is_self_adjoint is False and
|
|
||||||
operator_b.is_self_adjoint is True)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# The property is not known when operators are not known to have the property
|
|
||||||
# or both operators don't have the property (the property for the complement
|
|
||||||
# class is not closed under composition).
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_square(operator_a, operator_b):
|
|
||||||
"""Return a hint to whether the composition is square."""
|
|
||||||
if operator_a.is_square and operator_b.is_square:
|
|
||||||
return True
|
|
||||||
if operator_a.is_square is False and operator_b.is_square is False:
|
|
||||||
# Let A have shape [B, M, N], B have shape [B, N, L].
|
|
||||||
m = operator_a.range_dimension
|
|
||||||
l = operator_b.domain_dimension
|
|
||||||
if m is not None and l is not None:
|
|
||||||
return m == l
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _combined_positive_definite_hint(operator_a, operator_b):
|
|
||||||
"""Get combined PD hint for compositions."""
|
|
||||||
# Note: Positive definiteness is only guaranteed to be preserved
|
|
||||||
# when the operators commute and are symmetric. Only use this method in
|
|
||||||
# commuting cases.
|
|
||||||
|
|
||||||
if (operator_a.is_positive_definite is True and
|
|
||||||
operator_a.is_self_adjoint is True and
|
|
||||||
operator_b.is_positive_definite is True and
|
|
||||||
operator_b.is_self_adjoint is True):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _combined_non_singular_hint(operator_a, operator_b):
|
|
||||||
"""Get combined hint for when ."""
|
|
||||||
# If either operator is not-invertible the composition isn't.
|
|
||||||
if (operator_a.is_non_singular is False or
|
|
||||||
operator_b.is_non_singular is False):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return operator_a.is_non_singular and operator_b.is_non_singular
|
|
||||||
|
|
||||||
|
|
||||||
# By default, use a LinearOperatorComposition to delay the computation.
|
# By default, use a LinearOperatorComposition to delay the computation.
|
||||||
@ -93,15 +34,15 @@ def _combined_non_singular_hint(operator_a, operator_b):
|
|||||||
linear_operator.LinearOperator, linear_operator.LinearOperator)
|
linear_operator.LinearOperator, linear_operator.LinearOperator)
|
||||||
def _matmul_linear_operator(linop_a, linop_b):
|
def _matmul_linear_operator(linop_a, linop_b):
|
||||||
"""Generic matmul of two `LinearOperator`s."""
|
"""Generic matmul of two `LinearOperator`s."""
|
||||||
is_square = _is_square(linop_a, linop_b)
|
is_square = registrations_util.is_square(linop_a, linop_b)
|
||||||
is_non_singular = None
|
is_non_singular = None
|
||||||
is_self_adjoint = None
|
is_self_adjoint = None
|
||||||
is_positive_definite = None
|
is_positive_definite = None
|
||||||
|
|
||||||
if is_square:
|
if is_square:
|
||||||
is_non_singular = _combined_non_singular_hint(linop_a, linop_b)
|
is_non_singular = registrations_util.combined_non_singular_hint(
|
||||||
is_self_adjoint = _combined_self_adjoint_hint(linop_a, linop_b)
|
linop_a, linop_b)
|
||||||
elif is_square is False:
|
elif is_square is False: # pylint:disable=g-bool-id-comparison
|
||||||
is_non_singular = False
|
is_non_singular = False
|
||||||
is_self_adjoint = False
|
is_self_adjoint = False
|
||||||
is_positive_definite = False
|
is_positive_definite = False
|
||||||
@ -165,11 +106,13 @@ def _matmul_linear_operator_zeros_left(zeros, linop):
|
|||||||
def _matmul_linear_operator_diag(linop_a, linop_b):
|
def _matmul_linear_operator_diag(linop_a, linop_b):
|
||||||
return linear_operator_diag.LinearOperatorDiag(
|
return linear_operator_diag.LinearOperatorDiag(
|
||||||
diag=linop_a.diag * linop_b.diag,
|
diag=linop_a.diag * linop_b.diag,
|
||||||
is_non_singular=_combined_non_singular_hint(linop_a, linop_b),
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
is_self_adjoint=_combined_self_adjoint_hint(
|
|
||||||
linop_a, linop_b),
|
linop_a, linop_b),
|
||||||
is_positive_definite=_combined_positive_definite_hint(
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
linop_a, linop_b),
|
linop_a, linop_b),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_a, linop_b)),
|
||||||
is_square=True)
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
@ -180,12 +123,13 @@ def _matmul_linear_operator_diag_scaled_identity_right(
|
|||||||
linop_diag, linop_scaled_identity):
|
linop_diag, linop_scaled_identity):
|
||||||
return linear_operator_diag.LinearOperatorDiag(
|
return linear_operator_diag.LinearOperatorDiag(
|
||||||
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
||||||
is_non_singular=_combined_non_singular_hint(
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
linop_diag, linop_scaled_identity),
|
linop_diag, linop_scaled_identity),
|
||||||
is_self_adjoint=_combined_self_adjoint_hint(
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
linop_diag, linop_scaled_identity),
|
|
||||||
is_positive_definite=_combined_positive_definite_hint(
|
|
||||||
linop_diag, linop_scaled_identity),
|
linop_diag, linop_scaled_identity),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_diag, linop_scaled_identity)),
|
||||||
is_square=True)
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
@ -196,12 +140,13 @@ def _matmul_linear_operator_diag_scaled_identity_left(
|
|||||||
linop_scaled_identity, linop_diag):
|
linop_scaled_identity, linop_diag):
|
||||||
return linear_operator_diag.LinearOperatorDiag(
|
return linear_operator_diag.LinearOperatorDiag(
|
||||||
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
||||||
is_non_singular=_combined_non_singular_hint(
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
linop_diag, linop_scaled_identity),
|
linop_diag, linop_scaled_identity),
|
||||||
is_self_adjoint=_combined_self_adjoint_hint(
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
linop_diag, linop_scaled_identity),
|
|
||||||
is_positive_definite=_combined_positive_definite_hint(
|
|
||||||
linop_diag, linop_scaled_identity),
|
linop_diag, linop_scaled_identity),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_diag, linop_scaled_identity)),
|
||||||
is_square=True)
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
@ -211,11 +156,11 @@ def _matmul_linear_operator_diag_scaled_identity_left(
|
|||||||
def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
|
def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
|
||||||
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||||
tril=linop_diag.diag[..., None] * linop_triangular.to_dense(),
|
tril=linop_diag.diag[..., None] * linop_triangular.to_dense(),
|
||||||
is_non_singular=_combined_non_singular_hint(
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
linop_diag, linop_triangular),
|
linop_diag, linop_triangular),
|
||||||
# This is safe to do since the Triangular matrix is only self-adjoint
|
# This is safe to do since the Triangular matrix is only self-adjoint
|
||||||
# when it is a diagonal matrix, and hence commutes.
|
# when it is a diagonal matrix, and hence commutes.
|
||||||
is_self_adjoint=_combined_self_adjoint_hint(
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
linop_diag, linop_triangular),
|
linop_diag, linop_triangular),
|
||||||
is_positive_definite=None,
|
is_positive_definite=None,
|
||||||
is_square=True)
|
is_square=True)
|
||||||
@ -227,11 +172,11 @@ def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
|
|||||||
def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
|
def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
|
||||||
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||||
tril=linop_triangular.to_dense() * linop_diag.diag,
|
tril=linop_triangular.to_dense() * linop_diag.diag,
|
||||||
is_non_singular=_combined_non_singular_hint(
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
linop_diag, linop_triangular),
|
linop_diag, linop_triangular),
|
||||||
# This is safe to do since the Triangular matrix is only self-adjoint
|
# This is safe to do since the Triangular matrix is only self-adjoint
|
||||||
# when it is a diagonal matrix, and hence commutes.
|
# when it is a diagonal matrix, and hence commutes.
|
||||||
is_self_adjoint=_combined_self_adjoint_hint(
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
linop_diag, linop_triangular),
|
linop_diag, linop_triangular),
|
||||||
is_positive_definite=None,
|
is_positive_definite=None,
|
||||||
is_square=True)
|
is_square=True)
|
||||||
@ -245,8 +190,11 @@ def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
|
|||||||
def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
|
def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
|
||||||
return linear_operator_circulant.LinearOperatorCirculant(
|
return linear_operator_circulant.LinearOperatorCirculant(
|
||||||
spectrum=linop_a.spectrum * linop_b.spectrum,
|
spectrum=linop_a.spectrum * linop_b.spectrum,
|
||||||
is_non_singular=_combined_non_singular_hint(linop_a, linop_b),
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
is_self_adjoint=_combined_self_adjoint_hint(linop_a, linop_b),
|
|
||||||
is_positive_definite=_combined_positive_definite_hint(
|
|
||||||
linop_a, linop_b),
|
linop_a, linop_b),
|
||||||
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
|
linop_a, linop_b),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_a, linop_b)),
|
||||||
is_square=True)
|
is_square=True)
|
||||||
|
|||||||
91
tensorflow/python/ops/linalg/registrations_util.py
Normal file
91
tensorflow/python/ops/linalg/registrations_util.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Common utilities for registering LinearOperator methods."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
# Note: only use this method in the commuting case.
|
||||||
|
def combined_commuting_self_adjoint_hint(operator_a, operator_b):
|
||||||
|
"""Get combined hint for self-adjoint-ness."""
|
||||||
|
|
||||||
|
# The property is preserved under composition when the operators commute.
|
||||||
|
if operator_a.is_self_adjoint and operator_b.is_self_adjoint:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# The property is not preserved when an operator with the property is composed
|
||||||
|
# with an operator without the property.
|
||||||
|
|
||||||
|
# pylint:disable=g-bool-id-comparison
|
||||||
|
if ((operator_a.is_self_adjoint is True and
|
||||||
|
operator_b.is_self_adjoint is False) or
|
||||||
|
(operator_a.is_self_adjoint is False and
|
||||||
|
operator_b.is_self_adjoint is True)):
|
||||||
|
return False
|
||||||
|
# pylint:enable=g-bool-id-comparison
|
||||||
|
|
||||||
|
# The property is not known when operators are not known to have the property
|
||||||
|
# or both operators don't have the property (the property for the complement
|
||||||
|
# class is not closed under composition).
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_square(operator_a, operator_b):
|
||||||
|
"""Return a hint to whether the composition is square."""
|
||||||
|
if operator_a.is_square and operator_b.is_square:
|
||||||
|
return True
|
||||||
|
if operator_a.is_square is False and operator_b.is_square is False: # pylint:disable=g-bool-id-comparison
|
||||||
|
# Let A have shape [B, M, N], B have shape [B, N, L].
|
||||||
|
m = operator_a.range_dimension
|
||||||
|
l = operator_b.domain_dimension
|
||||||
|
if m is not None and l is not None:
|
||||||
|
return m == l
|
||||||
|
|
||||||
|
if (operator_a.is_square != operator_b.is_square) and (
|
||||||
|
operator_a.is_square is not None and operator_a.is_square is not None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Positive definiteness is only guaranteed to be preserved
|
||||||
|
# when the operators commute and are symmetric. Only use this method in
|
||||||
|
# commuting cases.
|
||||||
|
def combined_commuting_positive_definite_hint(operator_a, operator_b):
|
||||||
|
"""Get combined PD hint for compositions."""
|
||||||
|
# pylint:disable=g-bool-id-comparison
|
||||||
|
if (operator_a.is_positive_definite is True and
|
||||||
|
operator_a.is_self_adjoint is True and
|
||||||
|
operator_b.is_positive_definite is True and
|
||||||
|
operator_b.is_self_adjoint is True):
|
||||||
|
return True
|
||||||
|
# pylint:enable=g-bool-id-comparison
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def combined_non_singular_hint(operator_a, operator_b):
|
||||||
|
"""Get combined hint for when ."""
|
||||||
|
# If either operator is not-invertible the composition isn't.
|
||||||
|
|
||||||
|
# pylint:disable=g-bool-id-comparison
|
||||||
|
if (operator_a.is_non_singular is False or
|
||||||
|
operator_b.is_non_singular is False):
|
||||||
|
return False
|
||||||
|
# pylint:enable=g-bool-id-comparison
|
||||||
|
|
||||||
|
return operator_a.is_non_singular and operator_b.is_non_singular
|
||||||
164
tensorflow/python/ops/linalg/solve_registrations.py
Normal file
164
tensorflow/python/ops/linalg/solve_registrations.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Registrations for LinearOperator.solve."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_circulant
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_composition
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_identity
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_inversion
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
||||||
|
from tensorflow.python.ops.linalg import registrations_util
|
||||||
|
|
||||||
|
|
||||||
|
# By default, use a LinearOperatorComposition to delay the computation.
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator.LinearOperator, linear_operator.LinearOperator)
|
||||||
|
def _solve_linear_operator(linop_a, linop_b):
|
||||||
|
"""Generic solve of two `LinearOperator`s."""
|
||||||
|
is_square = registrations_util.is_square(linop_a, linop_b)
|
||||||
|
is_non_singular = None
|
||||||
|
is_self_adjoint = None
|
||||||
|
is_positive_definite = None
|
||||||
|
|
||||||
|
if is_square:
|
||||||
|
is_non_singular = registrations_util.combined_non_singular_hint(
|
||||||
|
linop_a, linop_b)
|
||||||
|
elif is_square is False: # pylint:disable=g-bool-id-comparison
|
||||||
|
is_non_singular = False
|
||||||
|
is_self_adjoint = False
|
||||||
|
is_positive_definite = False
|
||||||
|
|
||||||
|
return linear_operator_composition.LinearOperatorComposition(
|
||||||
|
operators=[
|
||||||
|
linear_operator_inversion.LinearOperatorInversion(linop_a),
|
||||||
|
linop_b
|
||||||
|
],
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_inversion.LinearOperatorInversion,
|
||||||
|
linear_operator.LinearOperator)
|
||||||
|
def _solve_inverse_linear_operator(linop_a, linop_b):
|
||||||
|
"""Solve inverse of generic `LinearOperator`s."""
|
||||||
|
return linop_a.operator.matmul(linop_b)
|
||||||
|
|
||||||
|
|
||||||
|
# Identity
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_identity.LinearOperatorIdentity,
|
||||||
|
linear_operator.LinearOperator)
|
||||||
|
def _solve_linear_operator_identity_left(identity, linop):
|
||||||
|
del identity
|
||||||
|
return linop
|
||||||
|
|
||||||
|
|
||||||
|
# Diag.
|
||||||
|
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_diag.LinearOperatorDiag,
|
||||||
|
linear_operator_diag.LinearOperatorDiag)
|
||||||
|
def _solve_linear_operator_diag(linop_a, linop_b):
|
||||||
|
return linear_operator_diag.LinearOperatorDiag(
|
||||||
|
diag=linop_b.diag / linop_a.diag,
|
||||||
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
|
linop_a, linop_b),
|
||||||
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
|
linop_a, linop_b),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_a, linop_b)),
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_diag.LinearOperatorDiag,
|
||||||
|
linear_operator_identity.LinearOperatorScaledIdentity)
|
||||||
|
def _solve_linear_operator_diag_scaled_identity_right(
|
||||||
|
linop_diag, linop_scaled_identity):
|
||||||
|
return linear_operator_diag.LinearOperatorDiag(
|
||||||
|
diag=linop_scaled_identity.multiplier / linop_diag.diag,
|
||||||
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
|
linop_diag, linop_scaled_identity),
|
||||||
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
|
linop_diag, linop_scaled_identity),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_diag, linop_scaled_identity)),
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_identity.LinearOperatorScaledIdentity,
|
||||||
|
linear_operator_diag.LinearOperatorDiag)
|
||||||
|
def _solve_linear_operator_diag_scaled_identity_left(
|
||||||
|
linop_scaled_identity, linop_diag):
|
||||||
|
return linear_operator_diag.LinearOperatorDiag(
|
||||||
|
diag=linop_diag.diag / linop_scaled_identity.multiplier,
|
||||||
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
|
linop_diag, linop_scaled_identity),
|
||||||
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
|
linop_diag, linop_scaled_identity),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_diag, linop_scaled_identity)),
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_diag.LinearOperatorDiag,
|
||||||
|
linear_operator_lower_triangular.LinearOperatorLowerTriangular)
|
||||||
|
def _solve_linear_operator_diag_tril(linop_diag, linop_triangular):
|
||||||
|
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||||
|
tril=linop_triangular.to_dense() / linop_diag.diag[..., None],
|
||||||
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
|
linop_diag, linop_triangular),
|
||||||
|
# This is safe to do since the Triangular matrix is only self-adjoint
|
||||||
|
# when it is a diagonal matrix, and hence commutes.
|
||||||
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
|
linop_diag, linop_triangular),
|
||||||
|
is_positive_definite=None,
|
||||||
|
is_square=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Circulant.
|
||||||
|
|
||||||
|
|
||||||
|
@linear_operator_algebra.RegisterSolve(
|
||||||
|
linear_operator_circulant.LinearOperatorCirculant,
|
||||||
|
linear_operator_circulant.LinearOperatorCirculant)
|
||||||
|
def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
|
||||||
|
return linear_operator_circulant.LinearOperatorCirculant(
|
||||||
|
spectrum=linop_b.spectrum / linop_a.spectrum,
|
||||||
|
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||||
|
linop_a, linop_b),
|
||||||
|
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||||
|
linop_a, linop_b),
|
||||||
|
is_positive_definite=(
|
||||||
|
registrations_util.combined_commuting_positive_definite_hint(
|
||||||
|
linop_a, linop_b)),
|
||||||
|
is_square=True)
|
||||||
Loading…
x
Reference in New Issue
Block a user