Suppress 'conversion to a dense matrix' warning from LinearOperatorFullMatrix.solve().

The current warning is inappropriate: since a LinearOperatorFullMatrix is inherently dense, no efficiency is lost when we treat it as dense.

PiperOrigin-RevId: 296305093
Change-Id: Id3b7e2a00f05d1e516374c4241cd84529844a056
This commit is contained in:
A. Unique TensorFlower 2020-02-20 14:59:42 -08:00 committed by TensorFlower Gardener
parent 3ba8bd697f
commit f120f7d514
2 changed files with 13 additions and 6 deletions

View File

@ -751,14 +751,11 @@ class LinearOperator(module.Module):
with self._name_scope(name):
return self._log_abs_determinant()
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
"""Default implementation of _solve."""
if self.is_square is False:
def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
"""Solve by conversion to a dense matrix."""
if self.is_square is False: # pylint: disable=g-bool-id-comparison
raise NotImplementedError(
"Solve is not yet implemented for non-square operators.")
logging.warn(
"Using (possibly slow) default implementation of solve."
" Requires conversion to a dense matrix and O(N^3) operations.")
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
if self._can_use_cholesky():
return linalg_ops.cholesky_solve(
@ -766,6 +763,13 @@ class LinearOperator(module.Module):
return linear_operator_util.matrix_solve_with_broadcast(
self.to_dense(), rhs, adjoint=adjoint)
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
"""Default implementation of _solve."""
logging.warn(
"Using (possibly slow) default implementation of solve."
" Requires conversion to a dense matrix and O(N^3) operations.")
return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
"""Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.

View File

@ -183,5 +183,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
return math_ops.matmul(
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
def _to_dense(self):
return self._matrix