LinearOperatorLowRankUpdate made "tape safe" for TF2 compliance.
PiperOrigin-RevId: 264521197
This commit is contained in:
		
							parent
							
								
									e9bda5601f
								
							
						
					
					
						commit
						a92016d3fe
					
				| @ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes | |||||||
| from tensorflow.python.framework import test_util | from tensorflow.python.framework import test_util | ||||||
| from tensorflow.python.ops import array_ops | from tensorflow.python.ops import array_ops | ||||||
| from tensorflow.python.ops import math_ops | from tensorflow.python.ops import math_ops | ||||||
|  | from tensorflow.python.ops import variables as variables_module | ||||||
| from tensorflow.python.ops.linalg import linalg as linalg_lib | from tensorflow.python.ops.linalg import linalg as linalg_lib | ||||||
| from tensorflow.python.ops.linalg import linear_operator_test_util | from tensorflow.python.ops.linalg import linear_operator_test_util | ||||||
| from tensorflow.python.platform import test | from tensorflow.python.platform import test | ||||||
| @ -155,6 +156,22 @@ class BaseLinearOperatorLowRankUpdatetest(object): | |||||||
| 
 | 
 | ||||||
|     return operator, matrix |     return operator, matrix | ||||||
| 
 | 
 | ||||||
|  |   def test_tape_safe(self): | ||||||
|  |     base_operator = linalg.LinearOperatorDiag( | ||||||
|  |         variables_module.Variable([1.], name="diag"), | ||||||
|  |         is_positive_definite=True, | ||||||
|  |         is_self_adjoint=True) | ||||||
|  | 
 | ||||||
|  |     operator = linalg.LinearOperatorLowRankUpdate( | ||||||
|  |         base_operator, | ||||||
|  |         u=variables_module.Variable([[2.]], name="u"), | ||||||
|  |         v=variables_module.Variable([[1.25]], name="v") | ||||||
|  |         if self._use_v else None, | ||||||
|  |         diag_update=variables_module.Variable([1.25], name="diag_update") | ||||||
|  |         if self._use_diag_update else None, | ||||||
|  |         is_diag_update_positive=self._is_diag_update_positive) | ||||||
|  |     self.check_tape_safe(operator) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( | class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( | ||||||
|     BaseLinearOperatorLowRankUpdatetest, |     BaseLinearOperatorLowRankUpdatetest, | ||||||
|  | |||||||
| @ -228,16 +228,16 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|     with ops.name_scope(name, values=values): |     with ops.name_scope(name, values=values): | ||||||
| 
 | 
 | ||||||
|       # Create U and V. |       # Create U and V. | ||||||
|       self._u = ops.convert_to_tensor(u, name="u") |       self._u = linear_operator_util.convert_nonref_to_tensor(u, name="u") | ||||||
|       if v is None: |       if v is None: | ||||||
|         self._v = self._u |         self._v = self._u | ||||||
|       else: |       else: | ||||||
|         self._v = ops.convert_to_tensor(v, name="v") |         self._v = linear_operator_util.convert_nonref_to_tensor(v, name="v") | ||||||
| 
 | 
 | ||||||
|       if diag_update is None: |       if diag_update is None: | ||||||
|         self._diag_update = None |         self._diag_update = None | ||||||
|       else: |       else: | ||||||
|         self._diag_update = ops.convert_to_tensor( |         self._diag_update = linear_operator_util.convert_nonref_to_tensor( | ||||||
|             diag_update, name="diag_update") |             diag_update, name="diag_update") | ||||||
| 
 | 
 | ||||||
|       # Create base_operator L. |       # Create base_operator L. | ||||||
| @ -261,12 +261,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
| 
 | 
 | ||||||
|       self._check_shapes() |       self._check_shapes() | ||||||
| 
 | 
 | ||||||
|       # Pre-compute the so-called "capacitance" matrix |  | ||||||
|       #   C := D^{-1} + V^H L^{-1} U |  | ||||||
|       self._capacitance = self._make_capacitance() |  | ||||||
|       if self._use_cholesky: |  | ||||||
|         self._chol_capacitance = linalg_ops.cholesky(self._capacitance) |  | ||||||
| 
 |  | ||||||
|   def _check_shapes(self): |   def _check_shapes(self): | ||||||
|     """Static check that shapes are compatible.""" |     """Static check that shapes are compatible.""" | ||||||
|     # Broadcast shape also checks that u and v are compatible. |     # Broadcast shape also checks that u and v are compatible. | ||||||
| @ -291,8 +285,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|     if diag_update is not None: |     if diag_update is not None: | ||||||
|       self._diag_operator = linear_operator_diag.LinearOperatorDiag( |       self._diag_operator = linear_operator_diag.LinearOperatorDiag( | ||||||
|           self._diag_update, is_positive_definite=is_diag_update_positive) |           self._diag_update, is_positive_definite=is_diag_update_positive) | ||||||
|       self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag( |  | ||||||
|           1. / self._diag_update, is_positive_definite=is_diag_update_positive) |  | ||||||
|     else: |     else: | ||||||
|       if tensor_shape.dimension_value(self.u.shape[-1]) is not None: |       if tensor_shape.dimension_value(self.u.shape[-1]) is not None: | ||||||
|         r = tensor_shape.dimension_value(self.u.shape[-1]) |         r = tensor_shape.dimension_value(self.u.shape[-1]) | ||||||
| @ -300,7 +292,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|         r = array_ops.shape(self.u)[-1] |         r = array_ops.shape(self.u)[-1] | ||||||
|       self._diag_operator = linear_operator_identity.LinearOperatorIdentity( |       self._diag_operator = linear_operator_identity.LinearOperatorIdentity( | ||||||
|           num_rows=r, dtype=self.dtype) |           num_rows=r, dtype=self.dtype) | ||||||
|       self._diag_inv_operator = self._diag_operator |  | ||||||
| 
 | 
 | ||||||
|   @property |   @property | ||||||
|   def u(self): |   def u(self): | ||||||
| @ -373,7 +364,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|     #                  = det(C) det(D) det(L) |     #                  = det(C) det(D) det(L) | ||||||
|     # where C is sometimes known as the capacitance matrix, |     # where C is sometimes known as the capacitance matrix, | ||||||
|     #   C := D^{-1} + V^H L^{-1} U |     #   C := D^{-1} + V^H L^{-1} U | ||||||
|     det_c = linalg_ops.matrix_determinant(self._capacitance) |     det_c = linalg_ops.matrix_determinant(self._make_capacitance()) | ||||||
|     det_d = self.diag_operator.determinant() |     det_d = self.diag_operator.determinant() | ||||||
|     det_l = self.base_operator.determinant() |     det_l = self.base_operator.determinant() | ||||||
|     return det_c * det_d * det_l |     return det_c * det_d * det_l | ||||||
| @ -386,11 +377,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|     log_abs_det_l = self.base_operator.log_abs_determinant() |     log_abs_det_l = self.base_operator.log_abs_determinant() | ||||||
| 
 | 
 | ||||||
|     if self._use_cholesky: |     if self._use_cholesky: | ||||||
|       chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance) |       chol_cap_diag = array_ops.matrix_diag_part( | ||||||
|  |           linalg_ops.cholesky(self._make_capacitance())) | ||||||
|       log_abs_det_c = 2 * math_ops.reduce_sum( |       log_abs_det_c = 2 * math_ops.reduce_sum( | ||||||
|           math_ops.log(chol_cap_diag), axis=[-1]) |           math_ops.log(chol_cap_diag), axis=[-1]) | ||||||
|     else: |     else: | ||||||
|       det_c = linalg_ops.matrix_determinant(self._capacitance) |       det_c = linalg_ops.matrix_determinant(self._make_capacitance()) | ||||||
|       log_abs_det_c = math_ops.log(math_ops.abs(det_c)) |       log_abs_det_c = math_ops.log(math_ops.abs(det_c)) | ||||||
|       if self.dtype.is_complex: |       if self.dtype.is_complex: | ||||||
|         log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) |         log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) | ||||||
| @ -426,10 +418,10 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|     # C^{-1} V^H L^{-1} rhs |     # C^{-1} V^H L^{-1} rhs | ||||||
|     if self._use_cholesky: |     if self._use_cholesky: | ||||||
|       capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast( |       capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast( | ||||||
|           self._chol_capacitance, vh_linv_rhs) |           linalg_ops.cholesky(self._make_capacitance()), vh_linv_rhs) | ||||||
|     else: |     else: | ||||||
|       capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast( |       capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast( | ||||||
|           self._capacitance, vh_linv_rhs, adjoint=adjoint) |           self._make_capacitance(), vh_linv_rhs, adjoint=adjoint) | ||||||
|     # U C^{-1} V^H M^{-1} rhs |     # U C^{-1} V^H M^{-1} rhs | ||||||
|     u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs) |     u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs) | ||||||
|     # L^{-1} U C^{-1} V^H L^{-1} rhs |     # L^{-1} U C^{-1} V^H L^{-1} rhs | ||||||
| @ -448,5 +440,5 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): | |||||||
|     vh_linv_u = math_ops.matmul(self.v, linv_u, adjoint_a=True) |     vh_linv_u = math_ops.matmul(self.v, linv_u, adjoint_a=True) | ||||||
| 
 | 
 | ||||||
|     # D^{-1} + V^H L^{-1} V |     # D^{-1} + V^H L^{-1} V | ||||||
|     capacitance = self._diag_inv_operator.add_to_tensor(vh_linv_u) |     capacitance = self._diag_operator.inverse().add_to_tensor(vh_linv_u) | ||||||
|     return capacitance |     return capacitance | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user