Get around deprecation warning in setting of graph_parents in LinearOperator

subclasses.

PiperOrigin-RevId: 278713002
Change-Id: Ia39f86a0395242b954eaee268df1d2591cf0827e
This commit is contained in:
Ian Langmore 2019-11-05 14:47:51 -08:00 committed by TensorFlower Gardener
parent 735aff9256
commit 804ce50226
13 changed files with 57 additions and 18 deletions

View File

@ -148,6 +148,7 @@ class LinearOperator(module.Module):
way.
"""
# TODO(b/143910018) Remove graph_parents in V3.
@deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will "
" no longer be used.", "graph_parents")
def __init__(self,
@ -201,13 +202,11 @@ class LinearOperator(module.Module):
self._is_square_set_or_implied_by_hints = is_square
graph_parents = [] if graph_parents is None else graph_parents
for i, t in enumerate(graph_parents):
if t is None or not (linear_operator_util.is_ref(t) or
tensor_util.is_tensor(t)):
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
if graph_parents is not None:
self._set_graph_parents(graph_parents)
else:
self._graph_parents = []
self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype
self._graph_parents = graph_parents
self._is_non_singular = is_non_singular
self._is_self_adjoint = is_self_adjoint
self._is_positive_definite = is_positive_definite
@ -1077,6 +1076,24 @@ class LinearOperator(module.Module):
def _can_use_cholesky(self):
return self.is_self_adjoint and self.is_positive_definite
def _set_graph_parents(self, graph_parents):
"""Set self._graph_parents. Called during derived class init.
This method allows derived classes to set graph_parents, without triggering
a deprecation warning (which is invoked if `graph_parents` is passed during
`__init__`.
Args:
graph_parents: Iterable over Tensors.
"""
# TODO(b/143910018) Remove this function in V3.
graph_parents = [] if graph_parents is None else graph_parents
for i, t in enumerate(graph_parents):
if t is None or not (linear_operator_util.is_ref(t) or
tensor_util.is_tensor(t)):
raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
self._graph_parents = graph_parents
# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
# place of a Tensor.

View File

@ -145,12 +145,14 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
with ops.name_scope(name, values=operator.graph_parents):
super(LinearOperatorAdjoint, self).__init__(
dtype=operator.dtype,
graph_parents=operator.graph_parents,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(operator.graph_parents)
@property
def operator(self):

View File

@ -204,13 +204,16 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
with ops.name_scope(name, values=graph_parents):
super(LinearOperatorBlockDiag, self).__init__(
dtype=dtype,
graph_parents=graph_parents,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=True,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(graph_parents)
@property
def operators(self):
return self._operators

View File

@ -116,12 +116,14 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
super(_BaseLinearOperatorCirculant, self).__init__(
dtype=dtypes.as_dtype(input_output_dtype),
graph_parents=[self.spectrum],
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self.spectrum])
def _check_spectrum_and_return_tensor(self, spectrum):
"""Static check of spectrum. Then return `Tensor` version."""

View File

@ -177,12 +177,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
with ops.name_scope(name, values=graph_parents):
super(LinearOperatorComposition, self).__init__(
dtype=dtype,
graph_parents=graph_parents,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(graph_parents)
@property
def operators(self):

View File

@ -158,12 +158,14 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
super(LinearOperatorDiag, self).__init__(
dtype=self._diag.dtype,
graph_parents=[self._diag],
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self._diag])
def _check_diag(self, diag):
"""Static check of diag."""

View File

@ -141,12 +141,14 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
super(LinearOperatorFullMatrix, self).__init__(
dtype=self._matrix.dtype,
graph_parents=[self._matrix],
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self._matrix])
def _check_matrix(self, matrix):
"""Static check of the `matrix` argument."""

View File

@ -146,12 +146,14 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
super(LinearOperatorHouseholder, self).__init__(
dtype=self._reflection_axis.dtype,
graph_parents=[self._reflection_axis],
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self._reflection_axis])
def _check_reflection_axis(self, reflection_axis):
"""Static check of reflection_axis."""

View File

@ -158,12 +158,14 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
with ops.name_scope(name, values=operator.graph_parents):
super(LinearOperatorInversion, self).__init__(
dtype=operator.dtype,
graph_parents=operator.graph_parents,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(operator.graph_parents)
@property
def operator(self):

View File

@ -221,12 +221,14 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
with ops.name_scope(name, values=graph_parents):
super(LinearOperatorKronecker, self).__init__(
dtype=dtype,
graph_parents=graph_parents,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(graph_parents)
@property
def operators(self):

View File

@ -248,12 +248,13 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
super(LinearOperatorLowRankUpdate, self).__init__(
dtype=self._base_operator.dtype,
graph_parents=graph_parents,
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
self._set_graph_parents(graph_parents)
# Create the diagonal operator D.
self._set_diag_operators(diag_update, is_diag_update_positive)

View File

@ -150,12 +150,13 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
super(LinearOperatorLowerTriangular, self).__init__(
dtype=self._tril.dtype,
graph_parents=[self._tril],
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
self._set_graph_parents([self._tril])
def _check_tril(self, tril):
"""Static check of the `tril` argument."""

View File

@ -150,12 +150,13 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
super(LinearOperatorToeplitz, self).__init__(
dtype=self._row.dtype,
graph_parents=[self._row, self._col],
graph_parents=None,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
self._set_graph_parents([self._row, self._col])
def _check_row_col(self, row, col):
"""Static check of row and column."""