diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 0e44fceeaf3..205c16e5197 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -148,6 +148,7 @@ class LinearOperator(module.Module): way. """ + # TODO(b/143910018) Remove graph_parents in V3. @deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will " " no longer be used.", "graph_parents") def __init__(self, @@ -201,13 +202,11 @@ class LinearOperator(module.Module): self._is_square_set_or_implied_by_hints = is_square - graph_parents = [] if graph_parents is None else graph_parents - for i, t in enumerate(graph_parents): - if t is None or not (linear_operator_util.is_ref(t) or - tensor_util.is_tensor(t)): - raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) + if graph_parents is not None: + self._set_graph_parents(graph_parents) + else: + self._graph_parents = [] self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype - self._graph_parents = graph_parents self._is_non_singular = is_non_singular self._is_self_adjoint = is_self_adjoint self._is_positive_definite = is_positive_definite @@ -1077,6 +1076,24 @@ class LinearOperator(module.Module): def _can_use_cholesky(self): return self.is_self_adjoint and self.is_positive_definite + def _set_graph_parents(self, graph_parents): + """Set self._graph_parents. Called during derived class init. + + This method allows derived classes to set graph_parents, without triggering + a deprecation warning (which is invoked if `graph_parents` is passed during + `__init__`. + + Args: + graph_parents: Iterable over Tensors. + """ + # TODO(b/143910018) Remove this function in V3. + graph_parents = [] if graph_parents is None else graph_parents + for i, t in enumerate(graph_parents): + if t is None or not (linear_operator_util.is_ref(t) or + tensor_util.is_tensor(t)): + raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) + self._graph_parents = graph_parents + # Overrides for tf.linalg functions. This allows a LinearOperator to be used in # place of a Tensor. diff --git a/tensorflow/python/ops/linalg/linear_operator_adjoint.py b/tensorflow/python/ops/linalg/linear_operator_adjoint.py index 803fbe9e903..eb5af872773 100644 --- a/tensorflow/python/ops/linalg/linear_operator_adjoint.py +++ b/tensorflow/python/ops/linalg/linear_operator_adjoint.py @@ -145,12 +145,14 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator): with ops.name_scope(name, values=operator.graph_parents): super(LinearOperatorAdjoint, self).__init__( dtype=operator.dtype, - graph_parents=operator.graph_parents, + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents(operator.graph_parents) @property def operator(self): diff --git a/tensorflow/python/ops/linalg/linear_operator_block_diag.py b/tensorflow/python/ops/linalg/linear_operator_block_diag.py index 5744420db64..8b4ab0dc5e5 100644 --- a/tensorflow/python/ops/linalg/linear_operator_block_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_block_diag.py @@ -204,13 +204,16 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator): with ops.name_scope(name, values=graph_parents): super(LinearOperatorBlockDiag, self).__init__( dtype=dtype, - graph_parents=graph_parents, + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=True, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents(graph_parents) + @property def operators(self): return self._operators diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py index 8d0e1d88b38..847eda9f7d5 100644 --- a/tensorflow/python/ops/linalg/linear_operator_circulant.py +++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py @@ -116,12 +116,14 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator): super(_BaseLinearOperatorCirculant, self).__init__( dtype=dtypes.as_dtype(input_output_dtype), - graph_parents=[self.spectrum], + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents([self.spectrum]) def _check_spectrum_and_return_tensor(self, spectrum): """Static check of spectrum. Then return `Tensor` version.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_composition.py b/tensorflow/python/ops/linalg/linear_operator_composition.py index 7e6f79e830a..00ef86d5aba 100644 --- a/tensorflow/python/ops/linalg/linear_operator_composition.py +++ b/tensorflow/python/ops/linalg/linear_operator_composition.py @@ -177,12 +177,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator): with ops.name_scope(name, values=graph_parents): super(LinearOperatorComposition, self).__init__( dtype=dtype, - graph_parents=graph_parents, + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents(graph_parents) @property def operators(self): diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index 01cd796c5e6..3d2a47a05de 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -158,12 +158,14 @@ class LinearOperatorDiag(linear_operator.LinearOperator): super(LinearOperatorDiag, self).__init__( dtype=self._diag.dtype, - graph_parents=[self._diag], + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents([self._diag]) def _check_diag(self, diag): """Static check of diag.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py index 15e8fb6fdcf..8fe68919250 100644 --- a/tensorflow/python/ops/linalg/linear_operator_full_matrix.py +++ b/tensorflow/python/ops/linalg/linear_operator_full_matrix.py @@ -141,12 +141,14 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): super(LinearOperatorFullMatrix, self).__init__( dtype=self._matrix.dtype, - graph_parents=[self._matrix], + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents([self._matrix]) def _check_matrix(self, matrix): """Static check of the `matrix` argument.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_householder.py b/tensorflow/python/ops/linalg/linear_operator_householder.py index 1c5b9abb8b0..03a5e560a75 100644 --- a/tensorflow/python/ops/linalg/linear_operator_householder.py +++ b/tensorflow/python/ops/linalg/linear_operator_householder.py @@ -146,12 +146,14 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator): super(LinearOperatorHouseholder, self).__init__( dtype=self._reflection_axis.dtype, - graph_parents=[self._reflection_axis], + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents([self._reflection_axis]) def _check_reflection_axis(self, reflection_axis): """Static check of reflection_axis.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_inversion.py b/tensorflow/python/ops/linalg/linear_operator_inversion.py index 8518bdb3d3b..deb0ed6f92b 100644 --- a/tensorflow/python/ops/linalg/linear_operator_inversion.py +++ b/tensorflow/python/ops/linalg/linear_operator_inversion.py @@ -158,12 +158,14 @@ class LinearOperatorInversion(linear_operator.LinearOperator): with ops.name_scope(name, values=operator.graph_parents): super(LinearOperatorInversion, self).__init__( dtype=operator.dtype, - graph_parents=operator.graph_parents, + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents(operator.graph_parents) @property def operator(self): diff --git a/tensorflow/python/ops/linalg/linear_operator_kronecker.py b/tensorflow/python/ops/linalg/linear_operator_kronecker.py index c86facfd007..1fe68885bfe 100644 --- a/tensorflow/python/ops/linalg/linear_operator_kronecker.py +++ b/tensorflow/python/ops/linalg/linear_operator_kronecker.py @@ -221,12 +221,14 @@ class LinearOperatorKronecker(linear_operator.LinearOperator): with ops.name_scope(name, values=graph_parents): super(LinearOperatorKronecker, self).__init__( dtype=dtype, - graph_parents=graph_parents, + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + # TODO(b/143910018) Remove graph_parents in V3. + self._set_graph_parents(graph_parents) @property def operators(self): diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index f4c75c1b9fb..019adc052ae 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -248,12 +248,13 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): super(LinearOperatorLowRankUpdate, self).__init__( dtype=self._base_operator.dtype, - graph_parents=graph_parents, + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + self._set_graph_parents(graph_parents) # Create the diagonal operator D. self._set_diag_operators(diag_update, is_diag_update_positive) diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index 50a4b681534..37695b5323c 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -150,12 +150,13 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): super(LinearOperatorLowerTriangular, self).__init__( dtype=self._tril.dtype, - graph_parents=[self._tril], + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + self._set_graph_parents([self._tril]) def _check_tril(self, tril): """Static check of the `tril` argument.""" diff --git a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py index be95ce4beec..71fff44da44 100644 --- a/tensorflow/python/ops/linalg/linear_operator_toeplitz.py +++ b/tensorflow/python/ops/linalg/linear_operator_toeplitz.py @@ -150,12 +150,13 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator): super(LinearOperatorToeplitz, self).__init__( dtype=self._row.dtype, - graph_parents=[self._row, self._col], + graph_parents=None, is_non_singular=is_non_singular, is_self_adjoint=is_self_adjoint, is_positive_definite=is_positive_definite, is_square=is_square, name=name) + self._set_graph_parents([self._row, self._col]) def _check_row_col(self, row, col): """Static check of row and column."""