LinearOperatorAdjoint added as a meta-class. Computes the adjoint of
an operator by flipping the adjoint arg in .matmul and .solve, and other appropriate modifications. This is hidden from the public API, with an upcoming method, .adjoint, the preferred way to get an adjoint, since it has more operator specific information. PiperOrigin-RevId: 221869871
This commit is contained in:
parent
2121365e74
commit
de8ebdb12f
@ -40,6 +40,28 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "linear_operator_adjoint_test",
|
||||
size = "medium",
|
||||
srcs = ["linear_operator_adjoint_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"noasan", # times out, b/63678675
|
||||
"optonly", # times out
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "linear_operator_algebra_test",
|
||||
size = "small",
|
||||
|
@ -0,0 +1,118 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_adjoint
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
linalg = linalg_lib
|
||||
|
||||
LinearOperatorAdjoint = linear_operator_adjoint.LinearOperatorAdjoint # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LinearOperatorAdjointTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
|
||||
def setUp(self):
|
||||
self._atol[dtypes.complex64] = 1e-5
|
||||
self._rtol[dtypes.complex64] = 1e-5
|
||||
|
||||
def _operator_and_matrix(self,
|
||||
build_info,
|
||||
dtype,
|
||||
use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
shape = list(build_info.shape)
|
||||
|
||||
if ensure_self_adjoint_and_pd:
|
||||
matrix = linear_operator_test_util.random_positive_definite_matrix(
|
||||
shape, dtype, force_well_conditioned=True)
|
||||
else:
|
||||
matrix = linear_operator_test_util.random_tril_matrix(
|
||||
shape, dtype, force_well_conditioned=True, remove_upper=True)
|
||||
|
||||
lin_op_matrix = matrix
|
||||
|
||||
if use_placeholder:
|
||||
lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
|
||||
|
||||
if ensure_self_adjoint_and_pd:
|
||||
operator = LinearOperatorAdjoint(
|
||||
linalg.LinearOperatorFullMatrix(
|
||||
lin_op_matrix, is_positive_definite=True, is_self_adjoint=True))
|
||||
else:
|
||||
operator = LinearOperatorAdjoint(
|
||||
linalg.LinearOperatorLowerTriangular(lin_op_matrix))
|
||||
|
||||
return operator, linalg.adjoint(matrix)
|
||||
|
||||
def test_base_operator_hint_used(self):
|
||||
# The matrix values do not effect auto-setting of the flags.
|
||||
matrix = [[1., 0.], [1., 1.]]
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
matrix,
|
||||
is_positive_definite=True,
|
||||
is_non_singular=True,
|
||||
is_self_adjoint=False)
|
||||
operator_adjoint = LinearOperatorAdjoint(operator)
|
||||
self.assertTrue(operator_adjoint.is_positive_definite)
|
||||
self.assertTrue(operator_adjoint.is_non_singular)
|
||||
self.assertFalse(operator_adjoint.is_self_adjoint)
|
||||
|
||||
def test_supplied_hint_used(self):
|
||||
# The matrix values do not effect auto-setting of the flags.
|
||||
matrix = [[1., 0.], [1., 1.]]
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||
operator_adjoint = LinearOperatorAdjoint(
|
||||
operator,
|
||||
is_positive_definite=True,
|
||||
is_non_singular=True,
|
||||
is_self_adjoint=False)
|
||||
self.assertTrue(operator_adjoint.is_positive_definite)
|
||||
self.assertTrue(operator_adjoint.is_non_singular)
|
||||
self.assertFalse(operator_adjoint.is_self_adjoint)
|
||||
|
||||
def test_contradicting_hints_raise(self):
|
||||
# The matrix values do not effect auto-setting of the flags.
|
||||
matrix = [[1., 0.], [1., 1.]]
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
matrix, is_positive_definite=False)
|
||||
with self.assertRaisesRegexp(ValueError, "positive-definite"):
|
||||
LinearOperatorAdjoint(operator, is_positive_definite=True)
|
||||
|
||||
operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=False)
|
||||
with self.assertRaisesRegexp(ValueError, "self-adjoint"):
|
||||
LinearOperatorAdjoint(operator, is_self_adjoint=True)
|
||||
|
||||
def test_name(self):
|
||||
matrix = [[11., 0.], [1., 8.]]
|
||||
operator = linalg.LinearOperatorFullMatrix(
|
||||
matrix, name="my_operator", is_non_singular=True)
|
||||
|
||||
operator = LinearOperatorAdjoint(operator)
|
||||
|
||||
self.assertEqual("my_operator_adjoint", operator.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
207
tensorflow/python/ops/linalg/linear_operator_adjoint.py
Normal file
207
tensorflow/python/ops/linalg/linear_operator_adjoint.py
Normal file
@ -0,0 +1,207 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Takes the adjoint of a `LinearOperator`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
@tf_export("linalg.LinearOperatorAdjoint")
|
||||
class LinearOperatorAdjoint(linear_operator.LinearOperator):
|
||||
"""`LinearOperator` representing the adjoint of another operator.
|
||||
|
||||
This operator represents the adjoint of another operator.
|
||||
|
||||
```python
|
||||
# Create a 2 x 2 linear operator.
|
||||
operator = LinearOperatorFullMatrix([[1 - i., 3.], [0., 1. + i]])
|
||||
operator_adjoint = LinearOperatorAdjoint(operator)
|
||||
|
||||
operator_adjoint.to_dense()
|
||||
==> [[1. + i, 0.]
|
||||
[3., 1 - i]]
|
||||
|
||||
operator_adjoint.shape
|
||||
==> [2, 2]
|
||||
|
||||
operator_adjoint.log_abs_determinant()
|
||||
==> - log(2)
|
||||
|
||||
x = ... Shape [2, 4] Tensor
|
||||
operator_adjoint.matmul(x)
|
||||
==> Shape [2, 4] Tensor, equal to operator.matmul(x, adjoint=True)
|
||||
```
|
||||
|
||||
#### Performance
|
||||
|
||||
The performance of `LinearOperatorAdjoint` depends on the underlying
|
||||
operators performance.
|
||||
|
||||
#### Matrix property hints
|
||||
|
||||
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
|
||||
for `X = non_singular, self_adjoint, positive_definite, square`.
|
||||
These have the following meaning:
|
||||
|
||||
* If `is_X == True`, callers should expect the operator to have the
|
||||
property `X`. This is a promise that should be fulfilled, but is *not* a
|
||||
runtime assert. For example, finite floating point precision may result
|
||||
in these promises being violated.
|
||||
* If `is_X == False`, callers should expect the operator to not have `X`.
|
||||
* If `is_X == None` (the default), callers should have no expectation either
|
||||
way.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
operator,
|
||||
is_non_singular=None,
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
is_square=None,
|
||||
name=None):
|
||||
r"""Initialize a `LinearOperatorAdjoint`.
|
||||
|
||||
`LinearOperatorAdjoint` is initialized with an operator `A`. The `solve`
|
||||
and `matmul` methods effectively flip the `adjoint` argument. E.g.
|
||||
|
||||
```
|
||||
A = MyLinearOperator(...)
|
||||
B = LinearOperatorAdjoint(A)
|
||||
x = [....] # a vector
|
||||
|
||||
assert A.matvec(x, adjoint=True) == B.matvec(x, adjoint=False)
|
||||
```
|
||||
|
||||
Args:
|
||||
operator: `LinearOperator` object.
|
||||
is_non_singular: Expect that this operator is non-singular.
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
name: A name for this `LinearOperator`. Default is `operator.name +
|
||||
"_adjoint"`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `operator.is_non_singular` is False.
|
||||
"""
|
||||
|
||||
self._operator = operator
|
||||
|
||||
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
||||
# base operator.
|
||||
def _combined_hint(hint_str, provided_hint_value, message):
|
||||
"""Get combined hint in the case where operator.hint should equal hint."""
|
||||
op_hint = getattr(operator, hint_str)
|
||||
if op_hint is False and provided_hint_value:
|
||||
raise ValueError(message)
|
||||
if op_hint and provided_hint_value is False:
|
||||
raise ValueError(message)
|
||||
return (op_hint or provided_hint_value) or None
|
||||
|
||||
is_square = _combined_hint(
|
||||
"is_square", is_square,
|
||||
"An operator is square if and only if its adjoint is square.")
|
||||
|
||||
is_non_singular = _combined_hint(
|
||||
"is_non_singular", is_non_singular,
|
||||
"An operator is non-singular if and only if its adjoint is "
|
||||
"non-singular.")
|
||||
|
||||
is_self_adjoint = _combined_hint(
|
||||
"is_self_adjoint", is_self_adjoint,
|
||||
"An operator is self-adjoint if and only if its adjoint is "
|
||||
"self-adjoint.")
|
||||
|
||||
is_positive_definite = _combined_hint(
|
||||
"is_positive_definite", is_positive_definite,
|
||||
"An operator is positive-definite if and only if its adjoint is "
|
||||
"positive-definite.")
|
||||
|
||||
is_square = _combined_hint(
|
||||
"is_square", is_square,
|
||||
"An operator is square if and only if its adjoint is square.")
|
||||
|
||||
# Initialization.
|
||||
if name is None:
|
||||
name = operator.name + "_adjoint"
|
||||
with ops.name_scope(name, values=operator.graph_parents):
|
||||
super(LinearOperatorAdjoint, self).__init__(
|
||||
dtype=operator.dtype,
|
||||
graph_parents=operator.graph_parents,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
@property
|
||||
def operator(self):
|
||||
"""The operator before taking the adjoint."""
|
||||
return self._operator
|
||||
|
||||
def _assert_non_singular(self):
|
||||
return self.operator.assert_non_singular()
|
||||
|
||||
def _assert_positive_definite(self):
|
||||
return self.operator.assert_positive_definite()
|
||||
|
||||
def _assert_self_adjoint(self):
|
||||
return self.operator.assert_self_adjoint()
|
||||
|
||||
def _shape(self):
|
||||
return self.operator.shape
|
||||
|
||||
def _shape_tensor(self):
|
||||
return self.operator.shape_tensor()
|
||||
|
||||
def _matmul(self, x, adjoint=False, adjoint_arg=False):
|
||||
return self.operator.matmul(
|
||||
x, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
if self.is_self_adjoint:
|
||||
return self.operator.determinant()
|
||||
return math_ops.conj(self.operator.determinant())
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
return self.operator.log_abs_determinant()
|
||||
|
||||
def _trace(self):
|
||||
if self.is_self_adjoint:
|
||||
return self.operator.trace()
|
||||
return math_ops.conj(self.operator.trace())
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
return self.operator.solve(
|
||||
rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
|
||||
|
||||
def _to_dense(self):
|
||||
if self.is_self_adjoint:
|
||||
return self.operator.to_dense()
|
||||
return linalg.adjoint(self.operator.to_dense())
|
Loading…
Reference in New Issue
Block a user