BUGFIX: LinearOperatorAdjoint and LinearOperatorInversion were not using
defining operator's hints correctly. PiperOrigin-RevId: 255327878
This commit is contained in:
parent
5b9fc81bff
commit
2f8c378621
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -459,5 +460,48 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
|
||||
operator, x).run() # pyformat: disable
|
||||
|
||||
|
||||
class DummyOperatorWithHint(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("none_none", None, None, None),
|
||||
("none_true", None, True, True),
|
||||
("true_none", True, None, True),
|
||||
("true_true", True, True, True),
|
||||
("none_false", None, False, False),
|
||||
("false_none", False, None, False),
|
||||
("false_false", False, False, False),
|
||||
)
|
||||
def test_computes_an_or_if_non_contradicting(self, operator_hint_value,
|
||||
provided_hint_value,
|
||||
expected_result):
|
||||
self.assertEqual(
|
||||
expected_result,
|
||||
linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
|
||||
operator=DummyOperatorWithHint(my_hint=operator_hint_value),
|
||||
hint_attr_name="my_hint",
|
||||
provided_hint_value=provided_hint_value,
|
||||
message="should not be needed here"))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("true_false", True, False),
|
||||
("false_true", False, True),
|
||||
)
|
||||
def test_raises_if_contradicting(self, operator_hint_value,
|
||||
provided_hint_value):
|
||||
with self.assertRaisesRegexp(ValueError, "my error message"):
|
||||
linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
|
||||
operator=DummyOperatorWithHint(my_hint=operator_hint_value),
|
||||
hint_attr_name="my_hint",
|
||||
provided_hint_value=provided_hint_value,
|
||||
message="my error message")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
__all__ = []
|
||||
@ -116,38 +117,28 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
|
||||
|
||||
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
||||
# base operator.
|
||||
def _combined_hint(hint_str, provided_hint_value, message):
|
||||
"""Get combined hint in the case where operator.hint should equal hint."""
|
||||
op_hint = getattr(operator, hint_str)
|
||||
if op_hint is False and provided_hint_value:
|
||||
raise ValueError(message)
|
||||
if op_hint and provided_hint_value is False:
|
||||
raise ValueError(message)
|
||||
return (op_hint or provided_hint_value) or None
|
||||
combine_hint = (
|
||||
linear_operator_util.use_operator_or_provided_hint_unless_contradicting)
|
||||
|
||||
is_square = _combined_hint(
|
||||
"is_square", is_square,
|
||||
is_square = combine_hint(
|
||||
operator, "is_square", is_square,
|
||||
"An operator is square if and only if its adjoint is square.")
|
||||
|
||||
is_non_singular = _combined_hint(
|
||||
"is_non_singular", is_non_singular,
|
||||
is_non_singular = combine_hint(
|
||||
operator, "is_non_singular", is_non_singular,
|
||||
"An operator is non-singular if and only if its adjoint is "
|
||||
"non-singular.")
|
||||
|
||||
is_self_adjoint = _combined_hint(
|
||||
"is_self_adjoint", is_self_adjoint,
|
||||
is_self_adjoint = combine_hint(
|
||||
operator, "is_self_adjoint", is_self_adjoint,
|
||||
"An operator is self-adjoint if and only if its adjoint is "
|
||||
"self-adjoint.")
|
||||
|
||||
is_positive_definite = _combined_hint(
|
||||
"is_positive_definite", is_positive_definite,
|
||||
is_positive_definite = combine_hint(
|
||||
operator, "is_positive_definite", is_positive_definite,
|
||||
"An operator is positive-definite if and only if its adjoint is "
|
||||
"positive-definite.")
|
||||
|
||||
is_square = _combined_hint(
|
||||
"is_square", is_square,
|
||||
"An operator is square if and only if its adjoint is square.")
|
||||
|
||||
# Initialization.
|
||||
if name is None:
|
||||
name = operator.name + "_adjoint"
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
__all__ = []
|
||||
@ -129,38 +130,28 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
|
||||
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
||||
# base operator. Other hints are, in this special case of inversion, ones
|
||||
# that must be the same for base/derived operator.
|
||||
def _combined_hint(hint_str, provided_hint_value, message):
|
||||
"""Get combined hint in the case where operator.hint should equal hint."""
|
||||
op_hint = getattr(operator, hint_str)
|
||||
if op_hint is False and provided_hint_value:
|
||||
raise ValueError(message)
|
||||
if op_hint and provided_hint_value is False:
|
||||
raise ValueError(message)
|
||||
return (op_hint or provided_hint_value) or None
|
||||
combine_hint = (
|
||||
linear_operator_util.use_operator_or_provided_hint_unless_contradicting)
|
||||
|
||||
is_square = _combined_hint(
|
||||
"is_square", is_square,
|
||||
is_square = combine_hint(
|
||||
operator, "is_square", is_square,
|
||||
"An operator is square if and only if its inverse is square.")
|
||||
|
||||
is_non_singular = _combined_hint(
|
||||
"is_non_singular", is_non_singular,
|
||||
is_non_singular = combine_hint(
|
||||
operator, "is_non_singular", is_non_singular,
|
||||
"An operator is non-singular if and only if its inverse is "
|
||||
"non-singular.")
|
||||
|
||||
is_self_adjoint = _combined_hint(
|
||||
"is_self_adjoint", is_self_adjoint,
|
||||
is_self_adjoint = combine_hint(
|
||||
operator, "is_self_adjoint", is_self_adjoint,
|
||||
"An operator is self-adjoint if and only if its inverse is "
|
||||
"self-adjoint.")
|
||||
|
||||
is_positive_definite = _combined_hint(
|
||||
"is_positive_definite", is_positive_definite,
|
||||
is_positive_definite = combine_hint(
|
||||
operator, "is_positive_definite", is_positive_definite,
|
||||
"An operator is positive-definite if and only if its inverse is "
|
||||
"positive-definite.")
|
||||
|
||||
is_square = _combined_hint(
|
||||
"is_square", is_square,
|
||||
"An operator is square if and only if its inverse is square.")
|
||||
|
||||
# Initialization.
|
||||
if name is None:
|
||||
name = operator.name + "_inv"
|
||||
|
@ -494,3 +494,38 @@ def _reshape_for_efficiency(a,
|
||||
return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
|
||||
|
||||
return a, b_squashed_end, reshape_inv, still_need_to_transpose
|
||||
|
||||
|
||||
################################################################################
|
||||
# Helpers for hints.
|
||||
################################################################################
|
||||
|
||||
|
||||
def use_operator_or_provided_hint_unless_contradicting(
|
||||
operator, hint_attr_name, provided_hint_value, message):
|
||||
"""Get combined hint in the case where operator.hint should equal hint.
|
||||
|
||||
Args:
|
||||
operator: LinearOperator that a meta-operator was initialized with.
|
||||
hint_attr_name: String name for the attribute.
|
||||
provided_hint_value: Bool or None. Value passed by user in initialization.
|
||||
message: Error message to print if hints contradict.
|
||||
|
||||
Returns:
|
||||
True, False, or None.
|
||||
|
||||
Raises:
|
||||
ValueError: If hints contradict.
|
||||
"""
|
||||
op_hint = getattr(operator, hint_attr_name)
|
||||
# pylint: disable=g-bool-id-comparison
|
||||
if op_hint is False and provided_hint_value:
|
||||
raise ValueError(message)
|
||||
if op_hint and provided_hint_value is False:
|
||||
raise ValueError(message)
|
||||
if op_hint or provided_hint_value:
|
||||
return True
|
||||
if op_hint is False or provided_hint_value is False:
|
||||
return False
|
||||
# pylint: enable=g-bool-id-comparison
|
||||
return None
|
||||
|
Loading…
Reference in New Issue
Block a user