LinearOperatorLowRankUpdate made "tape safe" for TF2 compliance.
PiperOrigin-RevId: 264521197
This commit is contained in:
parent
e9bda5601f
commit
a92016d3fe
@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
from tensorflow.python.platform import test
|
||||
@ -155,6 +156,22 @@ class BaseLinearOperatorLowRankUpdatetest(object):
|
||||
|
||||
return operator, matrix
|
||||
|
||||
def test_tape_safe(self):
|
||||
base_operator = linalg.LinearOperatorDiag(
|
||||
variables_module.Variable([1.], name="diag"),
|
||||
is_positive_definite=True,
|
||||
is_self_adjoint=True)
|
||||
|
||||
operator = linalg.LinearOperatorLowRankUpdate(
|
||||
base_operator,
|
||||
u=variables_module.Variable([[2.]], name="u"),
|
||||
v=variables_module.Variable([[1.25]], name="v")
|
||||
if self._use_v else None,
|
||||
diag_update=variables_module.Variable([1.25], name="diag_update")
|
||||
if self._use_diag_update else None,
|
||||
is_diag_update_positive=self._is_diag_update_positive)
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
|
||||
class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
|
||||
BaseLinearOperatorLowRankUpdatetest,
|
||||
|
@ -228,16 +228,16 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
with ops.name_scope(name, values=values):
|
||||
|
||||
# Create U and V.
|
||||
self._u = ops.convert_to_tensor(u, name="u")
|
||||
self._u = linear_operator_util.convert_nonref_to_tensor(u, name="u")
|
||||
if v is None:
|
||||
self._v = self._u
|
||||
else:
|
||||
self._v = ops.convert_to_tensor(v, name="v")
|
||||
self._v = linear_operator_util.convert_nonref_to_tensor(v, name="v")
|
||||
|
||||
if diag_update is None:
|
||||
self._diag_update = None
|
||||
else:
|
||||
self._diag_update = ops.convert_to_tensor(
|
||||
self._diag_update = linear_operator_util.convert_nonref_to_tensor(
|
||||
diag_update, name="diag_update")
|
||||
|
||||
# Create base_operator L.
|
||||
@ -261,12 +261,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
|
||||
self._check_shapes()
|
||||
|
||||
# Pre-compute the so-called "capacitance" matrix
|
||||
# C := D^{-1} + V^H L^{-1} U
|
||||
self._capacitance = self._make_capacitance()
|
||||
if self._use_cholesky:
|
||||
self._chol_capacitance = linalg_ops.cholesky(self._capacitance)
|
||||
|
||||
def _check_shapes(self):
|
||||
"""Static check that shapes are compatible."""
|
||||
# Broadcast shape also checks that u and v are compatible.
|
||||
@ -291,8 +285,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
if diag_update is not None:
|
||||
self._diag_operator = linear_operator_diag.LinearOperatorDiag(
|
||||
self._diag_update, is_positive_definite=is_diag_update_positive)
|
||||
self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag(
|
||||
1. / self._diag_update, is_positive_definite=is_diag_update_positive)
|
||||
else:
|
||||
if tensor_shape.dimension_value(self.u.shape[-1]) is not None:
|
||||
r = tensor_shape.dimension_value(self.u.shape[-1])
|
||||
@ -300,7 +292,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
r = array_ops.shape(self.u)[-1]
|
||||
self._diag_operator = linear_operator_identity.LinearOperatorIdentity(
|
||||
num_rows=r, dtype=self.dtype)
|
||||
self._diag_inv_operator = self._diag_operator
|
||||
|
||||
@property
|
||||
def u(self):
|
||||
@ -373,7 +364,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
# = det(C) det(D) det(L)
|
||||
# where C is sometimes known as the capacitance matrix,
|
||||
# C := D^{-1} + V^H L^{-1} U
|
||||
det_c = linalg_ops.matrix_determinant(self._capacitance)
|
||||
det_c = linalg_ops.matrix_determinant(self._make_capacitance())
|
||||
det_d = self.diag_operator.determinant()
|
||||
det_l = self.base_operator.determinant()
|
||||
return det_c * det_d * det_l
|
||||
@ -386,11 +377,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
log_abs_det_l = self.base_operator.log_abs_determinant()
|
||||
|
||||
if self._use_cholesky:
|
||||
chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance)
|
||||
chol_cap_diag = array_ops.matrix_diag_part(
|
||||
linalg_ops.cholesky(self._make_capacitance()))
|
||||
log_abs_det_c = 2 * math_ops.reduce_sum(
|
||||
math_ops.log(chol_cap_diag), axis=[-1])
|
||||
else:
|
||||
det_c = linalg_ops.matrix_determinant(self._capacitance)
|
||||
det_c = linalg_ops.matrix_determinant(self._make_capacitance())
|
||||
log_abs_det_c = math_ops.log(math_ops.abs(det_c))
|
||||
if self.dtype.is_complex:
|
||||
log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
|
||||
@ -426,10 +418,10 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
# C^{-1} V^H L^{-1} rhs
|
||||
if self._use_cholesky:
|
||||
capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast(
|
||||
self._chol_capacitance, vh_linv_rhs)
|
||||
linalg_ops.cholesky(self._make_capacitance()), vh_linv_rhs)
|
||||
else:
|
||||
capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
|
||||
self._capacitance, vh_linv_rhs, adjoint=adjoint)
|
||||
self._make_capacitance(), vh_linv_rhs, adjoint=adjoint)
|
||||
# U C^{-1} V^H M^{-1} rhs
|
||||
u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs)
|
||||
# L^{-1} U C^{-1} V^H L^{-1} rhs
|
||||
@ -448,5 +440,5 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
vh_linv_u = math_ops.matmul(self.v, linv_u, adjoint_a=True)
|
||||
|
||||
# D^{-1} + V^H L^{-1} V
|
||||
capacitance = self._diag_inv_operator.add_to_tensor(vh_linv_u)
|
||||
capacitance = self._diag_operator.inverse().add_to_tensor(vh_linv_u)
|
||||
return capacitance
|
||||
|
Loading…
Reference in New Issue
Block a user