LinearOperatorAdjoint added as a meta-class. Computes the adjoint of

an operator by flipping the adjoint arg in .matmul and .solve,
and other appropriate modifications.

This is hidden from the public API, with an upcoming method, .adjoint, the preferred way to get an adjoint, since it has more operator specific information.

PiperOrigin-RevId: 221869871
This commit is contained in:
A. Unique TensorFlower 2018-11-16 16:16:00 -08:00 committed by TensorFlower Gardener
parent 2121365e74
commit de8ebdb12f
3 changed files with 347 additions and 0 deletions

View File

@ -40,6 +40,28 @@ cuda_py_test(
],
)
cuda_py_test(
name = "linear_operator_adjoint_test",
size = "medium",
srcs = ["linear_operator_adjoint_test.py"],
additional_deps = [
"//tensorflow/python/ops/linalg",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
shard_count = 5,
tags = [
"noasan", # times out, b/63678675
"optonly", # times out
],
)
cuda_py_test(
name = "linear_operator_algebra_test",
size = "small",

View File

@ -0,0 +1,118 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_adjoint
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
linalg = linalg_lib
LinearOperatorAdjoint = linear_operator_adjoint.LinearOperatorAdjoint # pylint: disable=invalid-name
class LinearOperatorAdjointTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
def setUp(self):
self._atol[dtypes.complex64] = 1e-5
self._rtol[dtypes.complex64] = 1e-5
def _operator_and_matrix(self,
build_info,
dtype,
use_placeholder,
ensure_self_adjoint_and_pd=False):
shape = list(build_info.shape)
if ensure_self_adjoint_and_pd:
matrix = linear_operator_test_util.random_positive_definite_matrix(
shape, dtype, force_well_conditioned=True)
else:
matrix = linear_operator_test_util.random_tril_matrix(
shape, dtype, force_well_conditioned=True, remove_upper=True)
lin_op_matrix = matrix
if use_placeholder:
lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
if ensure_self_adjoint_and_pd:
operator = LinearOperatorAdjoint(
linalg.LinearOperatorFullMatrix(
lin_op_matrix, is_positive_definite=True, is_self_adjoint=True))
else:
operator = LinearOperatorAdjoint(
linalg.LinearOperatorLowerTriangular(lin_op_matrix))
return operator, linalg.adjoint(matrix)
def test_base_operator_hint_used(self):
# The matrix values do not effect auto-setting of the flags.
matrix = [[1., 0.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix,
is_positive_definite=True,
is_non_singular=True,
is_self_adjoint=False)
operator_adjoint = LinearOperatorAdjoint(operator)
self.assertTrue(operator_adjoint.is_positive_definite)
self.assertTrue(operator_adjoint.is_non_singular)
self.assertFalse(operator_adjoint.is_self_adjoint)
def test_supplied_hint_used(self):
# The matrix values do not effect auto-setting of the flags.
matrix = [[1., 0.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(matrix)
operator_adjoint = LinearOperatorAdjoint(
operator,
is_positive_definite=True,
is_non_singular=True,
is_self_adjoint=False)
self.assertTrue(operator_adjoint.is_positive_definite)
self.assertTrue(operator_adjoint.is_non_singular)
self.assertFalse(operator_adjoint.is_self_adjoint)
def test_contradicting_hints_raise(self):
# The matrix values do not effect auto-setting of the flags.
matrix = [[1., 0.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_positive_definite=False)
with self.assertRaisesRegexp(ValueError, "positive-definite"):
LinearOperatorAdjoint(operator, is_positive_definite=True)
operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=False)
with self.assertRaisesRegexp(ValueError, "self-adjoint"):
LinearOperatorAdjoint(operator, is_self_adjoint=True)
def test_name(self):
matrix = [[11., 0.], [1., 8.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, name="my_operator", is_non_singular=True)
operator = LinearOperatorAdjoint(operator)
self.assertEqual("my_operator_adjoint", operator.name)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,207 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Takes the adjoint of a `LinearOperator`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.util.tf_export import tf_export
__all__ = []
@tf_export("linalg.LinearOperatorAdjoint")
class LinearOperatorAdjoint(linear_operator.LinearOperator):
"""`LinearOperator` representing the adjoint of another operator.
This operator represents the adjoint of another operator.
```python
# Create a 2 x 2 linear operator.
operator = LinearOperatorFullMatrix([[1 - i., 3.], [0., 1. + i]])
operator_adjoint = LinearOperatorAdjoint(operator)
operator_adjoint.to_dense()
==> [[1. + i, 0.]
[3., 1 - i]]
operator_adjoint.shape
==> [2, 2]
operator_adjoint.log_abs_determinant()
==> - log(2)
x = ... Shape [2, 4] Tensor
operator_adjoint.matmul(x)
==> Shape [2, 4] Tensor, equal to operator.matmul(x, adjoint=True)
```
#### Performance
The performance of `LinearOperatorAdjoint` depends on the underlying
operators performance.
#### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning:
* If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. For example, finite floating point precision may result
in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
"""
def __init__(self,
operator,
is_non_singular=None,
is_self_adjoint=None,
is_positive_definite=None,
is_square=None,
name=None):
r"""Initialize a `LinearOperatorAdjoint`.
`LinearOperatorAdjoint` is initialized with an operator `A`. The `solve`
and `matmul` methods effectively flip the `adjoint` argument. E.g.
```
A = MyLinearOperator(...)
B = LinearOperatorAdjoint(A)
x = [....] # a vector
assert A.matvec(x, adjoint=True) == B.matvec(x, adjoint=False)
```
Args:
operator: `LinearOperator` object.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`. Default is `operator.name +
"_adjoint"`.
Raises:
ValueError: If `operator.is_non_singular` is False.
"""
self._operator = operator
# The congruency of is_non_singular and is_self_adjoint was checked in the
# base operator.
def _combined_hint(hint_str, provided_hint_value, message):
"""Get combined hint in the case where operator.hint should equal hint."""
op_hint = getattr(operator, hint_str)
if op_hint is False and provided_hint_value:
raise ValueError(message)
if op_hint and provided_hint_value is False:
raise ValueError(message)
return (op_hint or provided_hint_value) or None
is_square = _combined_hint(
"is_square", is_square,
"An operator is square if and only if its adjoint is square.")
is_non_singular = _combined_hint(
"is_non_singular", is_non_singular,
"An operator is non-singular if and only if its adjoint is "
"non-singular.")
is_self_adjoint = _combined_hint(
"is_self_adjoint", is_self_adjoint,
"An operator is self-adjoint if and only if its adjoint is "
"self-adjoint.")
is_positive_definite = _combined_hint(
"is_positive_definite", is_positive_definite,
"An operator is positive-definite if and only if its adjoint is "
"positive-definite.")
is_square = _combined_hint(
"is_square", is_square,
"An operator is square if and only if its adjoint is square.")
# Initialization.
if name is None:
name = operator.name + "_adjoint"
with ops.name_scope(name, values=operator.graph_parents):
super(LinearOperatorAdjoint, self).__init__(
dtype=operator.dtype,
graph_parents=operator.graph_parents,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
@property
def operator(self):
"""The operator before taking the adjoint."""
return self._operator
def _assert_non_singular(self):
return self.operator.assert_non_singular()
def _assert_positive_definite(self):
return self.operator.assert_positive_definite()
def _assert_self_adjoint(self):
return self.operator.assert_self_adjoint()
def _shape(self):
return self.operator.shape
def _shape_tensor(self):
return self.operator.shape_tensor()
def _matmul(self, x, adjoint=False, adjoint_arg=False):
return self.operator.matmul(
x, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _determinant(self):
if self.is_self_adjoint:
return self.operator.determinant()
return math_ops.conj(self.operator.determinant())
def _log_abs_determinant(self):
return self.operator.log_abs_determinant()
def _trace(self):
if self.is_self_adjoint:
return self.operator.trace()
return math_ops.conj(self.operator.trace())
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
return self.operator.solve(
rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _to_dense(self):
if self.is_self_adjoint:
return self.operator.to_dense()
return linalg.adjoint(self.operator.to_dense())