Allow LinearOperator.solve to take a LinearOperator.
PiperOrigin-RevId: 244388120
This commit is contained in:
		
							parent
							
								
									58f67785f6
								
							
						
					
					
						commit
						0aa8055f1a
					
				| @ -17,6 +17,8 @@ from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import numpy as np | ||||
| 
 | ||||
| from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops.linalg import linalg as linalg_lib | ||||
| @ -113,6 +115,110 @@ class LinearOperatorAdjointTest( | ||||
| 
 | ||||
|     self.assertEqual("my_operator_adjoint", operator.name) | ||||
| 
 | ||||
|   def test_matmul_adjoint_operator(self): | ||||
|     matrix1 = np.random.randn(4, 4) | ||||
|     matrix2 = np.random.randn(4, 4) | ||||
|     full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) | ||||
|     full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.matmul(matrix1, matrix2.T), | ||||
|         self.evaluate( | ||||
|             full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.matmul(matrix1.T, matrix2), | ||||
|         self.evaluate( | ||||
|             full_matrix1.matmul(full_matrix2, adjoint=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.matmul(matrix1.T, matrix2.T), | ||||
|         self.evaluate( | ||||
|             full_matrix1.matmul( | ||||
|                 full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|   def test_matmul_adjoint_complex_operator(self): | ||||
|     matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) | ||||
|     matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) | ||||
|     full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) | ||||
|     full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.matmul(matrix1, matrix2.conj().T), | ||||
|         self.evaluate( | ||||
|             full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.matmul(matrix1.conj().T, matrix2), | ||||
|         self.evaluate( | ||||
|             full_matrix1.matmul(full_matrix2, adjoint=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         np.matmul(matrix1.conj().T, matrix2.conj().T), | ||||
|         self.evaluate( | ||||
|             full_matrix1.matmul( | ||||
|                 full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|   def test_solve_adjoint_operator(self): | ||||
|     matrix1 = self.evaluate( | ||||
|         linear_operator_test_util.random_tril_matrix( | ||||
|             [4, 4], dtype=dtypes.float64, force_well_conditioned=True)) | ||||
|     matrix2 = np.random.randn(4, 4) | ||||
|     full_matrix1 = linalg.LinearOperatorLowerTriangular( | ||||
|         matrix1, is_non_singular=True) | ||||
|     full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         self.evaluate(linalg.triangular_solve(matrix1, matrix2.T)), | ||||
|         self.evaluate( | ||||
|             full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         self.evaluate( | ||||
|             linalg.triangular_solve( | ||||
|                 matrix1.T, matrix2, lower=False)), | ||||
|         self.evaluate( | ||||
|             full_matrix1.solve(full_matrix2, adjoint=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         self.evaluate( | ||||
|             linalg.triangular_solve(matrix1.T, matrix2.T, lower=False)), | ||||
|         self.evaluate( | ||||
|             full_matrix1.solve( | ||||
|                 full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|   def test_solve_adjoint_complex_operator(self): | ||||
|     matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( | ||||
|         [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + | ||||
|                             1j * linear_operator_test_util.random_tril_matrix( | ||||
|                                 [4, 4], dtype=dtypes.complex128, | ||||
|                                 force_well_conditioned=True)) | ||||
|     matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) | ||||
| 
 | ||||
|     full_matrix1 = linalg.LinearOperatorLowerTriangular( | ||||
|         matrix1, is_non_singular=True) | ||||
|     full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         self.evaluate(linalg.triangular_solve(matrix1, matrix2.conj().T)), | ||||
|         self.evaluate( | ||||
|             full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         self.evaluate( | ||||
|             linalg.triangular_solve( | ||||
|                 matrix1.conj().T, matrix2, lower=False)), | ||||
|         self.evaluate( | ||||
|             full_matrix1.solve(full_matrix2, adjoint=True).to_dense())) | ||||
| 
 | ||||
|     self.assertAllClose( | ||||
|         self.evaluate( | ||||
|             linalg.triangular_solve( | ||||
|                 matrix1.conj().T, matrix2.conj().T, lower=False)), | ||||
|         self.evaluate( | ||||
|             full_matrix1.solve( | ||||
|                 full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
| 
 | ||||
| class LinearOperatorAdjointNonSquareTest( | ||||
|     linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): | ||||
|  | ||||
| @ -23,6 +23,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations  # pylint: disab | ||||
| from tensorflow.python.ops.linalg import linear_operator | ||||
| from tensorflow.python.ops.linalg import linear_operator_algebra | ||||
| from tensorflow.python.ops.linalg import matmul_registrations  # pylint: disable=unused-import | ||||
| from tensorflow.python.ops.linalg import solve_registrations  # pylint: disable=unused-import | ||||
| from tensorflow.python.platform import test | ||||
| 
 | ||||
| # pylint: disable=protected-access | ||||
| @ -34,6 +35,8 @@ _INVERSES = linear_operator_algebra._INVERSES | ||||
| _registered_inverse = linear_operator_algebra._registered_inverse | ||||
| _MATMUL = linear_operator_algebra._MATMUL | ||||
| _registered_matmul = linear_operator_algebra._registered_matmul | ||||
| _SOLVE = linear_operator_algebra._SOLVE | ||||
| _registered_solve = linear_operator_algebra._registered_solve | ||||
| # pylint: enable=protected-access | ||||
| 
 | ||||
| 
 | ||||
| @ -175,6 +178,55 @@ class MatmulTest(test.TestCase): | ||||
|       self.assertEqual(v, _registered_matmul(k[0], k[1])) | ||||
| 
 | ||||
| 
 | ||||
| class SolveTest(test.TestCase): | ||||
| 
 | ||||
|   def testRegistration(self): | ||||
| 
 | ||||
|     class CustomLinOp(linear_operator.LinearOperator): | ||||
| 
 | ||||
|       def _matmul(self, a): | ||||
|         pass | ||||
| 
 | ||||
|       def _solve(self, a): | ||||
|         pass | ||||
| 
 | ||||
|       def _shape(self): | ||||
|         return tensor_shape.TensorShape([1, 1]) | ||||
| 
 | ||||
|       def _shape_tensor(self): | ||||
|         pass | ||||
| 
 | ||||
|     # Register Solve to a lambda that spits out the name parameter | ||||
|     @linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp) | ||||
|     def _solve(a, b):  # pylint: disable=unused-argument,unused-variable | ||||
|       return "OK" | ||||
| 
 | ||||
|     custom_linop = CustomLinOp( | ||||
|         dtype=None, is_self_adjoint=True, is_positive_definite=True) | ||||
|     self.assertEqual("OK", custom_linop.solve(custom_linop)) | ||||
| 
 | ||||
|   def testRegistrationFailures(self): | ||||
| 
 | ||||
|     class CustomLinOp(linear_operator.LinearOperator): | ||||
|       pass | ||||
| 
 | ||||
|     with self.assertRaisesRegexp(TypeError, "must be callable"): | ||||
|       linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)("blah") | ||||
| 
 | ||||
|     # First registration is OK | ||||
|     linear_operator_algebra.RegisterSolve( | ||||
|         CustomLinOp, CustomLinOp)(lambda a: None) | ||||
| 
 | ||||
|     # Second registration fails | ||||
|     with self.assertRaisesRegexp(ValueError, "has already been registered"): | ||||
|       linear_operator_algebra.RegisterSolve( | ||||
|           CustomLinOp, CustomLinOp)(lambda a: None) | ||||
| 
 | ||||
|   def testExactSolveRegistrationsAllMatch(self): | ||||
|     for (k, v) in _SOLVE.items(): | ||||
|       self.assertEqual(v, _registered_solve(k[0], k[1])) | ||||
| 
 | ||||
| 
 | ||||
| class InverseTest(test.TestCase): | ||||
| 
 | ||||
|   def testRegistration(self): | ||||
|  | ||||
| @ -187,6 +187,35 @@ class LinearOperatorDiagTest( | ||||
|         linalg_lib.LinearOperatorDiag)) | ||||
|     self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag)) | ||||
| 
 | ||||
|   def test_diag_solve(self): | ||||
|     operator1 = linalg_lib.LinearOperatorDiag([2., 3.], is_non_singular=True) | ||||
|     operator2 = linalg_lib.LinearOperatorDiag([1., 2.], is_non_singular=True) | ||||
|     operator3 = linalg_lib.LinearOperatorScaledIdentity( | ||||
|         num_rows=2, multiplier=3., is_non_singular=True) | ||||
|     operator_solve = operator1.solve(operator2) | ||||
|     self.assertTrue(isinstance( | ||||
|         operator_solve, | ||||
|         linalg_lib.LinearOperatorDiag)) | ||||
|     self.assertAllClose([0.5, 2 / 3.], self.evaluate(operator_solve.diag)) | ||||
| 
 | ||||
|     operator_solve = operator2.solve(operator1) | ||||
|     self.assertTrue(isinstance( | ||||
|         operator_solve, | ||||
|         linalg_lib.LinearOperatorDiag)) | ||||
|     self.assertAllClose([2., 3 / 2.], self.evaluate(operator_solve.diag)) | ||||
| 
 | ||||
|     operator_solve = operator1.solve(operator3) | ||||
|     self.assertTrue(isinstance( | ||||
|         operator_solve, | ||||
|         linalg_lib.LinearOperatorDiag)) | ||||
|     self.assertAllClose([3 / 2., 1.], self.evaluate(operator_solve.diag)) | ||||
| 
 | ||||
|     operator_solve = operator3.solve(operator1) | ||||
|     self.assertTrue(isinstance( | ||||
|         operator_solve, | ||||
|         linalg_lib.LinearOperatorDiag)) | ||||
|     self.assertAllClose([2 / 3., 1.], self.evaluate(operator_solve.diag)) | ||||
| 
 | ||||
|   def test_diag_adjoint_type(self): | ||||
|     diag = [1., 3., 5., 8.] | ||||
|     operator = linalg.LinearOperatorDiag(diag, is_non_singular=True) | ||||
|  | ||||
| @ -495,6 +495,20 @@ class LinearOperatorScaledIdentityTest( | ||||
|         linalg_lib.LinearOperatorScaledIdentity)) | ||||
|     self.assertAllClose(3., self.evaluate(operator_matmul.multiplier)) | ||||
| 
 | ||||
|   def test_identity_solve(self): | ||||
|     operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2) | ||||
|     operator2 = linalg_lib.LinearOperatorScaledIdentity( | ||||
|         num_rows=2, multiplier=3.) | ||||
|     self.assertTrue(isinstance( | ||||
|         operator1.solve(operator1), | ||||
|         linalg_lib.LinearOperatorIdentity)) | ||||
| 
 | ||||
|     operator_solve = operator1.solve(operator2) | ||||
|     self.assertTrue(isinstance( | ||||
|         operator_solve, | ||||
|         linalg_lib.LinearOperatorScaledIdentity)) | ||||
|     self.assertAllClose(3., self.evaluate(operator_solve.multiplier)) | ||||
| 
 | ||||
|   def test_scaled_identity_cholesky_type(self): | ||||
|     operator = linalg_lib.LinearOperatorScaledIdentity( | ||||
|         num_rows=2, | ||||
|  | ||||
| @ -238,7 +238,7 @@ class LinearOperatorTest(test.TestCase): | ||||
| 
 | ||||
|     self.assertTrue(operator_matmul.is_square) | ||||
|     self.assertTrue(operator_matmul.is_non_singular) | ||||
|     self.assertTrue(operator_matmul.is_self_adjoint) | ||||
|     self.assertEqual(None, operator_matmul.is_self_adjoint) | ||||
|     self.assertEqual(None, operator_matmul.is_positive_definite) | ||||
| 
 | ||||
|   @test_util.run_deprecated_v1 | ||||
|  | ||||
| @ -25,6 +25,7 @@ from tensorflow.python.ops.linalg import cholesky_registrations as _cholesky_reg | ||||
| from tensorflow.python.ops.linalg import inverse_registrations as _inverse_registrations | ||||
| from tensorflow.python.ops.linalg import linear_operator_algebra as _linear_operator_algebra | ||||
| from tensorflow.python.ops.linalg import matmul_registrations as _matmul_registrations | ||||
| from tensorflow.python.ops.linalg import solve_registrations as _solve_registrations | ||||
| from tensorflow.python.ops.linalg.linalg_impl import * | ||||
| from tensorflow.python.ops.linalg.linear_operator import * | ||||
| from tensorflow.python.ops.linalg.linear_operator_block_diag import * | ||||
|  | ||||
| @ -597,16 +597,18 @@ class LinearOperator(object): | ||||
|         as `self`. | ||||
|     """ | ||||
|     if isinstance(x, LinearOperator): | ||||
|       if adjoint or adjoint_arg: | ||||
|         raise ValueError(".matmul not supported with adjoints.") | ||||
|       if (x.range_dimension is not None and | ||||
|           self.domain_dimension is not None and | ||||
|           x.range_dimension != self.domain_dimension): | ||||
|       left_operator = self.adjoint() if adjoint else self | ||||
|       right_operator = x.adjoint() if adjoint_arg else x | ||||
| 
 | ||||
|       if (right_operator.range_dimension is not None and | ||||
|           left_operator.domain_dimension is not None and | ||||
|           right_operator.range_dimension != left_operator.domain_dimension): | ||||
|         raise ValueError( | ||||
|             "Operators are incompatible. Expected `x` to have dimension" | ||||
|             " {} but got {}.".format(self.domain_dimension, x.range_dimension)) | ||||
|             " {} but got {}.".format( | ||||
|                 left_operator.domain_dimension, right_operator.range_dimension)) | ||||
|       with self._name_scope(name): | ||||
|         return linear_operator_algebra.matmul(self, x) | ||||
|         return linear_operator_algebra.matmul(left_operator, right_operator) | ||||
| 
 | ||||
|     with self._name_scope(name, values=[x]): | ||||
|       x = ops.convert_to_tensor(x, name="x") | ||||
| @ -780,6 +782,20 @@ class LinearOperator(object): | ||||
|       raise NotImplementedError( | ||||
|           "Exact solve not implemented for an operator that is expected to " | ||||
|           "not be square.") | ||||
|     if isinstance(rhs, LinearOperator): | ||||
|       left_operator = self.adjoint() if adjoint else self | ||||
|       right_operator = rhs.adjoint() if adjoint_arg else rhs | ||||
| 
 | ||||
|       if (right_operator.range_dimension is not None and | ||||
|           left_operator.domain_dimension is not None and | ||||
|           right_operator.range_dimension != left_operator.domain_dimension): | ||||
|         raise ValueError( | ||||
|             "Operators are incompatible. Expected `rhs` to have dimension" | ||||
|             " {} but got {}.".format( | ||||
|                 left_operator.domain_dimension, right_operator.range_dimension)) | ||||
|       with self._name_scope(name): | ||||
|         return linear_operator_algebra.solve(left_operator, right_operator) | ||||
| 
 | ||||
|     with self._name_scope(name, values=[rhs]): | ||||
|       rhs = ops.convert_to_tensor(rhs, name="rhs") | ||||
|       self._check_input_dtype(rhs) | ||||
|  | ||||
| @ -28,6 +28,7 @@ from tensorflow.python.util import tf_inspect | ||||
| _ADJOINTS = {} | ||||
| _CHOLESKY_DECOMPS = {} | ||||
| _MATMUL = {} | ||||
| _SOLVE = {} | ||||
| _INVERSES = {} | ||||
| 
 | ||||
| 
 | ||||
| @ -62,6 +63,11 @@ def _registered_matmul(type_a, type_b): | ||||
|   return _registered_function([type_a, type_b], _MATMUL) | ||||
| 
 | ||||
| 
 | ||||
| def _registered_solve(type_a, type_b): | ||||
|   """Get the Solve function registered for classes a and b.""" | ||||
|   return _registered_function([type_a, type_b], _SOLVE) | ||||
| 
 | ||||
| 
 | ||||
| def _registered_inverse(type_a): | ||||
|   """Get the Cholesky function registered for class a.""" | ||||
|   return _registered_function([type_a], _INVERSES) | ||||
| @ -138,6 +144,31 @@ def matmul(lin_op_a, lin_op_b, name=None): | ||||
|     return matmul_fn(lin_op_a, lin_op_b) | ||||
| 
 | ||||
| 
 | ||||
| def solve(lin_op_a, lin_op_b, name=None): | ||||
|   """Compute lin_op_a.solve(lin_op_b). | ||||
| 
 | ||||
|   Args: | ||||
|     lin_op_a: The LinearOperator on the left. | ||||
|     lin_op_b: The LinearOperator on the right. | ||||
|     name: Name to use for this operation. | ||||
| 
 | ||||
|   Returns: | ||||
|     A LinearOperator that represents the solve between `lin_op_a` and | ||||
|       `lin_op_b`. | ||||
| 
 | ||||
|   Raises: | ||||
|     NotImplementedError: If no solve method is defined between types of | ||||
|       `lin_op_a` and `lin_op_b`. | ||||
|   """ | ||||
|   solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b)) | ||||
|   if solve_fn is None: | ||||
|     raise ValueError("No solve registered for {}.solve({})".format( | ||||
|         type(lin_op_a), type(lin_op_b))) | ||||
| 
 | ||||
|   with ops.name_scope(name, "Solve"): | ||||
|     return solve_fn(lin_op_a, lin_op_b) | ||||
| 
 | ||||
| 
 | ||||
| def inverse(lin_op_a, name=None): | ||||
|   """Get the Inverse associated to lin_op_a. | ||||
| 
 | ||||
| @ -291,6 +322,52 @@ class RegisterMatmul(object): | ||||
|     return matmul_fn | ||||
| 
 | ||||
| 
 | ||||
| class RegisterSolve(object): | ||||
|   """Decorator to register a Solve implementation function. | ||||
| 
 | ||||
|   Usage: | ||||
| 
 | ||||
|   @linear_operator_algebra.RegisterSolve( | ||||
|     lin_op.LinearOperatorIdentity, | ||||
|     lin_op.LinearOperatorIdentity) | ||||
|   def _solve_identity(a, b): | ||||
|     # Return the identity matrix. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, lin_op_cls_a, lin_op_cls_b): | ||||
|     """Initialize the LinearOperator registrar. | ||||
| 
 | ||||
|     Args: | ||||
|       lin_op_cls_a: the class of the LinearOperator that is computing solve. | ||||
|       lin_op_cls_b: the class of the second LinearOperator to solve. | ||||
|     """ | ||||
|     self._key = (lin_op_cls_a, lin_op_cls_b) | ||||
| 
 | ||||
|   def __call__(self, solve_fn): | ||||
|     """Perform the Solve registration. | ||||
| 
 | ||||
|     Args: | ||||
|       solve_fn: The function to use for the Solve. | ||||
| 
 | ||||
|     Returns: | ||||
|       solve_fn | ||||
| 
 | ||||
|     Raises: | ||||
|       TypeError: if solve_fn is not a callable. | ||||
|       ValueError: if a Solve function has already been registered for | ||||
|         the given argument classes. | ||||
|     """ | ||||
|     if not callable(solve_fn): | ||||
|       raise TypeError( | ||||
|           "solve_fn must be callable, received: {}".format(solve_fn)) | ||||
|     if self._key in _SOLVE: | ||||
|       raise ValueError("Solve({}, {}) has already been registered.".format( | ||||
|           self._key[0].__name__, | ||||
|           self._key[1].__name__)) | ||||
|     _SOLVE[self._key] = solve_fn | ||||
|     return solve_fn | ||||
| 
 | ||||
| 
 | ||||
| class RegisterInverse(object): | ||||
|   """Decorator to register an Inverse implementation function. | ||||
| 
 | ||||
|  | ||||
| @ -26,66 +26,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag | ||||
| from tensorflow.python.ops.linalg import linear_operator_identity | ||||
| from tensorflow.python.ops.linalg import linear_operator_lower_triangular | ||||
| from tensorflow.python.ops.linalg import linear_operator_zeros | ||||
| 
 | ||||
| 
 | ||||
| def _combined_self_adjoint_hint(operator_a, operator_b): | ||||
|   """Get combined hint for self-adjoint-ness.""" | ||||
|   # Note: only use this method in the commuting case. | ||||
|   # The property is preserved under composition when the operators commute. | ||||
|   if operator_a.is_self_adjoint and operator_b.is_self_adjoint: | ||||
|     return True | ||||
| 
 | ||||
|   # The property is not preserved when an operator with the property is composed | ||||
|   # with an operator without the property. | ||||
|   if ((operator_a.is_self_adjoint is True and | ||||
|        operator_b.is_self_adjoint is False) or | ||||
|       (operator_a.is_self_adjoint is False and | ||||
|        operator_b.is_self_adjoint is True)): | ||||
|     return False | ||||
| 
 | ||||
|   # The property is not known when operators are not known to have the property | ||||
|   # or both operators don't have the property (the property for the complement | ||||
|   # class is not closed under composition). | ||||
|   return None | ||||
| 
 | ||||
| 
 | ||||
| def _is_square(operator_a, operator_b): | ||||
|   """Return a hint to whether the composition is square.""" | ||||
|   if operator_a.is_square and operator_b.is_square: | ||||
|     return True | ||||
|   if operator_a.is_square is False and operator_b.is_square is False: | ||||
|     # Let A have shape [B, M, N], B have shape [B, N, L]. | ||||
|     m = operator_a.range_dimension | ||||
|     l = operator_b.domain_dimension | ||||
|     if m is not None and l is not None: | ||||
|       return m == l | ||||
| 
 | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def _combined_positive_definite_hint(operator_a, operator_b): | ||||
|   """Get combined PD hint for compositions.""" | ||||
|   # Note: Positive definiteness is only guaranteed to be preserved | ||||
|   # when the operators commute and are symmetric. Only use this method in | ||||
|   # commuting cases. | ||||
| 
 | ||||
|   if (operator_a.is_positive_definite is True and | ||||
|       operator_a.is_self_adjoint is True and | ||||
|       operator_b.is_positive_definite is True and | ||||
|       operator_b.is_self_adjoint is True): | ||||
|     return True | ||||
| 
 | ||||
|   return None | ||||
| 
 | ||||
| 
 | ||||
| def _combined_non_singular_hint(operator_a, operator_b): | ||||
|   """Get combined hint for when .""" | ||||
|   # If either operator is not-invertible the composition isn't. | ||||
|   if (operator_a.is_non_singular is False or | ||||
|       operator_b.is_non_singular is False): | ||||
|     return False | ||||
| 
 | ||||
|   return operator_a.is_non_singular and operator_b.is_non_singular | ||||
| from tensorflow.python.ops.linalg import registrations_util | ||||
| 
 | ||||
| 
 | ||||
| # By default, use a LinearOperatorComposition to delay the computation. | ||||
| @ -93,15 +34,15 @@ def _combined_non_singular_hint(operator_a, operator_b): | ||||
|     linear_operator.LinearOperator, linear_operator.LinearOperator) | ||||
| def _matmul_linear_operator(linop_a, linop_b): | ||||
|   """Generic matmul of two `LinearOperator`s.""" | ||||
|   is_square = _is_square(linop_a, linop_b) | ||||
|   is_square = registrations_util.is_square(linop_a, linop_b) | ||||
|   is_non_singular = None | ||||
|   is_self_adjoint = None | ||||
|   is_positive_definite = None | ||||
| 
 | ||||
|   if is_square: | ||||
|     is_non_singular = _combined_non_singular_hint(linop_a, linop_b) | ||||
|     is_self_adjoint = _combined_self_adjoint_hint(linop_a, linop_b) | ||||
|   elif is_square is False: | ||||
|     is_non_singular = registrations_util.combined_non_singular_hint( | ||||
|         linop_a, linop_b) | ||||
|   elif is_square is False:  # pylint:disable=g-bool-id-comparison | ||||
|     is_non_singular = False | ||||
|     is_self_adjoint = False | ||||
|     is_positive_definite = False | ||||
| @ -165,11 +106,13 @@ def _matmul_linear_operator_zeros_left(zeros, linop): | ||||
| def _matmul_linear_operator_diag(linop_a, linop_b): | ||||
|   return linear_operator_diag.LinearOperatorDiag( | ||||
|       diag=linop_a.diag * linop_b.diag, | ||||
|       is_non_singular=_combined_non_singular_hint(linop_a, linop_b), | ||||
|       is_self_adjoint=_combined_self_adjoint_hint( | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_positive_definite=_combined_positive_definite_hint( | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_a, linop_b)), | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| @ -180,12 +123,13 @@ def _matmul_linear_operator_diag_scaled_identity_right( | ||||
|     linop_diag, linop_scaled_identity): | ||||
|   return linear_operator_diag.LinearOperatorDiag( | ||||
|       diag=linop_diag.diag * linop_scaled_identity.multiplier, | ||||
|       is_non_singular=_combined_non_singular_hint( | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_self_adjoint=_combined_self_adjoint_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_positive_definite=_combined_positive_definite_hint( | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_diag, linop_scaled_identity)), | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| @ -196,12 +140,13 @@ def _matmul_linear_operator_diag_scaled_identity_left( | ||||
|     linop_scaled_identity, linop_diag): | ||||
|   return linear_operator_diag.LinearOperatorDiag( | ||||
|       diag=linop_diag.diag * linop_scaled_identity.multiplier, | ||||
|       is_non_singular=_combined_non_singular_hint( | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_self_adjoint=_combined_self_adjoint_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_positive_definite=_combined_positive_definite_hint( | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_diag, linop_scaled_identity)), | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| @ -211,11 +156,11 @@ def _matmul_linear_operator_diag_scaled_identity_left( | ||||
| def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): | ||||
|   return linear_operator_lower_triangular.LinearOperatorLowerTriangular( | ||||
|       tril=linop_diag.diag[..., None] * linop_triangular.to_dense(), | ||||
|       is_non_singular=_combined_non_singular_hint( | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_triangular), | ||||
|       # This is safe to do since the Triangular matrix is only self-adjoint | ||||
|       # when it is a diagonal matrix, and hence commutes. | ||||
|       is_self_adjoint=_combined_self_adjoint_hint( | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_triangular), | ||||
|       is_positive_definite=None, | ||||
|       is_square=True) | ||||
| @ -227,11 +172,11 @@ def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): | ||||
| def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): | ||||
|   return linear_operator_lower_triangular.LinearOperatorLowerTriangular( | ||||
|       tril=linop_triangular.to_dense() * linop_diag.diag, | ||||
|       is_non_singular=_combined_non_singular_hint( | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_triangular), | ||||
|       # This is safe to do since the Triangular matrix is only self-adjoint | ||||
|       # when it is a diagonal matrix, and hence commutes. | ||||
|       is_self_adjoint=_combined_self_adjoint_hint( | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_triangular), | ||||
|       is_positive_definite=None, | ||||
|       is_square=True) | ||||
| @ -245,8 +190,11 @@ def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): | ||||
| def _matmul_linear_operator_circulant_circulant(linop_a, linop_b): | ||||
|   return linear_operator_circulant.LinearOperatorCirculant( | ||||
|       spectrum=linop_a.spectrum * linop_b.spectrum, | ||||
|       is_non_singular=_combined_non_singular_hint(linop_a, linop_b), | ||||
|       is_self_adjoint=_combined_self_adjoint_hint(linop_a, linop_b), | ||||
|       is_positive_definite=_combined_positive_definite_hint( | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_a, linop_b)), | ||||
|       is_square=True) | ||||
|  | ||||
							
								
								
									
										91
									
								
								tensorflow/python/ops/linalg/registrations_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								tensorflow/python/ops/linalg/registrations_util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,91 @@ | ||||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Common utilities for registering LinearOperator methods.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| 
 | ||||
| # Note: only use this method in the commuting case. | ||||
| def combined_commuting_self_adjoint_hint(operator_a, operator_b): | ||||
|   """Get combined hint for self-adjoint-ness.""" | ||||
| 
 | ||||
|   # The property is preserved under composition when the operators commute. | ||||
|   if operator_a.is_self_adjoint and operator_b.is_self_adjoint: | ||||
|     return True | ||||
| 
 | ||||
|   # The property is not preserved when an operator with the property is composed | ||||
|   # with an operator without the property. | ||||
| 
 | ||||
|   # pylint:disable=g-bool-id-comparison | ||||
|   if ((operator_a.is_self_adjoint is True and | ||||
|        operator_b.is_self_adjoint is False) or | ||||
|       (operator_a.is_self_adjoint is False and | ||||
|        operator_b.is_self_adjoint is True)): | ||||
|     return False | ||||
|   # pylint:enable=g-bool-id-comparison | ||||
| 
 | ||||
|   # The property is not known when operators are not known to have the property | ||||
|   # or both operators don't have the property (the property for the complement | ||||
|   # class is not closed under composition). | ||||
|   return None | ||||
| 
 | ||||
| 
 | ||||
| def is_square(operator_a, operator_b): | ||||
|   """Return a hint to whether the composition is square.""" | ||||
|   if operator_a.is_square and operator_b.is_square: | ||||
|     return True | ||||
|   if operator_a.is_square is False and operator_b.is_square is False:  # pylint:disable=g-bool-id-comparison | ||||
|     # Let A have shape [B, M, N], B have shape [B, N, L]. | ||||
|     m = operator_a.range_dimension | ||||
|     l = operator_b.domain_dimension | ||||
|     if m is not None and l is not None: | ||||
|       return m == l | ||||
| 
 | ||||
|   if (operator_a.is_square != operator_b.is_square) and ( | ||||
|       operator_a.is_square is not None and operator_a.is_square is not None): | ||||
|     return False | ||||
| 
 | ||||
|   return None | ||||
| 
 | ||||
| 
 | ||||
| # Note: Positive definiteness is only guaranteed to be preserved | ||||
| # when the operators commute and are symmetric. Only use this method in | ||||
| # commuting cases. | ||||
| def combined_commuting_positive_definite_hint(operator_a, operator_b): | ||||
|   """Get combined PD hint for compositions.""" | ||||
|   # pylint:disable=g-bool-id-comparison | ||||
|   if (operator_a.is_positive_definite is True and | ||||
|       operator_a.is_self_adjoint is True and | ||||
|       operator_b.is_positive_definite is True and | ||||
|       operator_b.is_self_adjoint is True): | ||||
|     return True | ||||
|   # pylint:enable=g-bool-id-comparison | ||||
| 
 | ||||
|   return None | ||||
| 
 | ||||
| 
 | ||||
| def combined_non_singular_hint(operator_a, operator_b): | ||||
|   """Get combined hint for when .""" | ||||
|   # If either operator is not-invertible the composition isn't. | ||||
| 
 | ||||
|   # pylint:disable=g-bool-id-comparison | ||||
|   if (operator_a.is_non_singular is False or | ||||
|       operator_b.is_non_singular is False): | ||||
|     return False | ||||
|   # pylint:enable=g-bool-id-comparison | ||||
| 
 | ||||
|   return operator_a.is_non_singular and operator_b.is_non_singular | ||||
							
								
								
									
										164
									
								
								tensorflow/python/ops/linalg/solve_registrations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								tensorflow/python/ops/linalg/solve_registrations.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,164 @@ | ||||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Registrations for LinearOperator.solve.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.ops.linalg import linear_operator | ||||
| from tensorflow.python.ops.linalg import linear_operator_algebra | ||||
| from tensorflow.python.ops.linalg import linear_operator_circulant | ||||
| from tensorflow.python.ops.linalg import linear_operator_composition | ||||
| from tensorflow.python.ops.linalg import linear_operator_diag | ||||
| from tensorflow.python.ops.linalg import linear_operator_identity | ||||
| from tensorflow.python.ops.linalg import linear_operator_inversion | ||||
| from tensorflow.python.ops.linalg import linear_operator_lower_triangular | ||||
| from tensorflow.python.ops.linalg import registrations_util | ||||
| 
 | ||||
| 
 | ||||
| # By default, use a LinearOperatorComposition to delay the computation. | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator.LinearOperator, linear_operator.LinearOperator) | ||||
| def _solve_linear_operator(linop_a, linop_b): | ||||
|   """Generic solve of two `LinearOperator`s.""" | ||||
|   is_square = registrations_util.is_square(linop_a, linop_b) | ||||
|   is_non_singular = None | ||||
|   is_self_adjoint = None | ||||
|   is_positive_definite = None | ||||
| 
 | ||||
|   if is_square: | ||||
|     is_non_singular = registrations_util.combined_non_singular_hint( | ||||
|         linop_a, linop_b) | ||||
|   elif is_square is False:  # pylint:disable=g-bool-id-comparison | ||||
|     is_non_singular = False | ||||
|     is_self_adjoint = False | ||||
|     is_positive_definite = False | ||||
| 
 | ||||
|   return linear_operator_composition.LinearOperatorComposition( | ||||
|       operators=[ | ||||
|           linear_operator_inversion.LinearOperatorInversion(linop_a), | ||||
|           linop_b | ||||
|       ], | ||||
|       is_non_singular=is_non_singular, | ||||
|       is_self_adjoint=is_self_adjoint, | ||||
|       is_positive_definite=is_positive_definite, | ||||
|       is_square=is_square, | ||||
|   ) | ||||
| 
 | ||||
| 
 | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_inversion.LinearOperatorInversion, | ||||
|     linear_operator.LinearOperator) | ||||
| def _solve_inverse_linear_operator(linop_a, linop_b): | ||||
|   """Solve inverse of generic `LinearOperator`s.""" | ||||
|   return linop_a.operator.matmul(linop_b) | ||||
| 
 | ||||
| 
 | ||||
| # Identity | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_identity.LinearOperatorIdentity, | ||||
|     linear_operator.LinearOperator) | ||||
| def _solve_linear_operator_identity_left(identity, linop): | ||||
|   del identity | ||||
|   return linop | ||||
| 
 | ||||
| 
 | ||||
| # Diag. | ||||
| 
 | ||||
| 
 | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_diag.LinearOperatorDiag, | ||||
|     linear_operator_diag.LinearOperatorDiag) | ||||
| def _solve_linear_operator_diag(linop_a, linop_b): | ||||
|   return linear_operator_diag.LinearOperatorDiag( | ||||
|       diag=linop_b.diag / linop_a.diag, | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_a, linop_b)), | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_diag.LinearOperatorDiag, | ||||
|     linear_operator_identity.LinearOperatorScaledIdentity) | ||||
| def _solve_linear_operator_diag_scaled_identity_right( | ||||
|     linop_diag, linop_scaled_identity): | ||||
|   return linear_operator_diag.LinearOperatorDiag( | ||||
|       diag=linop_scaled_identity.multiplier / linop_diag.diag, | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_diag, linop_scaled_identity)), | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_identity.LinearOperatorScaledIdentity, | ||||
|     linear_operator_diag.LinearOperatorDiag) | ||||
| def _solve_linear_operator_diag_scaled_identity_left( | ||||
|     linop_scaled_identity, linop_diag): | ||||
|   return linear_operator_diag.LinearOperatorDiag( | ||||
|       diag=linop_diag.diag / linop_scaled_identity.multiplier, | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_scaled_identity), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_diag, linop_scaled_identity)), | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_diag.LinearOperatorDiag, | ||||
|     linear_operator_lower_triangular.LinearOperatorLowerTriangular) | ||||
| def _solve_linear_operator_diag_tril(linop_diag, linop_triangular): | ||||
|   return linear_operator_lower_triangular.LinearOperatorLowerTriangular( | ||||
|       tril=linop_triangular.to_dense() / linop_diag.diag[..., None], | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_diag, linop_triangular), | ||||
|       # This is safe to do since the Triangular matrix is only self-adjoint | ||||
|       # when it is a diagonal matrix, and hence commutes. | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_diag, linop_triangular), | ||||
|       is_positive_definite=None, | ||||
|       is_square=True) | ||||
| 
 | ||||
| 
 | ||||
| # Circulant. | ||||
| 
 | ||||
| 
 | ||||
| @linear_operator_algebra.RegisterSolve( | ||||
|     linear_operator_circulant.LinearOperatorCirculant, | ||||
|     linear_operator_circulant.LinearOperatorCirculant) | ||||
| def _solve_linear_operator_circulant_circulant(linop_a, linop_b): | ||||
|   return linear_operator_circulant.LinearOperatorCirculant( | ||||
|       spectrum=linop_b.spectrum / linop_a.spectrum, | ||||
|       is_non_singular=registrations_util.combined_non_singular_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( | ||||
|           linop_a, linop_b), | ||||
|       is_positive_definite=( | ||||
|           registrations_util.combined_commuting_positive_definite_hint( | ||||
|               linop_a, linop_b)), | ||||
|       is_square=True) | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user