LinearOperatorLowRankUpdate made "tape safe" for TF2 compliance.

PiperOrigin-RevId: 264521197
This commit is contained in:
Ian Langmore 2019-08-20 19:39:25 -07:00 committed by TensorFlower Gardener
parent e9bda5601f
commit a92016d3fe
2 changed files with 27 additions and 18 deletions

View File

@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
@ -155,6 +156,22 @@ class BaseLinearOperatorLowRankUpdatetest(object):
return operator, matrix
def test_tape_safe(self):
base_operator = linalg.LinearOperatorDiag(
variables_module.Variable([1.], name="diag"),
is_positive_definite=True,
is_self_adjoint=True)
operator = linalg.LinearOperatorLowRankUpdate(
base_operator,
u=variables_module.Variable([[2.]], name="u"),
v=variables_module.Variable([[1.25]], name="v")
if self._use_v else None,
diag_update=variables_module.Variable([1.25], name="diag_update")
if self._use_diag_update else None,
is_diag_update_positive=self._is_diag_update_positive)
self.check_tape_safe(operator)
class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
BaseLinearOperatorLowRankUpdatetest,

View File

@ -228,16 +228,16 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
with ops.name_scope(name, values=values):
# Create U and V.
self._u = ops.convert_to_tensor(u, name="u")
self._u = linear_operator_util.convert_nonref_to_tensor(u, name="u")
if v is None:
self._v = self._u
else:
self._v = ops.convert_to_tensor(v, name="v")
self._v = linear_operator_util.convert_nonref_to_tensor(v, name="v")
if diag_update is None:
self._diag_update = None
else:
self._diag_update = ops.convert_to_tensor(
self._diag_update = linear_operator_util.convert_nonref_to_tensor(
diag_update, name="diag_update")
# Create base_operator L.
@ -261,12 +261,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
self._check_shapes()
# Pre-compute the so-called "capacitance" matrix
# C := D^{-1} + V^H L^{-1} U
self._capacitance = self._make_capacitance()
if self._use_cholesky:
self._chol_capacitance = linalg_ops.cholesky(self._capacitance)
def _check_shapes(self):
"""Static check that shapes are compatible."""
# Broadcast shape also checks that u and v are compatible.
@ -291,8 +285,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
if diag_update is not None:
self._diag_operator = linear_operator_diag.LinearOperatorDiag(
self._diag_update, is_positive_definite=is_diag_update_positive)
self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag(
1. / self._diag_update, is_positive_definite=is_diag_update_positive)
else:
if tensor_shape.dimension_value(self.u.shape[-1]) is not None:
r = tensor_shape.dimension_value(self.u.shape[-1])
@ -300,7 +292,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
r = array_ops.shape(self.u)[-1]
self._diag_operator = linear_operator_identity.LinearOperatorIdentity(
num_rows=r, dtype=self.dtype)
self._diag_inv_operator = self._diag_operator
@property
def u(self):
@ -373,7 +364,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
# = det(C) det(D) det(L)
# where C is sometimes known as the capacitance matrix,
# C := D^{-1} + V^H L^{-1} U
det_c = linalg_ops.matrix_determinant(self._capacitance)
det_c = linalg_ops.matrix_determinant(self._make_capacitance())
det_d = self.diag_operator.determinant()
det_l = self.base_operator.determinant()
return det_c * det_d * det_l
@ -386,11 +377,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
log_abs_det_l = self.base_operator.log_abs_determinant()
if self._use_cholesky:
chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance)
chol_cap_diag = array_ops.matrix_diag_part(
linalg_ops.cholesky(self._make_capacitance()))
log_abs_det_c = 2 * math_ops.reduce_sum(
math_ops.log(chol_cap_diag), axis=[-1])
else:
det_c = linalg_ops.matrix_determinant(self._capacitance)
det_c = linalg_ops.matrix_determinant(self._make_capacitance())
log_abs_det_c = math_ops.log(math_ops.abs(det_c))
if self.dtype.is_complex:
log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
@ -426,10 +418,10 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
# C^{-1} V^H L^{-1} rhs
if self._use_cholesky:
capinv_vh_linv_rhs = linear_operator_util.cholesky_solve_with_broadcast(
self._chol_capacitance, vh_linv_rhs)
linalg_ops.cholesky(self._make_capacitance()), vh_linv_rhs)
else:
capinv_vh_linv_rhs = linear_operator_util.matrix_solve_with_broadcast(
self._capacitance, vh_linv_rhs, adjoint=adjoint)
self._make_capacitance(), vh_linv_rhs, adjoint=adjoint)
# U C^{-1} V^H M^{-1} rhs
u_capinv_vh_linv_rhs = math_ops.matmul(u, capinv_vh_linv_rhs)
# L^{-1} U C^{-1} V^H L^{-1} rhs
@ -448,5 +440,5 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
vh_linv_u = math_ops.matmul(self.v, linv_u, adjoint_a=True)
# D^{-1} + V^H L^{-1} V
capacitance = self._diag_inv_operator.add_to_tensor(vh_linv_u)
capacitance = self._diag_operator.inverse().add_to_tensor(vh_linv_u)
return capacitance