Get around deprecation warning in setting of graph_parents in LinearOperator
subclasses. PiperOrigin-RevId: 278713002 Change-Id: Ia39f86a0395242b954eaee268df1d2591cf0827e
This commit is contained in:
parent
735aff9256
commit
804ce50226
@ -148,6 +148,7 @@ class LinearOperator(module.Module):
|
||||
way.
|
||||
"""
|
||||
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
@deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will "
|
||||
" no longer be used.", "graph_parents")
|
||||
def __init__(self,
|
||||
@ -201,13 +202,11 @@ class LinearOperator(module.Module):
|
||||
|
||||
self._is_square_set_or_implied_by_hints = is_square
|
||||
|
||||
graph_parents = [] if graph_parents is None else graph_parents
|
||||
for i, t in enumerate(graph_parents):
|
||||
if t is None or not (linear_operator_util.is_ref(t) or
|
||||
tensor_util.is_tensor(t)):
|
||||
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
||||
if graph_parents is not None:
|
||||
self._set_graph_parents(graph_parents)
|
||||
else:
|
||||
self._graph_parents = []
|
||||
self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype
|
||||
self._graph_parents = graph_parents
|
||||
self._is_non_singular = is_non_singular
|
||||
self._is_self_adjoint = is_self_adjoint
|
||||
self._is_positive_definite = is_positive_definite
|
||||
@ -1077,6 +1076,24 @@ class LinearOperator(module.Module):
|
||||
def _can_use_cholesky(self):
|
||||
return self.is_self_adjoint and self.is_positive_definite
|
||||
|
||||
def _set_graph_parents(self, graph_parents):
|
||||
"""Set self._graph_parents. Called during derived class init.
|
||||
|
||||
This method allows derived classes to set graph_parents, without triggering
|
||||
a deprecation warning (which is invoked if `graph_parents` is passed during
|
||||
`__init__`.
|
||||
|
||||
Args:
|
||||
graph_parents: Iterable over Tensors.
|
||||
"""
|
||||
# TODO(b/143910018) Remove this function in V3.
|
||||
graph_parents = [] if graph_parents is None else graph_parents
|
||||
for i, t in enumerate(graph_parents):
|
||||
if t is None or not (linear_operator_util.is_ref(t) or
|
||||
tensor_util.is_tensor(t)):
|
||||
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
|
||||
self._graph_parents = graph_parents
|
||||
|
||||
|
||||
# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
|
||||
# place of a Tensor.
|
||||
|
@ -145,12 +145,14 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
|
||||
with ops.name_scope(name, values=operator.graph_parents):
|
||||
super(LinearOperatorAdjoint, self).__init__(
|
||||
dtype=operator.dtype,
|
||||
graph_parents=operator.graph_parents,
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents(operator.graph_parents)
|
||||
|
||||
@property
|
||||
def operator(self):
|
||||
|
@ -204,13 +204,16 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
with ops.name_scope(name, values=graph_parents):
|
||||
super(LinearOperatorBlockDiag, self).__init__(
|
||||
dtype=dtype,
|
||||
graph_parents=graph_parents,
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=True,
|
||||
name=name)
|
||||
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents(graph_parents)
|
||||
|
||||
@property
|
||||
def operators(self):
|
||||
return self._operators
|
||||
|
@ -116,12 +116,14 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
|
||||
super(_BaseLinearOperatorCirculant, self).__init__(
|
||||
dtype=dtypes.as_dtype(input_output_dtype),
|
||||
graph_parents=[self.spectrum],
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents([self.spectrum])
|
||||
|
||||
def _check_spectrum_and_return_tensor(self, spectrum):
|
||||
"""Static check of spectrum. Then return `Tensor` version."""
|
||||
|
@ -177,12 +177,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
with ops.name_scope(name, values=graph_parents):
|
||||
super(LinearOperatorComposition, self).__init__(
|
||||
dtype=dtype,
|
||||
graph_parents=graph_parents,
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents(graph_parents)
|
||||
|
||||
@property
|
||||
def operators(self):
|
||||
|
@ -158,12 +158,14 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
|
||||
super(LinearOperatorDiag, self).__init__(
|
||||
dtype=self._diag.dtype,
|
||||
graph_parents=[self._diag],
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents([self._diag])
|
||||
|
||||
def _check_diag(self, diag):
|
||||
"""Static check of diag."""
|
||||
|
@ -141,12 +141,14 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
|
||||
super(LinearOperatorFullMatrix, self).__init__(
|
||||
dtype=self._matrix.dtype,
|
||||
graph_parents=[self._matrix],
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents([self._matrix])
|
||||
|
||||
def _check_matrix(self, matrix):
|
||||
"""Static check of the `matrix` argument."""
|
||||
|
@ -146,12 +146,14 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
|
||||
super(LinearOperatorHouseholder, self).__init__(
|
||||
dtype=self._reflection_axis.dtype,
|
||||
graph_parents=[self._reflection_axis],
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents([self._reflection_axis])
|
||||
|
||||
def _check_reflection_axis(self, reflection_axis):
|
||||
"""Static check of reflection_axis."""
|
||||
|
@ -158,12 +158,14 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
|
||||
with ops.name_scope(name, values=operator.graph_parents):
|
||||
super(LinearOperatorInversion, self).__init__(
|
||||
dtype=operator.dtype,
|
||||
graph_parents=operator.graph_parents,
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents(operator.graph_parents)
|
||||
|
||||
@property
|
||||
def operator(self):
|
||||
|
@ -221,12 +221,14 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
|
||||
with ops.name_scope(name, values=graph_parents):
|
||||
super(LinearOperatorKronecker, self).__init__(
|
||||
dtype=dtype,
|
||||
graph_parents=graph_parents,
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
# TODO(b/143910018) Remove graph_parents in V3.
|
||||
self._set_graph_parents(graph_parents)
|
||||
|
||||
@property
|
||||
def operators(self):
|
||||
|
@ -248,12 +248,13 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
||||
|
||||
super(LinearOperatorLowRankUpdate, self).__init__(
|
||||
dtype=self._base_operator.dtype,
|
||||
graph_parents=graph_parents,
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
self._set_graph_parents(graph_parents)
|
||||
|
||||
# Create the diagonal operator D.
|
||||
self._set_diag_operators(diag_update, is_diag_update_positive)
|
||||
|
@ -150,12 +150,13 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
||||
|
||||
super(LinearOperatorLowerTriangular, self).__init__(
|
||||
dtype=self._tril.dtype,
|
||||
graph_parents=[self._tril],
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
self._set_graph_parents([self._tril])
|
||||
|
||||
def _check_tril(self, tril):
|
||||
"""Static check of the `tril` argument."""
|
||||
|
@ -150,12 +150,13 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
|
||||
|
||||
super(LinearOperatorToeplitz, self).__init__(
|
||||
dtype=self._row.dtype,
|
||||
graph_parents=[self._row, self._col],
|
||||
graph_parents=None,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
self._set_graph_parents([self._row, self._col])
|
||||
|
||||
def _check_row_col(self, row, col):
|
||||
"""Static check of row and column."""
|
||||
|
Loading…
Reference in New Issue
Block a user