From a92016d3fea53414a727177953710c23870fdab1 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Tue, 20 Aug 2019 19:39:25 -0700 Subject: [PATCH] LinearOperatorLowRankUpdate made "tape safe" for TF2 compliance. PiperOrigin-RevId: 264521197 --- .../linear_operator_low_rank_update_test.py | 17 +++++++++++ .../linalg/linear_operator_low_rank_update.py | 28 +++++++------------ 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index c438187e35f..0438120a66c 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.ops.linalg import linear_operator_test_util from tensorflow.python.platform import test @@ -155,6 +156,22 @@ class BaseLinearOperatorLowRankUpdatetest(object): return operator, matrix + def test_tape_safe(self): + base_operator = linalg.LinearOperatorDiag( + variables_module.Variable([1.], name="diag"), + is_positive_definite=True, + is_self_adjoint=True) + + operator = linalg.LinearOperatorLowRankUpdate( + base_operator, + u=variables_module.Variable([[2.]], name="u"), + v=variables_module.Variable([[1.25]], name="v") + if self._use_v else None, + diag_update=variables_module.Variable([1.25], name="diag_update") + if self._use_diag_update else None, + is_diag_update_positive=self._is_diag_update_positive) + self.check_tape_safe(operator) + class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( BaseLinearOperatorLowRankUpdatetest, diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index 2d9626ab7a0..8803d58c15f 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -228,16 +228,16 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): with ops.name_scope(name, values=values): # Create U and V. - self._u = ops.convert_to_tensor(u, name="u") + self._u = linear_operator_util.convert_nonref_to_tensor(u, name="u") if v is None: self._v = self._u else: - self._v = ops.convert_to_tensor(v, name="v") + self._v = linear_operator_util.convert_nonref_to_tensor(v, name="v") if diag_update is None: self._diag_update = None else: - self._diag_update = ops.convert_to_tensor( + self._diag_update = linear_operator_util.convert_nonref_to_tensor( diag_update, name="diag_update") # Create base_operator L. @@ -261,12 +261,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): self._check_shapes() - # Pre-compute the so-called "capacitance" matrix - # C := D^{-1} + V^H L^{-1} U - self._capacitance = self._make_capacitance() - if self._use_cholesky: - self._chol_capacitance = linalg_ops.cholesky(self._capacitance) - def _check_shapes(self): """Static check that shapes are compatible.""" # Broadcast shape also checks that u and v are compatible. @@ -291,8 +285,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): if diag_update is not None: self._diag_operator = linear_operator_diag.LinearOperatorDiag( self._diag_update, is_positive_definite=is_diag_update_positive) - self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag( - 1. / self._diag_update, is_positive_definite=is_diag_update_positive) else: if tensor_shape.dimension_value(self.u.shape[-1]) is not None: r = tensor_shape.dimension_value(self.u.shape[-1]) @@ -300,7 +292,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): r = array_ops.shape(self.u)[-1] self._diag_operator = linear_operator_identity.LinearOperatorIdentity( num_rows=r, dtype=self.dtype) - self._diag_inv_operator = self._diag_operator @property def u(self): @@ -373,7 +364,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): # = det(C) det(D) det(L) # where C is sometimes known as the capacitance matrix, # C := D^{-1} + V^H L^{-1} U - det_c = linalg_ops.matrix_determinant(self._capacitance) + det_c = linalg_ops.matrix_determinant(self._make_capacitance()) det_d = self.diag_operator.determinant() det_l = self.base_operator.determinant() return det_c * det_d * det_l @@ -386,11 +377,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): log_abs_det_l = self.base_operator.log_abs_determinant() if self._use_cholesky: - chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance) + chol_cap_diag = array_ops.matrix_diag_part( + linalg_ops.cholesky(self._make_capacitance())) log_abs_det_c = 2 * math_ops.reduce_sum( math_ops.log(chol_cap_diag), axis=[-1]) else: - det_c = linalg_ops.matrix_determinant(self._capacitance) + det_c = linalg_ops.matrix_determinant(self._make_capacitance()) log_abs_det_c = math_ops.log(math_ops.abs(det_c)) if self.dtype.is_complex: log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) @@ -426,10 +418,10 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): # C^{-1} V^H L^{-1} rhs if self._use_cholesky: capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast( - self._chol_capacitance, vh_linv_rhs) + linalg_ops.cholesky(self._make_capacitance()), vh_linv_rhs) else: capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast( - self._capacitance, vh_linv_rhs, adjoint=adjoint) + self._make_capacitance(), vh_linv_rhs, adjoint=adjoint) # U C^{-1} V^H M^{-1} rhs u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs) # L^{-1} U C^{-1} V^H L^{-1} rhs @@ -448,5 +440,5 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): vh_linv_u = math_ops.matmul(self.v, linv_u, adjoint_a=True) # D^{-1} + V^H L^{-1} V - capacitance = self._diag_inv_operator.add_to_tensor(vh_linv_u) + capacitance = self._diag_operator.inverse().add_to_tensor(vh_linv_u) return capacitance