diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py index f70d8c4e1cd..f913cffa8b8 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops.linalg import linalg as linalg_lib @@ -113,6 +115,110 @@ class LinearOperatorAdjointTest( self.assertEqual("my_operator_adjoint", operator.name) + def test_matmul_adjoint_operator(self): + matrix1 = np.random.randn(4, 4) + matrix2 = np.random.randn(4, 4) + full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) + full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) + + self.assertAllClose( + np.matmul(matrix1, matrix2.T), + self.evaluate( + full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense())) + + self.assertAllClose( + np.matmul(matrix1.T, matrix2), + self.evaluate( + full_matrix1.matmul(full_matrix2, adjoint=True).to_dense())) + + self.assertAllClose( + np.matmul(matrix1.T, matrix2.T), + self.evaluate( + full_matrix1.matmul( + full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) + + def test_matmul_adjoint_complex_operator(self): + matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) + matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) + full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) + full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) + + self.assertAllClose( + np.matmul(matrix1, matrix2.conj().T), + self.evaluate( + full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense())) + + self.assertAllClose( + np.matmul(matrix1.conj().T, matrix2), + self.evaluate( + full_matrix1.matmul(full_matrix2, adjoint=True).to_dense())) + + self.assertAllClose( + np.matmul(matrix1.conj().T, matrix2.conj().T), + self.evaluate( + full_matrix1.matmul( + full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) + + def test_solve_adjoint_operator(self): + matrix1 = self.evaluate( + linear_operator_test_util.random_tril_matrix( + [4, 4], dtype=dtypes.float64, force_well_conditioned=True)) + matrix2 = np.random.randn(4, 4) + full_matrix1 = linalg.LinearOperatorLowerTriangular( + matrix1, is_non_singular=True) + full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) + + self.assertAllClose( + self.evaluate(linalg.triangular_solve(matrix1, matrix2.T)), + self.evaluate( + full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense())) + + self.assertAllClose( + self.evaluate( + linalg.triangular_solve( + matrix1.T, matrix2, lower=False)), + self.evaluate( + full_matrix1.solve(full_matrix2, adjoint=True).to_dense())) + + self.assertAllClose( + self.evaluate( + linalg.triangular_solve(matrix1.T, matrix2.T, lower=False)), + self.evaluate( + full_matrix1.solve( + full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) + + def test_solve_adjoint_complex_operator(self): + matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( + [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + + 1j * linear_operator_test_util.random_tril_matrix( + [4, 4], dtype=dtypes.complex128, + force_well_conditioned=True)) + matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) + + full_matrix1 = linalg.LinearOperatorLowerTriangular( + matrix1, is_non_singular=True) + full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) + + self.assertAllClose( + self.evaluate(linalg.triangular_solve(matrix1, matrix2.conj().T)), + self.evaluate( + full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense())) + + self.assertAllClose( + self.evaluate( + linalg.triangular_solve( + matrix1.conj().T, matrix2, lower=False)), + self.evaluate( + full_matrix1.solve(full_matrix2, adjoint=True).to_dense())) + + self.assertAllClose( + self.evaluate( + linalg.triangular_solve( + matrix1.conj().T, matrix2.conj().T, lower=False)), + self.evaluate( + full_matrix1.solve( + full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) + class LinearOperatorAdjointNonSquareTest( linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py index 12da8659cac..8057d055783 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py @@ -23,6 +23,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations # pylint: disab from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_algebra from tensorflow.python.ops.linalg import matmul_registrations # pylint: disable=unused-import +from tensorflow.python.ops.linalg import solve_registrations # pylint: disable=unused-import from tensorflow.python.platform import test # pylint: disable=protected-access @@ -34,6 +35,8 @@ _INVERSES = linear_operator_algebra._INVERSES _registered_inverse = linear_operator_algebra._registered_inverse _MATMUL = linear_operator_algebra._MATMUL _registered_matmul = linear_operator_algebra._registered_matmul +_SOLVE = linear_operator_algebra._SOLVE +_registered_solve = linear_operator_algebra._registered_solve # pylint: enable=protected-access @@ -175,6 +178,55 @@ class MatmulTest(test.TestCase): self.assertEqual(v, _registered_matmul(k[0], k[1])) +class SolveTest(test.TestCase): + + def testRegistration(self): + + class CustomLinOp(linear_operator.LinearOperator): + + def _matmul(self, a): + pass + + def _solve(self, a): + pass + + def _shape(self): + return tensor_shape.TensorShape([1, 1]) + + def _shape_tensor(self): + pass + + # Register Solve to a lambda that spits out the name parameter + @linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp) + def _solve(a, b): # pylint: disable=unused-argument,unused-variable + return "OK" + + custom_linop = CustomLinOp( + dtype=None, is_self_adjoint=True, is_positive_definite=True) + self.assertEqual("OK", custom_linop.solve(custom_linop)) + + def testRegistrationFailures(self): + + class CustomLinOp(linear_operator.LinearOperator): + pass + + with self.assertRaisesRegexp(TypeError, "must be callable"): + linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)("blah") + + # First registration is OK + linear_operator_algebra.RegisterSolve( + CustomLinOp, CustomLinOp)(lambda a: None) + + # Second registration fails + with self.assertRaisesRegexp(ValueError, "has already been registered"): + linear_operator_algebra.RegisterSolve( + CustomLinOp, CustomLinOp)(lambda a: None) + + def testExactSolveRegistrationsAllMatch(self): + for (k, v) in _SOLVE.items(): + self.assertEqual(v, _registered_solve(k[0], k[1])) + + class InverseTest(test.TestCase): def testRegistration(self): diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py index 5c3220e60f4..4b2548a406d 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py @@ -187,6 +187,35 @@ class LinearOperatorDiagTest( linalg_lib.LinearOperatorDiag)) self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag)) + def test_diag_solve(self): + operator1 = linalg_lib.LinearOperatorDiag([2., 3.], is_non_singular=True) + operator2 = linalg_lib.LinearOperatorDiag([1., 2.], is_non_singular=True) + operator3 = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=3., is_non_singular=True) + operator_solve = operator1.solve(operator2) + self.assertTrue(isinstance( + operator_solve, + linalg_lib.LinearOperatorDiag)) + self.assertAllClose([0.5, 2 / 3.], self.evaluate(operator_solve.diag)) + + operator_solve = operator2.solve(operator1) + self.assertTrue(isinstance( + operator_solve, + linalg_lib.LinearOperatorDiag)) + self.assertAllClose([2., 3 / 2.], self.evaluate(operator_solve.diag)) + + operator_solve = operator1.solve(operator3) + self.assertTrue(isinstance( + operator_solve, + linalg_lib.LinearOperatorDiag)) + self.assertAllClose([3 / 2., 1.], self.evaluate(operator_solve.diag)) + + operator_solve = operator3.solve(operator1) + self.assertTrue(isinstance( + operator_solve, + linalg_lib.LinearOperatorDiag)) + self.assertAllClose([2 / 3., 1.], self.evaluate(operator_solve.diag)) + def test_diag_adjoint_type(self): diag = [1., 3., 5., 8.] operator = linalg.LinearOperatorDiag(diag, is_non_singular=True) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py index 63d7be1d593..23e936f0f59 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py @@ -495,6 +495,20 @@ class LinearOperatorScaledIdentityTest( linalg_lib.LinearOperatorScaledIdentity)) self.assertAllClose(3., self.evaluate(operator_matmul.multiplier)) + def test_identity_solve(self): + operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2) + operator2 = linalg_lib.LinearOperatorScaledIdentity( + num_rows=2, multiplier=3.) + self.assertTrue(isinstance( + operator1.solve(operator1), + linalg_lib.LinearOperatorIdentity)) + + operator_solve = operator1.solve(operator2) + self.assertTrue(isinstance( + operator_solve, + linalg_lib.LinearOperatorScaledIdentity)) + self.assertAllClose(3., self.evaluate(operator_solve.multiplier)) + def test_scaled_identity_cholesky_type(self): operator = linalg_lib.LinearOperatorScaledIdentity( num_rows=2, diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py index 8f8b15e8ed8..c62f3f0fed4 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py @@ -238,7 +238,7 @@ class LinearOperatorTest(test.TestCase): self.assertTrue(operator_matmul.is_square) self.assertTrue(operator_matmul.is_non_singular) - self.assertTrue(operator_matmul.is_self_adjoint) + self.assertEqual(None, operator_matmul.is_self_adjoint) self.assertEqual(None, operator_matmul.is_positive_definite) @test_util.run_deprecated_v1 diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py index b9f8411c934..088e6e45eaf 100644 --- a/tensorflow/python/ops/linalg/linalg.py +++ b/tensorflow/python/ops/linalg/linalg.py @@ -25,6 +25,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations as _cholesky_reg from tensorflow.python.ops.linalg import inverse_registrations as _inverse_registrations from tensorflow.python.ops.linalg import linear_operator_algebra as _linear_operator_algebra from tensorflow.python.ops.linalg import matmul_registrations as _matmul_registrations +from tensorflow.python.ops.linalg import solve_registrations as _solve_registrations from tensorflow.python.ops.linalg.linalg_impl import * from tensorflow.python.ops.linalg.linear_operator import * from tensorflow.python.ops.linalg.linear_operator_block_diag import * diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 7f40c5f7352..80c91695360 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -597,16 +597,18 @@ class LinearOperator(object): as `self`. """ if isinstance(x, LinearOperator): - if adjoint or adjoint_arg: - raise ValueError(".matmul not supported with adjoints.") - if (x.range_dimension is not None and - self.domain_dimension is not None and - x.range_dimension != self.domain_dimension): + left_operator = self.adjoint() if adjoint else self + right_operator = x.adjoint() if adjoint_arg else x + + if (right_operator.range_dimension is not None and + left_operator.domain_dimension is not None and + right_operator.range_dimension != left_operator.domain_dimension): raise ValueError( "Operators are incompatible. Expected `x` to have dimension" - " {} but got {}.".format(self.domain_dimension, x.range_dimension)) + " {} but got {}.".format( + left_operator.domain_dimension, right_operator.range_dimension)) with self._name_scope(name): - return linear_operator_algebra.matmul(self, x) + return linear_operator_algebra.matmul(left_operator, right_operator) with self._name_scope(name, values=[x]): x = ops.convert_to_tensor(x, name="x") @@ -780,6 +782,20 @@ class LinearOperator(object): raise NotImplementedError( "Exact solve not implemented for an operator that is expected to " "not be square.") + if isinstance(rhs, LinearOperator): + left_operator = self.adjoint() if adjoint else self + right_operator = rhs.adjoint() if adjoint_arg else rhs + + if (right_operator.range_dimension is not None and + left_operator.domain_dimension is not None and + right_operator.range_dimension != left_operator.domain_dimension): + raise ValueError( + "Operators are incompatible. Expected `rhs` to have dimension" + " {} but got {}.".format( + left_operator.domain_dimension, right_operator.range_dimension)) + with self._name_scope(name): + return linear_operator_algebra.solve(left_operator, right_operator) + with self._name_scope(name, values=[rhs]): rhs = ops.convert_to_tensor(rhs, name="rhs") self._check_input_dtype(rhs) diff --git a/tensorflow/python/ops/linalg/linear_operator_algebra.py b/tensorflow/python/ops/linalg/linear_operator_algebra.py index 0d1eab4b735..cd4aceab67d 100644 --- a/tensorflow/python/ops/linalg/linear_operator_algebra.py +++ b/tensorflow/python/ops/linalg/linear_operator_algebra.py @@ -28,6 +28,7 @@ from tensorflow.python.util import tf_inspect _ADJOINTS = {} _CHOLESKY_DECOMPS = {} _MATMUL = {} +_SOLVE = {} _INVERSES = {} @@ -62,6 +63,11 @@ def _registered_matmul(type_a, type_b): return _registered_function([type_a, type_b], _MATMUL) +def _registered_solve(type_a, type_b): + """Get the Solve function registered for classes a and b.""" + return _registered_function([type_a, type_b], _SOLVE) + + def _registered_inverse(type_a): """Get the Cholesky function registered for class a.""" return _registered_function([type_a], _INVERSES) @@ -138,6 +144,31 @@ def matmul(lin_op_a, lin_op_b, name=None): return matmul_fn(lin_op_a, lin_op_b) +def solve(lin_op_a, lin_op_b, name=None): + """Compute lin_op_a.solve(lin_op_b). + + Args: + lin_op_a: The LinearOperator on the left. + lin_op_b: The LinearOperator on the right. + name: Name to use for this operation. + + Returns: + A LinearOperator that represents the solve between `lin_op_a` and + `lin_op_b`. + + Raises: + NotImplementedError: If no solve method is defined between types of + `lin_op_a` and `lin_op_b`. + """ + solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b)) + if solve_fn is None: + raise ValueError("No solve registered for {}.solve({})".format( + type(lin_op_a), type(lin_op_b))) + + with ops.name_scope(name, "Solve"): + return solve_fn(lin_op_a, lin_op_b) + + def inverse(lin_op_a, name=None): """Get the Inverse associated to lin_op_a. @@ -291,6 +322,52 @@ class RegisterMatmul(object): return matmul_fn +class RegisterSolve(object): + """Decorator to register a Solve implementation function. + + Usage: + + @linear_operator_algebra.RegisterSolve( + lin_op.LinearOperatorIdentity, + lin_op.LinearOperatorIdentity) + def _solve_identity(a, b): + # Return the identity matrix. + """ + + def __init__(self, lin_op_cls_a, lin_op_cls_b): + """Initialize the LinearOperator registrar. + + Args: + lin_op_cls_a: the class of the LinearOperator that is computing solve. + lin_op_cls_b: the class of the second LinearOperator to solve. + """ + self._key = (lin_op_cls_a, lin_op_cls_b) + + def __call__(self, solve_fn): + """Perform the Solve registration. + + Args: + solve_fn: The function to use for the Solve. + + Returns: + solve_fn + + Raises: + TypeError: if solve_fn is not a callable. + ValueError: if a Solve function has already been registered for + the given argument classes. + """ + if not callable(solve_fn): + raise TypeError( + "solve_fn must be callable, received: {}".format(solve_fn)) + if self._key in _SOLVE: + raise ValueError("Solve({}, {}) has already been registered.".format( + self._key[0].__name__, + self._key[1].__name__)) + _SOLVE[self._key] = solve_fn + return solve_fn + + class RegisterInverse(object): """Decorator to register an Inverse implementation function. diff --git a/tensorflow/python/ops/linalg/matmul_registrations.py b/tensorflow/python/ops/linalg/matmul_registrations.py index e0ac988ba27..f624351cd99 100644 --- a/tensorflow/python/ops/linalg/matmul_registrations.py +++ b/tensorflow/python/ops/linalg/matmul_registrations.py @@ -26,66 +26,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag from tensorflow.python.ops.linalg import linear_operator_identity from tensorflow.python.ops.linalg import linear_operator_lower_triangular from tensorflow.python.ops.linalg import linear_operator_zeros - - -def _combined_self_adjoint_hint(operator_a, operator_b): - """Get combined hint for self-adjoint-ness.""" - # Note: only use this method in the commuting case. - # The property is preserved under composition when the operators commute. - if operator_a.is_self_adjoint and operator_b.is_self_adjoint: - return True - - # The property is not preserved when an operator with the property is composed - # with an operator without the property. - if ((operator_a.is_self_adjoint is True and - operator_b.is_self_adjoint is False) or - (operator_a.is_self_adjoint is False and - operator_b.is_self_adjoint is True)): - return False - - # The property is not known when operators are not known to have the property - # or both operators don't have the property (the property for the complement - # class is not closed under composition). - return None - - -def _is_square(operator_a, operator_b): - """Return a hint to whether the composition is square.""" - if operator_a.is_square and operator_b.is_square: - return True - if operator_a.is_square is False and operator_b.is_square is False: - # Let A have shape [B, M, N], B have shape [B, N, L]. - m = operator_a.range_dimension - l = operator_b.domain_dimension - if m is not None and l is not None: - return m == l - - return None - - -def _combined_positive_definite_hint(operator_a, operator_b): - """Get combined PD hint for compositions.""" - # Note: Positive definiteness is only guaranteed to be preserved - # when the operators commute and are symmetric. Only use this method in - # commuting cases. - - if (operator_a.is_positive_definite is True and - operator_a.is_self_adjoint is True and - operator_b.is_positive_definite is True and - operator_b.is_self_adjoint is True): - return True - - return None - - -def _combined_non_singular_hint(operator_a, operator_b): - """Get combined hint for when .""" - # If either operator is not-invertible the composition isn't. - if (operator_a.is_non_singular is False or - operator_b.is_non_singular is False): - return False - - return operator_a.is_non_singular and operator_b.is_non_singular +from tensorflow.python.ops.linalg import registrations_util # By default, use a LinearOperatorComposition to delay the computation. @@ -93,15 +34,15 @@ def _combined_non_singular_hint(operator_a, operator_b): linear_operator.LinearOperator, linear_operator.LinearOperator) def _matmul_linear_operator(linop_a, linop_b): """Generic matmul of two `LinearOperator`s.""" - is_square = _is_square(linop_a, linop_b) + is_square = registrations_util.is_square(linop_a, linop_b) is_non_singular = None is_self_adjoint = None is_positive_definite = None if is_square: - is_non_singular = _combined_non_singular_hint(linop_a, linop_b) - is_self_adjoint = _combined_self_adjoint_hint(linop_a, linop_b) - elif is_square is False: + is_non_singular = registrations_util.combined_non_singular_hint( + linop_a, linop_b) + elif is_square is False: # pylint:disable=g-bool-id-comparison is_non_singular = False is_self_adjoint = False is_positive_definite = False @@ -165,11 +106,13 @@ def _matmul_linear_operator_zeros_left(zeros, linop): def _matmul_linear_operator_diag(linop_a, linop_b): return linear_operator_diag.LinearOperatorDiag( diag=linop_a.diag * linop_b.diag, - is_non_singular=_combined_non_singular_hint(linop_a, linop_b), - is_self_adjoint=_combined_self_adjoint_hint( + is_non_singular=registrations_util.combined_non_singular_hint( linop_a, linop_b), - is_positive_definite=_combined_positive_definite_hint( + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( linop_a, linop_b), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_a, linop_b)), is_square=True) @@ -180,12 +123,13 @@ def _matmul_linear_operator_diag_scaled_identity_right( linop_diag, linop_scaled_identity): return linear_operator_diag.LinearOperatorDiag( diag=linop_diag.diag * linop_scaled_identity.multiplier, - is_non_singular=_combined_non_singular_hint( + is_non_singular=registrations_util.combined_non_singular_hint( linop_diag, linop_scaled_identity), - is_self_adjoint=_combined_self_adjoint_hint( - linop_diag, linop_scaled_identity), - is_positive_definite=_combined_positive_definite_hint( + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( linop_diag, linop_scaled_identity), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_diag, linop_scaled_identity)), is_square=True) @@ -196,12 +140,13 @@ def _matmul_linear_operator_diag_scaled_identity_left( linop_scaled_identity, linop_diag): return linear_operator_diag.LinearOperatorDiag( diag=linop_diag.diag * linop_scaled_identity.multiplier, - is_non_singular=_combined_non_singular_hint( + is_non_singular=registrations_util.combined_non_singular_hint( linop_diag, linop_scaled_identity), - is_self_adjoint=_combined_self_adjoint_hint( - linop_diag, linop_scaled_identity), - is_positive_definite=_combined_positive_definite_hint( + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( linop_diag, linop_scaled_identity), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_diag, linop_scaled_identity)), is_square=True) @@ -211,11 +156,11 @@ def _matmul_linear_operator_diag_scaled_identity_left( def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): return linear_operator_lower_triangular.LinearOperatorLowerTriangular( tril=linop_diag.diag[..., None] * linop_triangular.to_dense(), - is_non_singular=_combined_non_singular_hint( + is_non_singular=registrations_util.combined_non_singular_hint( linop_diag, linop_triangular), # This is safe to do since the Triangular matrix is only self-adjoint # when it is a diagonal matrix, and hence commutes. - is_self_adjoint=_combined_self_adjoint_hint( + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( linop_diag, linop_triangular), is_positive_definite=None, is_square=True) @@ -227,11 +172,11 @@ def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): return linear_operator_lower_triangular.LinearOperatorLowerTriangular( tril=linop_triangular.to_dense() * linop_diag.diag, - is_non_singular=_combined_non_singular_hint( + is_non_singular=registrations_util.combined_non_singular_hint( linop_diag, linop_triangular), # This is safe to do since the Triangular matrix is only self-adjoint # when it is a diagonal matrix, and hence commutes. - is_self_adjoint=_combined_self_adjoint_hint( + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( linop_diag, linop_triangular), is_positive_definite=None, is_square=True) @@ -245,8 +190,11 @@ def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): def _matmul_linear_operator_circulant_circulant(linop_a, linop_b): return linear_operator_circulant.LinearOperatorCirculant( spectrum=linop_a.spectrum * linop_b.spectrum, - is_non_singular=_combined_non_singular_hint(linop_a, linop_b), - is_self_adjoint=_combined_self_adjoint_hint(linop_a, linop_b), - is_positive_definite=_combined_positive_definite_hint( + is_non_singular=registrations_util.combined_non_singular_hint( linop_a, linop_b), + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( + linop_a, linop_b), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_a, linop_b)), is_square=True) diff --git a/tensorflow/python/ops/linalg/registrations_util.py b/tensorflow/python/ops/linalg/registrations_util.py new file mode 100644 index 00000000000..c707a67d43c --- /dev/null +++ b/tensorflow/python/ops/linalg/registrations_util.py @@ -0,0 +1,91 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common utilities for registering LinearOperator methods.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# Note: only use this method in the commuting case. +def combined_commuting_self_adjoint_hint(operator_a, operator_b): + """Get combined hint for self-adjoint-ness.""" + + # The property is preserved under composition when the operators commute. + if operator_a.is_self_adjoint and operator_b.is_self_adjoint: + return True + + # The property is not preserved when an operator with the property is composed + # with an operator without the property. + + # pylint:disable=g-bool-id-comparison + if ((operator_a.is_self_adjoint is True and + operator_b.is_self_adjoint is False) or + (operator_a.is_self_adjoint is False and + operator_b.is_self_adjoint is True)): + return False + # pylint:enable=g-bool-id-comparison + + # The property is not known when operators are not known to have the property + # or both operators don't have the property (the property for the complement + # class is not closed under composition). + return None + + +def is_square(operator_a, operator_b): + """Return a hint to whether the composition is square.""" + if operator_a.is_square and operator_b.is_square: + return True + if operator_a.is_square is False and operator_b.is_square is False: # pylint:disable=g-bool-id-comparison + # Let A have shape [B, M, N], B have shape [B, N, L]. + m = operator_a.range_dimension + l = operator_b.domain_dimension + if m is not None and l is not None: + return m == l + + if (operator_a.is_square != operator_b.is_square) and ( + operator_a.is_square is not None and operator_a.is_square is not None): + return False + + return None + + +# Note: Positive definiteness is only guaranteed to be preserved +# when the operators commute and are symmetric. Only use this method in +# commuting cases. +def combined_commuting_positive_definite_hint(operator_a, operator_b): + """Get combined PD hint for compositions.""" + # pylint:disable=g-bool-id-comparison + if (operator_a.is_positive_definite is True and + operator_a.is_self_adjoint is True and + operator_b.is_positive_definite is True and + operator_b.is_self_adjoint is True): + return True + # pylint:enable=g-bool-id-comparison + + return None + + +def combined_non_singular_hint(operator_a, operator_b): + """Get combined hint for when .""" + # If either operator is not-invertible the composition isn't. + + # pylint:disable=g-bool-id-comparison + if (operator_a.is_non_singular is False or + operator_b.is_non_singular is False): + return False + # pylint:enable=g-bool-id-comparison + + return operator_a.is_non_singular and operator_b.is_non_singular diff --git a/tensorflow/python/ops/linalg/solve_registrations.py b/tensorflow/python/ops/linalg/solve_registrations.py new file mode 100644 index 00000000000..cfdce44f0d7 --- /dev/null +++ b/tensorflow/python/ops/linalg/solve_registrations.py @@ -0,0 +1,164 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registrations for LinearOperator.solve.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_algebra +from tensorflow.python.ops.linalg import linear_operator_circulant +from tensorflow.python.ops.linalg import linear_operator_composition +from tensorflow.python.ops.linalg import linear_operator_diag +from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.ops.linalg import linear_operator_inversion +from tensorflow.python.ops.linalg import linear_operator_lower_triangular +from tensorflow.python.ops.linalg import registrations_util + + +# By default, use a LinearOperatorComposition to delay the computation. +@linear_operator_algebra.RegisterSolve( + linear_operator.LinearOperator, linear_operator.LinearOperator) +def _solve_linear_operator(linop_a, linop_b): + """Generic solve of two `LinearOperator`s.""" + is_square = registrations_util.is_square(linop_a, linop_b) + is_non_singular = None + is_self_adjoint = None + is_positive_definite = None + + if is_square: + is_non_singular = registrations_util.combined_non_singular_hint( + linop_a, linop_b) + elif is_square is False: # pylint:disable=g-bool-id-comparison + is_non_singular = False + is_self_adjoint = False + is_positive_definite = False + + return linear_operator_composition.LinearOperatorComposition( + operators=[ + linear_operator_inversion.LinearOperatorInversion(linop_a), + linop_b + ], + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + ) + + +@linear_operator_algebra.RegisterSolve( + linear_operator_inversion.LinearOperatorInversion, + linear_operator.LinearOperator) +def _solve_inverse_linear_operator(linop_a, linop_b): + """Solve inverse of generic `LinearOperator`s.""" + return linop_a.operator.matmul(linop_b) + + +# Identity +@linear_operator_algebra.RegisterSolve( + linear_operator_identity.LinearOperatorIdentity, + linear_operator.LinearOperator) +def _solve_linear_operator_identity_left(identity, linop): + del identity + return linop + + +# Diag. + + +@linear_operator_algebra.RegisterSolve( + linear_operator_diag.LinearOperatorDiag, + linear_operator_diag.LinearOperatorDiag) +def _solve_linear_operator_diag(linop_a, linop_b): + return linear_operator_diag.LinearOperatorDiag( + diag=linop_b.diag / linop_a.diag, + is_non_singular=registrations_util.combined_non_singular_hint( + linop_a, linop_b), + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( + linop_a, linop_b), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_a, linop_b)), + is_square=True) + + +@linear_operator_algebra.RegisterSolve( + linear_operator_diag.LinearOperatorDiag, + linear_operator_identity.LinearOperatorScaledIdentity) +def _solve_linear_operator_diag_scaled_identity_right( + linop_diag, linop_scaled_identity): + return linear_operator_diag.LinearOperatorDiag( + diag=linop_scaled_identity.multiplier / linop_diag.diag, + is_non_singular=registrations_util.combined_non_singular_hint( + linop_diag, linop_scaled_identity), + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( + linop_diag, linop_scaled_identity), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_diag, linop_scaled_identity)), + is_square=True) + + +@linear_operator_algebra.RegisterSolve( + linear_operator_identity.LinearOperatorScaledIdentity, + linear_operator_diag.LinearOperatorDiag) +def _solve_linear_operator_diag_scaled_identity_left( + linop_scaled_identity, linop_diag): + return linear_operator_diag.LinearOperatorDiag( + diag=linop_diag.diag / linop_scaled_identity.multiplier, + is_non_singular=registrations_util.combined_non_singular_hint( + linop_diag, linop_scaled_identity), + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( + linop_diag, linop_scaled_identity), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_diag, linop_scaled_identity)), + is_square=True) + + +@linear_operator_algebra.RegisterSolve( + linear_operator_diag.LinearOperatorDiag, + linear_operator_lower_triangular.LinearOperatorLowerTriangular) +def _solve_linear_operator_diag_tril(linop_diag, linop_triangular): + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + tril=linop_triangular.to_dense() / linop_diag.diag[..., None], + is_non_singular=registrations_util.combined_non_singular_hint( + linop_diag, linop_triangular), + # This is safe to do since the Triangular matrix is only self-adjoint + # when it is a diagonal matrix, and hence commutes. + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( + linop_diag, linop_triangular), + is_positive_definite=None, + is_square=True) + + +# Circulant. + + +@linear_operator_algebra.RegisterSolve( + linear_operator_circulant.LinearOperatorCirculant, + linear_operator_circulant.LinearOperatorCirculant) +def _solve_linear_operator_circulant_circulant(linop_a, linop_b): + return linear_operator_circulant.LinearOperatorCirculant( + spectrum=linop_b.spectrum / linop_a.spectrum, + is_non_singular=registrations_util.combined_non_singular_hint( + linop_a, linop_b), + is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( + linop_a, linop_b), + is_positive_definite=( + registrations_util.combined_commuting_positive_definite_hint( + linop_a, linop_b)), + is_square=True)