Suppress 'conversion to a dense matrix' warning from LinearOperatorFullMatrix.solve().
The current warning is inappropriate: since a LinearOperatorFullMatrix is inherently dense, no efficiency is lost when we treat it as dense. PiperOrigin-RevId: 296305093 Change-Id: Id3b7e2a00f05d1e516374c4241cd84529844a056
This commit is contained in:
parent
3ba8bd697f
commit
f120f7d514
@ -751,14 +751,11 @@ class LinearOperator(module.Module):
|
|||||||
with self._name_scope(name):
|
with self._name_scope(name):
|
||||||
return self._log_abs_determinant()
|
return self._log_abs_determinant()
|
||||||
|
|
||||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||||
"""Default implementation of _solve."""
|
"""Solve by conversion to a dense matrix."""
|
||||||
if self.is_square is False:
|
if self.is_square is False: # pylint: disable=g-bool-id-comparison
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Solve is not yet implemented for non-square operators.")
|
"Solve is not yet implemented for non-square operators.")
|
||||||
logging.warn(
|
|
||||||
"Using (possibly slow) default implementation of solve."
|
|
||||||
" Requires conversion to a dense matrix and O(N^3) operations.")
|
|
||||||
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
|
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
|
||||||
if self._can_use_cholesky():
|
if self._can_use_cholesky():
|
||||||
return linalg_ops.cholesky_solve(
|
return linalg_ops.cholesky_solve(
|
||||||
@ -766,6 +763,13 @@ class LinearOperator(module.Module):
|
|||||||
return linear_operator_util.matrix_solve_with_broadcast(
|
return linear_operator_util.matrix_solve_with_broadcast(
|
||||||
self.to_dense(), rhs, adjoint=adjoint)
|
self.to_dense(), rhs, adjoint=adjoint)
|
||||||
|
|
||||||
|
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||||
|
"""Default implementation of _solve."""
|
||||||
|
logging.warn(
|
||||||
|
"Using (possibly slow) default implementation of solve."
|
||||||
|
" Requires conversion to a dense matrix and O(N^3) operations.")
|
||||||
|
return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||||
|
|
||||||
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
|
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
|
||||||
"""Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
|
"""Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
|
||||||
|
|
||||||
|
@ -183,5 +183,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
|||||||
return math_ops.matmul(
|
return math_ops.matmul(
|
||||||
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||||
|
|
||||||
|
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||||
|
return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||||
|
|
||||||
def _to_dense(self):
|
def _to_dense(self):
|
||||||
return self._matrix
|
return self._matrix
|
||||||
|
Loading…
Reference in New Issue
Block a user