Allow LinearOperator.solve to take a LinearOperator.
PiperOrigin-RevId: 244388120
This commit is contained in:
parent
58f67785f6
commit
0aa8055f1a
tensorflow/python
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
@ -113,6 +115,110 @@ class LinearOperatorAdjointTest(
|
||||
|
||||
self.assertEqual("my_operator_adjoint", operator.name)
|
||||
|
||||
def test_matmul_adjoint_operator(self):
|
||||
matrix1 = np.random.randn(4, 4)
|
||||
matrix2 = np.random.randn(4, 4)
|
||||
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
||||
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||
|
||||
self.assertAllClose(
|
||||
np.matmul(matrix1, matrix2.T),
|
||||
self.evaluate(
|
||||
full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
np.matmul(matrix1.T, matrix2),
|
||||
self.evaluate(
|
||||
full_matrix1.matmul(full_matrix2, adjoint=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
np.matmul(matrix1.T, matrix2.T),
|
||||
self.evaluate(
|
||||
full_matrix1.matmul(
|
||||
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||
|
||||
def test_matmul_adjoint_complex_operator(self):
|
||||
matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
||||
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||
|
||||
self.assertAllClose(
|
||||
np.matmul(matrix1, matrix2.conj().T),
|
||||
self.evaluate(
|
||||
full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
np.matmul(matrix1.conj().T, matrix2),
|
||||
self.evaluate(
|
||||
full_matrix1.matmul(full_matrix2, adjoint=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
np.matmul(matrix1.conj().T, matrix2.conj().T),
|
||||
self.evaluate(
|
||||
full_matrix1.matmul(
|
||||
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||
|
||||
def test_solve_adjoint_operator(self):
|
||||
matrix1 = self.evaluate(
|
||||
linear_operator_test_util.random_tril_matrix(
|
||||
[4, 4], dtype=dtypes.float64, force_well_conditioned=True))
|
||||
matrix2 = np.random.randn(4, 4)
|
||||
full_matrix1 = linalg.LinearOperatorLowerTriangular(
|
||||
matrix1, is_non_singular=True)
|
||||
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(linalg.triangular_solve(matrix1, matrix2.T)),
|
||||
self.evaluate(
|
||||
full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(
|
||||
linalg.triangular_solve(
|
||||
matrix1.T, matrix2, lower=False)),
|
||||
self.evaluate(
|
||||
full_matrix1.solve(full_matrix2, adjoint=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(
|
||||
linalg.triangular_solve(matrix1.T, matrix2.T, lower=False)),
|
||||
self.evaluate(
|
||||
full_matrix1.solve(
|
||||
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||
|
||||
def test_solve_adjoint_complex_operator(self):
|
||||
matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix(
|
||||
[4, 4], dtype=dtypes.complex128, force_well_conditioned=True) +
|
||||
1j * linear_operator_test_util.random_tril_matrix(
|
||||
[4, 4], dtype=dtypes.complex128,
|
||||
force_well_conditioned=True))
|
||||
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||
|
||||
full_matrix1 = linalg.LinearOperatorLowerTriangular(
|
||||
matrix1, is_non_singular=True)
|
||||
full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2)
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(linalg.triangular_solve(matrix1, matrix2.conj().T)),
|
||||
self.evaluate(
|
||||
full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(
|
||||
linalg.triangular_solve(
|
||||
matrix1.conj().T, matrix2, lower=False)),
|
||||
self.evaluate(
|
||||
full_matrix1.solve(full_matrix2, adjoint=True).to_dense()))
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(
|
||||
linalg.triangular_solve(
|
||||
matrix1.conj().T, matrix2.conj().T, lower=False)),
|
||||
self.evaluate(
|
||||
full_matrix1.solve(
|
||||
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||
|
||||
|
||||
class LinearOperatorAdjointNonSquareTest(
|
||||
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations # pylint: disab
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||
from tensorflow.python.ops.linalg import matmul_registrations # pylint: disable=unused-import
|
||||
from tensorflow.python.ops.linalg import solve_registrations # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@ -34,6 +35,8 @@ _INVERSES = linear_operator_algebra._INVERSES
|
||||
_registered_inverse = linear_operator_algebra._registered_inverse
|
||||
_MATMUL = linear_operator_algebra._MATMUL
|
||||
_registered_matmul = linear_operator_algebra._registered_matmul
|
||||
_SOLVE = linear_operator_algebra._SOLVE
|
||||
_registered_solve = linear_operator_algebra._registered_solve
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@ -175,6 +178,55 @@ class MatmulTest(test.TestCase):
|
||||
self.assertEqual(v, _registered_matmul(k[0], k[1]))
|
||||
|
||||
|
||||
class SolveTest(test.TestCase):
|
||||
|
||||
def testRegistration(self):
|
||||
|
||||
class CustomLinOp(linear_operator.LinearOperator):
|
||||
|
||||
def _matmul(self, a):
|
||||
pass
|
||||
|
||||
def _solve(self, a):
|
||||
pass
|
||||
|
||||
def _shape(self):
|
||||
return tensor_shape.TensorShape([1, 1])
|
||||
|
||||
def _shape_tensor(self):
|
||||
pass
|
||||
|
||||
# Register Solve to a lambda that spits out the name parameter
|
||||
@linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)
|
||||
def _solve(a, b): # pylint: disable=unused-argument,unused-variable
|
||||
return "OK"
|
||||
|
||||
custom_linop = CustomLinOp(
|
||||
dtype=None, is_self_adjoint=True, is_positive_definite=True)
|
||||
self.assertEqual("OK", custom_linop.solve(custom_linop))
|
||||
|
||||
def testRegistrationFailures(self):
|
||||
|
||||
class CustomLinOp(linear_operator.LinearOperator):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "must be callable"):
|
||||
linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)("blah")
|
||||
|
||||
# First registration is OK
|
||||
linear_operator_algebra.RegisterSolve(
|
||||
CustomLinOp, CustomLinOp)(lambda a: None)
|
||||
|
||||
# Second registration fails
|
||||
with self.assertRaisesRegexp(ValueError, "has already been registered"):
|
||||
linear_operator_algebra.RegisterSolve(
|
||||
CustomLinOp, CustomLinOp)(lambda a: None)
|
||||
|
||||
def testExactSolveRegistrationsAllMatch(self):
|
||||
for (k, v) in _SOLVE.items():
|
||||
self.assertEqual(v, _registered_solve(k[0], k[1]))
|
||||
|
||||
|
||||
class InverseTest(test.TestCase):
|
||||
|
||||
def testRegistration(self):
|
||||
|
@ -187,6 +187,35 @@ class LinearOperatorDiagTest(
|
||||
linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag))
|
||||
|
||||
def test_diag_solve(self):
|
||||
operator1 = linalg_lib.LinearOperatorDiag([2., 3.], is_non_singular=True)
|
||||
operator2 = linalg_lib.LinearOperatorDiag([1., 2.], is_non_singular=True)
|
||||
operator3 = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=3., is_non_singular=True)
|
||||
operator_solve = operator1.solve(operator2)
|
||||
self.assertTrue(isinstance(
|
||||
operator_solve,
|
||||
linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([0.5, 2 / 3.], self.evaluate(operator_solve.diag))
|
||||
|
||||
operator_solve = operator2.solve(operator1)
|
||||
self.assertTrue(isinstance(
|
||||
operator_solve,
|
||||
linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([2., 3 / 2.], self.evaluate(operator_solve.diag))
|
||||
|
||||
operator_solve = operator1.solve(operator3)
|
||||
self.assertTrue(isinstance(
|
||||
operator_solve,
|
||||
linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([3 / 2., 1.], self.evaluate(operator_solve.diag))
|
||||
|
||||
operator_solve = operator3.solve(operator1)
|
||||
self.assertTrue(isinstance(
|
||||
operator_solve,
|
||||
linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([2 / 3., 1.], self.evaluate(operator_solve.diag))
|
||||
|
||||
def test_diag_adjoint_type(self):
|
||||
diag = [1., 3., 5., 8.]
|
||||
operator = linalg.LinearOperatorDiag(diag, is_non_singular=True)
|
||||
|
@ -495,6 +495,20 @@ class LinearOperatorScaledIdentityTest(
|
||||
linalg_lib.LinearOperatorScaledIdentity))
|
||||
self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))
|
||||
|
||||
def test_identity_solve(self):
|
||||
operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
|
||||
operator2 = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=3.)
|
||||
self.assertTrue(isinstance(
|
||||
operator1.solve(operator1),
|
||||
linalg_lib.LinearOperatorIdentity))
|
||||
|
||||
operator_solve = operator1.solve(operator2)
|
||||
self.assertTrue(isinstance(
|
||||
operator_solve,
|
||||
linalg_lib.LinearOperatorScaledIdentity))
|
||||
self.assertAllClose(3., self.evaluate(operator_solve.multiplier))
|
||||
|
||||
def test_scaled_identity_cholesky_type(self):
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2,
|
||||
|
@ -238,7 +238,7 @@ class LinearOperatorTest(test.TestCase):
|
||||
|
||||
self.assertTrue(operator_matmul.is_square)
|
||||
self.assertTrue(operator_matmul.is_non_singular)
|
||||
self.assertTrue(operator_matmul.is_self_adjoint)
|
||||
self.assertEqual(None, operator_matmul.is_self_adjoint)
|
||||
self.assertEqual(None, operator_matmul.is_positive_definite)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations as _cholesky_reg
|
||||
from tensorflow.python.ops.linalg import inverse_registrations as _inverse_registrations
|
||||
from tensorflow.python.ops.linalg import linear_operator_algebra as _linear_operator_algebra
|
||||
from tensorflow.python.ops.linalg import matmul_registrations as _matmul_registrations
|
||||
from tensorflow.python.ops.linalg import solve_registrations as _solve_registrations
|
||||
from tensorflow.python.ops.linalg.linalg_impl import *
|
||||
from tensorflow.python.ops.linalg.linear_operator import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_block_diag import *
|
||||
|
@ -597,16 +597,18 @@ class LinearOperator(object):
|
||||
as `self`.
|
||||
"""
|
||||
if isinstance(x, LinearOperator):
|
||||
if adjoint or adjoint_arg:
|
||||
raise ValueError(".matmul not supported with adjoints.")
|
||||
if (x.range_dimension is not None and
|
||||
self.domain_dimension is not None and
|
||||
x.range_dimension != self.domain_dimension):
|
||||
left_operator = self.adjoint() if adjoint else self
|
||||
right_operator = x.adjoint() if adjoint_arg else x
|
||||
|
||||
if (right_operator.range_dimension is not None and
|
||||
left_operator.domain_dimension is not None and
|
||||
right_operator.range_dimension != left_operator.domain_dimension):
|
||||
raise ValueError(
|
||||
"Operators are incompatible. Expected `x` to have dimension"
|
||||
" {} but got {}.".format(self.domain_dimension, x.range_dimension))
|
||||
" {} but got {}.".format(
|
||||
left_operator.domain_dimension, right_operator.range_dimension))
|
||||
with self._name_scope(name):
|
||||
return linear_operator_algebra.matmul(self, x)
|
||||
return linear_operator_algebra.matmul(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name, values=[x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
@ -780,6 +782,20 @@ class LinearOperator(object):
|
||||
raise NotImplementedError(
|
||||
"Exact solve not implemented for an operator that is expected to "
|
||||
"not be square.")
|
||||
if isinstance(rhs, LinearOperator):
|
||||
left_operator = self.adjoint() if adjoint else self
|
||||
right_operator = rhs.adjoint() if adjoint_arg else rhs
|
||||
|
||||
if (right_operator.range_dimension is not None and
|
||||
left_operator.domain_dimension is not None and
|
||||
right_operator.range_dimension != left_operator.domain_dimension):
|
||||
raise ValueError(
|
||||
"Operators are incompatible. Expected `rhs` to have dimension"
|
||||
" {} but got {}.".format(
|
||||
left_operator.domain_dimension, right_operator.range_dimension))
|
||||
with self._name_scope(name):
|
||||
return linear_operator_algebra.solve(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name, values=[rhs]):
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.util import tf_inspect
|
||||
_ADJOINTS = {}
|
||||
_CHOLESKY_DECOMPS = {}
|
||||
_MATMUL = {}
|
||||
_SOLVE = {}
|
||||
_INVERSES = {}
|
||||
|
||||
|
||||
@ -62,6 +63,11 @@ def _registered_matmul(type_a, type_b):
|
||||
return _registered_function([type_a, type_b], _MATMUL)
|
||||
|
||||
|
||||
def _registered_solve(type_a, type_b):
|
||||
"""Get the Solve function registered for classes a and b."""
|
||||
return _registered_function([type_a, type_b], _SOLVE)
|
||||
|
||||
|
||||
def _registered_inverse(type_a):
|
||||
"""Get the Cholesky function registered for class a."""
|
||||
return _registered_function([type_a], _INVERSES)
|
||||
@ -138,6 +144,31 @@ def matmul(lin_op_a, lin_op_b, name=None):
|
||||
return matmul_fn(lin_op_a, lin_op_b)
|
||||
|
||||
|
||||
def solve(lin_op_a, lin_op_b, name=None):
|
||||
"""Compute lin_op_a.solve(lin_op_b).
|
||||
|
||||
Args:
|
||||
lin_op_a: The LinearOperator on the left.
|
||||
lin_op_b: The LinearOperator on the right.
|
||||
name: Name to use for this operation.
|
||||
|
||||
Returns:
|
||||
A LinearOperator that represents the solve between `lin_op_a` and
|
||||
`lin_op_b`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If no solve method is defined between types of
|
||||
`lin_op_a` and `lin_op_b`.
|
||||
"""
|
||||
solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b))
|
||||
if solve_fn is None:
|
||||
raise ValueError("No solve registered for {}.solve({})".format(
|
||||
type(lin_op_a), type(lin_op_b)))
|
||||
|
||||
with ops.name_scope(name, "Solve"):
|
||||
return solve_fn(lin_op_a, lin_op_b)
|
||||
|
||||
|
||||
def inverse(lin_op_a, name=None):
|
||||
"""Get the Inverse associated to lin_op_a.
|
||||
|
||||
@ -291,6 +322,52 @@ class RegisterMatmul(object):
|
||||
return matmul_fn
|
||||
|
||||
|
||||
class RegisterSolve(object):
|
||||
"""Decorator to register a Solve implementation function.
|
||||
|
||||
Usage:
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
lin_op.LinearOperatorIdentity,
|
||||
lin_op.LinearOperatorIdentity)
|
||||
def _solve_identity(a, b):
|
||||
# Return the identity matrix.
|
||||
"""
|
||||
|
||||
def __init__(self, lin_op_cls_a, lin_op_cls_b):
|
||||
"""Initialize the LinearOperator registrar.
|
||||
|
||||
Args:
|
||||
lin_op_cls_a: the class of the LinearOperator that is computing solve.
|
||||
lin_op_cls_b: the class of the second LinearOperator to solve.
|
||||
"""
|
||||
self._key = (lin_op_cls_a, lin_op_cls_b)
|
||||
|
||||
def __call__(self, solve_fn):
|
||||
"""Perform the Solve registration.
|
||||
|
||||
Args:
|
||||
solve_fn: The function to use for the Solve.
|
||||
|
||||
Returns:
|
||||
solve_fn
|
||||
|
||||
Raises:
|
||||
TypeError: if solve_fn is not a callable.
|
||||
ValueError: if a Solve function has already been registered for
|
||||
the given argument classes.
|
||||
"""
|
||||
if not callable(solve_fn):
|
||||
raise TypeError(
|
||||
"solve_fn must be callable, received: {}".format(solve_fn))
|
||||
if self._key in _SOLVE:
|
||||
raise ValueError("Solve({}, {}) has already been registered.".format(
|
||||
self._key[0].__name__,
|
||||
self._key[1].__name__))
|
||||
_SOLVE[self._key] = solve_fn
|
||||
return solve_fn
|
||||
|
||||
|
||||
class RegisterInverse(object):
|
||||
"""Decorator to register an Inverse implementation function.
|
||||
|
||||
|
@ -26,66 +26,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag
|
||||
from tensorflow.python.ops.linalg import linear_operator_identity
|
||||
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
||||
from tensorflow.python.ops.linalg import linear_operator_zeros
|
||||
|
||||
|
||||
def _combined_self_adjoint_hint(operator_a, operator_b):
|
||||
"""Get combined hint for self-adjoint-ness."""
|
||||
# Note: only use this method in the commuting case.
|
||||
# The property is preserved under composition when the operators commute.
|
||||
if operator_a.is_self_adjoint and operator_b.is_self_adjoint:
|
||||
return True
|
||||
|
||||
# The property is not preserved when an operator with the property is composed
|
||||
# with an operator without the property.
|
||||
if ((operator_a.is_self_adjoint is True and
|
||||
operator_b.is_self_adjoint is False) or
|
||||
(operator_a.is_self_adjoint is False and
|
||||
operator_b.is_self_adjoint is True)):
|
||||
return False
|
||||
|
||||
# The property is not known when operators are not known to have the property
|
||||
# or both operators don't have the property (the property for the complement
|
||||
# class is not closed under composition).
|
||||
return None
|
||||
|
||||
|
||||
def _is_square(operator_a, operator_b):
|
||||
"""Return a hint to whether the composition is square."""
|
||||
if operator_a.is_square and operator_b.is_square:
|
||||
return True
|
||||
if operator_a.is_square is False and operator_b.is_square is False:
|
||||
# Let A have shape [B, M, N], B have shape [B, N, L].
|
||||
m = operator_a.range_dimension
|
||||
l = operator_b.domain_dimension
|
||||
if m is not None and l is not None:
|
||||
return m == l
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _combined_positive_definite_hint(operator_a, operator_b):
|
||||
"""Get combined PD hint for compositions."""
|
||||
# Note: Positive definiteness is only guaranteed to be preserved
|
||||
# when the operators commute and are symmetric. Only use this method in
|
||||
# commuting cases.
|
||||
|
||||
if (operator_a.is_positive_definite is True and
|
||||
operator_a.is_self_adjoint is True and
|
||||
operator_b.is_positive_definite is True and
|
||||
operator_b.is_self_adjoint is True):
|
||||
return True
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _combined_non_singular_hint(operator_a, operator_b):
|
||||
"""Get combined hint for when ."""
|
||||
# If either operator is not-invertible the composition isn't.
|
||||
if (operator_a.is_non_singular is False or
|
||||
operator_b.is_non_singular is False):
|
||||
return False
|
||||
|
||||
return operator_a.is_non_singular and operator_b.is_non_singular
|
||||
from tensorflow.python.ops.linalg import registrations_util
|
||||
|
||||
|
||||
# By default, use a LinearOperatorComposition to delay the computation.
|
||||
@ -93,15 +34,15 @@ def _combined_non_singular_hint(operator_a, operator_b):
|
||||
linear_operator.LinearOperator, linear_operator.LinearOperator)
|
||||
def _matmul_linear_operator(linop_a, linop_b):
|
||||
"""Generic matmul of two `LinearOperator`s."""
|
||||
is_square = _is_square(linop_a, linop_b)
|
||||
is_square = registrations_util.is_square(linop_a, linop_b)
|
||||
is_non_singular = None
|
||||
is_self_adjoint = None
|
||||
is_positive_definite = None
|
||||
|
||||
if is_square:
|
||||
is_non_singular = _combined_non_singular_hint(linop_a, linop_b)
|
||||
is_self_adjoint = _combined_self_adjoint_hint(linop_a, linop_b)
|
||||
elif is_square is False:
|
||||
is_non_singular = registrations_util.combined_non_singular_hint(
|
||||
linop_a, linop_b)
|
||||
elif is_square is False: # pylint:disable=g-bool-id-comparison
|
||||
is_non_singular = False
|
||||
is_self_adjoint = False
|
||||
is_positive_definite = False
|
||||
@ -165,11 +106,13 @@ def _matmul_linear_operator_zeros_left(zeros, linop):
|
||||
def _matmul_linear_operator_diag(linop_a, linop_b):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=linop_a.diag * linop_b.diag,
|
||||
is_non_singular=_combined_non_singular_hint(linop_a, linop_b),
|
||||
is_self_adjoint=_combined_self_adjoint_hint(
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_a, linop_b),
|
||||
is_positive_definite=_combined_positive_definite_hint(
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_a, linop_b),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_a, linop_b)),
|
||||
is_square=True)
|
||||
|
||||
|
||||
@ -180,12 +123,13 @@ def _matmul_linear_operator_diag_scaled_identity_right(
|
||||
linop_diag, linop_scaled_identity):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
||||
is_non_singular=_combined_non_singular_hint(
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_self_adjoint=_combined_self_adjoint_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_positive_definite=_combined_positive_definite_hint(
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_diag, linop_scaled_identity)),
|
||||
is_square=True)
|
||||
|
||||
|
||||
@ -196,12 +140,13 @@ def _matmul_linear_operator_diag_scaled_identity_left(
|
||||
linop_scaled_identity, linop_diag):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=linop_diag.diag * linop_scaled_identity.multiplier,
|
||||
is_non_singular=_combined_non_singular_hint(
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_self_adjoint=_combined_self_adjoint_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_positive_definite=_combined_positive_definite_hint(
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_diag, linop_scaled_identity)),
|
||||
is_square=True)
|
||||
|
||||
|
||||
@ -211,11 +156,11 @@ def _matmul_linear_operator_diag_scaled_identity_left(
|
||||
def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
|
||||
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||
tril=linop_diag.diag[..., None] * linop_triangular.to_dense(),
|
||||
is_non_singular=_combined_non_singular_hint(
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_triangular),
|
||||
# This is safe to do since the Triangular matrix is only self-adjoint
|
||||
# when it is a diagonal matrix, and hence commutes.
|
||||
is_self_adjoint=_combined_self_adjoint_hint(
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_triangular),
|
||||
is_positive_definite=None,
|
||||
is_square=True)
|
||||
@ -227,11 +172,11 @@ def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
|
||||
def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
|
||||
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||
tril=linop_triangular.to_dense() * linop_diag.diag,
|
||||
is_non_singular=_combined_non_singular_hint(
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_triangular),
|
||||
# This is safe to do since the Triangular matrix is only self-adjoint
|
||||
# when it is a diagonal matrix, and hence commutes.
|
||||
is_self_adjoint=_combined_self_adjoint_hint(
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_triangular),
|
||||
is_positive_definite=None,
|
||||
is_square=True)
|
||||
@ -245,8 +190,11 @@ def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
|
||||
def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
|
||||
return linear_operator_circulant.LinearOperatorCirculant(
|
||||
spectrum=linop_a.spectrum * linop_b.spectrum,
|
||||
is_non_singular=_combined_non_singular_hint(linop_a, linop_b),
|
||||
is_self_adjoint=_combined_self_adjoint_hint(linop_a, linop_b),
|
||||
is_positive_definite=_combined_positive_definite_hint(
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_a, linop_b),
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_a, linop_b),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_a, linop_b)),
|
||||
is_square=True)
|
||||
|
91
tensorflow/python/ops/linalg/registrations_util.py
Normal file
91
tensorflow/python/ops/linalg/registrations_util.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common utilities for registering LinearOperator methods."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
# Note: only use this method in the commuting case.
|
||||
def combined_commuting_self_adjoint_hint(operator_a, operator_b):
|
||||
"""Get combined hint for self-adjoint-ness."""
|
||||
|
||||
# The property is preserved under composition when the operators commute.
|
||||
if operator_a.is_self_adjoint and operator_b.is_self_adjoint:
|
||||
return True
|
||||
|
||||
# The property is not preserved when an operator with the property is composed
|
||||
# with an operator without the property.
|
||||
|
||||
# pylint:disable=g-bool-id-comparison
|
||||
if ((operator_a.is_self_adjoint is True and
|
||||
operator_b.is_self_adjoint is False) or
|
||||
(operator_a.is_self_adjoint is False and
|
||||
operator_b.is_self_adjoint is True)):
|
||||
return False
|
||||
# pylint:enable=g-bool-id-comparison
|
||||
|
||||
# The property is not known when operators are not known to have the property
|
||||
# or both operators don't have the property (the property for the complement
|
||||
# class is not closed under composition).
|
||||
return None
|
||||
|
||||
|
||||
def is_square(operator_a, operator_b):
|
||||
"""Return a hint to whether the composition is square."""
|
||||
if operator_a.is_square and operator_b.is_square:
|
||||
return True
|
||||
if operator_a.is_square is False and operator_b.is_square is False: # pylint:disable=g-bool-id-comparison
|
||||
# Let A have shape [B, M, N], B have shape [B, N, L].
|
||||
m = operator_a.range_dimension
|
||||
l = operator_b.domain_dimension
|
||||
if m is not None and l is not None:
|
||||
return m == l
|
||||
|
||||
if (operator_a.is_square != operator_b.is_square) and (
|
||||
operator_a.is_square is not None and operator_a.is_square is not None):
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Note: Positive definiteness is only guaranteed to be preserved
|
||||
# when the operators commute and are symmetric. Only use this method in
|
||||
# commuting cases.
|
||||
def combined_commuting_positive_definite_hint(operator_a, operator_b):
|
||||
"""Get combined PD hint for compositions."""
|
||||
# pylint:disable=g-bool-id-comparison
|
||||
if (operator_a.is_positive_definite is True and
|
||||
operator_a.is_self_adjoint is True and
|
||||
operator_b.is_positive_definite is True and
|
||||
operator_b.is_self_adjoint is True):
|
||||
return True
|
||||
# pylint:enable=g-bool-id-comparison
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def combined_non_singular_hint(operator_a, operator_b):
|
||||
"""Get combined hint for when ."""
|
||||
# If either operator is not-invertible the composition isn't.
|
||||
|
||||
# pylint:disable=g-bool-id-comparison
|
||||
if (operator_a.is_non_singular is False or
|
||||
operator_b.is_non_singular is False):
|
||||
return False
|
||||
# pylint:enable=g-bool-id-comparison
|
||||
|
||||
return operator_a.is_non_singular and operator_b.is_non_singular
|
164
tensorflow/python/ops/linalg/solve_registrations.py
Normal file
164
tensorflow/python/ops/linalg/solve_registrations.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Registrations for LinearOperator.solve."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||
from tensorflow.python.ops.linalg import linear_operator_circulant
|
||||
from tensorflow.python.ops.linalg import linear_operator_composition
|
||||
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||
from tensorflow.python.ops.linalg import linear_operator_identity
|
||||
from tensorflow.python.ops.linalg import linear_operator_inversion
|
||||
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
||||
from tensorflow.python.ops.linalg import registrations_util
|
||||
|
||||
|
||||
# By default, use a LinearOperatorComposition to delay the computation.
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator.LinearOperator, linear_operator.LinearOperator)
|
||||
def _solve_linear_operator(linop_a, linop_b):
|
||||
"""Generic solve of two `LinearOperator`s."""
|
||||
is_square = registrations_util.is_square(linop_a, linop_b)
|
||||
is_non_singular = None
|
||||
is_self_adjoint = None
|
||||
is_positive_definite = None
|
||||
|
||||
if is_square:
|
||||
is_non_singular = registrations_util.combined_non_singular_hint(
|
||||
linop_a, linop_b)
|
||||
elif is_square is False: # pylint:disable=g-bool-id-comparison
|
||||
is_non_singular = False
|
||||
is_self_adjoint = False
|
||||
is_positive_definite = False
|
||||
|
||||
return linear_operator_composition.LinearOperatorComposition(
|
||||
operators=[
|
||||
linear_operator_inversion.LinearOperatorInversion(linop_a),
|
||||
linop_b
|
||||
],
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
)
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_inversion.LinearOperatorInversion,
|
||||
linear_operator.LinearOperator)
|
||||
def _solve_inverse_linear_operator(linop_a, linop_b):
|
||||
"""Solve inverse of generic `LinearOperator`s."""
|
||||
return linop_a.operator.matmul(linop_b)
|
||||
|
||||
|
||||
# Identity
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_identity.LinearOperatorIdentity,
|
||||
linear_operator.LinearOperator)
|
||||
def _solve_linear_operator_identity_left(identity, linop):
|
||||
del identity
|
||||
return linop
|
||||
|
||||
|
||||
# Diag.
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_diag.LinearOperatorDiag,
|
||||
linear_operator_diag.LinearOperatorDiag)
|
||||
def _solve_linear_operator_diag(linop_a, linop_b):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=linop_b.diag / linop_a.diag,
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_a, linop_b),
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_a, linop_b),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_a, linop_b)),
|
||||
is_square=True)
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_diag.LinearOperatorDiag,
|
||||
linear_operator_identity.LinearOperatorScaledIdentity)
|
||||
def _solve_linear_operator_diag_scaled_identity_right(
|
||||
linop_diag, linop_scaled_identity):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=linop_scaled_identity.multiplier / linop_diag.diag,
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_diag, linop_scaled_identity)),
|
||||
is_square=True)
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_identity.LinearOperatorScaledIdentity,
|
||||
linear_operator_diag.LinearOperatorDiag)
|
||||
def _solve_linear_operator_diag_scaled_identity_left(
|
||||
linop_scaled_identity, linop_diag):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=linop_diag.diag / linop_scaled_identity.multiplier,
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_scaled_identity),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_diag, linop_scaled_identity)),
|
||||
is_square=True)
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_diag.LinearOperatorDiag,
|
||||
linear_operator_lower_triangular.LinearOperatorLowerTriangular)
|
||||
def _solve_linear_operator_diag_tril(linop_diag, linop_triangular):
|
||||
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||
tril=linop_triangular.to_dense() / linop_diag.diag[..., None],
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_diag, linop_triangular),
|
||||
# This is safe to do since the Triangular matrix is only self-adjoint
|
||||
# when it is a diagonal matrix, and hence commutes.
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_diag, linop_triangular),
|
||||
is_positive_definite=None,
|
||||
is_square=True)
|
||||
|
||||
|
||||
# Circulant.
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterSolve(
|
||||
linear_operator_circulant.LinearOperatorCirculant,
|
||||
linear_operator_circulant.LinearOperatorCirculant)
|
||||
def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
|
||||
return linear_operator_circulant.LinearOperatorCirculant(
|
||||
spectrum=linop_b.spectrum / linop_a.spectrum,
|
||||
is_non_singular=registrations_util.combined_non_singular_hint(
|
||||
linop_a, linop_b),
|
||||
is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
|
||||
linop_a, linop_b),
|
||||
is_positive_definite=(
|
||||
registrations_util.combined_commuting_positive_definite_hint(
|
||||
linop_a, linop_b)),
|
||||
is_square=True)
|
Loading…
Reference in New Issue
Block a user