BUGFIX: LinearOperatorAdjoint and LinearOperatorInversion were not using
defining operator's hints correctly. PiperOrigin-RevId: 255327878
This commit is contained in:
parent
5b9fc81bff
commit
2f8c378621
@ -17,6 +17,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -459,5 +460,48 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
|
|||||||
operator, x).run() # pyformat: disable
|
operator, x).run() # pyformat: disable
|
||||||
|
|
||||||
|
|
||||||
|
class DummyOperatorWithHint(object):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.__dict__.update(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("none_none", None, None, None),
|
||||||
|
("none_true", None, True, True),
|
||||||
|
("true_none", True, None, True),
|
||||||
|
("true_true", True, True, True),
|
||||||
|
("none_false", None, False, False),
|
||||||
|
("false_none", False, None, False),
|
||||||
|
("false_false", False, False, False),
|
||||||
|
)
|
||||||
|
def test_computes_an_or_if_non_contradicting(self, operator_hint_value,
|
||||||
|
provided_hint_value,
|
||||||
|
expected_result):
|
||||||
|
self.assertEqual(
|
||||||
|
expected_result,
|
||||||
|
linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
|
||||||
|
operator=DummyOperatorWithHint(my_hint=operator_hint_value),
|
||||||
|
hint_attr_name="my_hint",
|
||||||
|
provided_hint_value=provided_hint_value,
|
||||||
|
message="should not be needed here"))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("true_false", True, False),
|
||||||
|
("false_true", False, True),
|
||||||
|
)
|
||||||
|
def test_raises_if_contradicting(self, operator_hint_value,
|
||||||
|
provided_hint_value):
|
||||||
|
with self.assertRaisesRegexp(ValueError, "my error message"):
|
||||||
|
linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
|
||||||
|
operator=DummyOperatorWithHint(my_hint=operator_hint_value),
|
||||||
|
hint_attr_name="my_hint",
|
||||||
|
provided_hint_value=provided_hint_value,
|
||||||
|
message="my error message")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
||||||
from tensorflow.python.ops.linalg import linear_operator
|
from tensorflow.python.ops.linalg import linear_operator
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_util
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -116,38 +117,28 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
|
|||||||
|
|
||||||
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
||||||
# base operator.
|
# base operator.
|
||||||
def _combined_hint(hint_str, provided_hint_value, message):
|
combine_hint = (
|
||||||
"""Get combined hint in the case where operator.hint should equal hint."""
|
linear_operator_util.use_operator_or_provided_hint_unless_contradicting)
|
||||||
op_hint = getattr(operator, hint_str)
|
|
||||||
if op_hint is False and provided_hint_value:
|
|
||||||
raise ValueError(message)
|
|
||||||
if op_hint and provided_hint_value is False:
|
|
||||||
raise ValueError(message)
|
|
||||||
return (op_hint or provided_hint_value) or None
|
|
||||||
|
|
||||||
is_square = _combined_hint(
|
is_square = combine_hint(
|
||||||
"is_square", is_square,
|
operator, "is_square", is_square,
|
||||||
"An operator is square if and only if its adjoint is square.")
|
"An operator is square if and only if its adjoint is square.")
|
||||||
|
|
||||||
is_non_singular = _combined_hint(
|
is_non_singular = combine_hint(
|
||||||
"is_non_singular", is_non_singular,
|
operator, "is_non_singular", is_non_singular,
|
||||||
"An operator is non-singular if and only if its adjoint is "
|
"An operator is non-singular if and only if its adjoint is "
|
||||||
"non-singular.")
|
"non-singular.")
|
||||||
|
|
||||||
is_self_adjoint = _combined_hint(
|
is_self_adjoint = combine_hint(
|
||||||
"is_self_adjoint", is_self_adjoint,
|
operator, "is_self_adjoint", is_self_adjoint,
|
||||||
"An operator is self-adjoint if and only if its adjoint is "
|
"An operator is self-adjoint if and only if its adjoint is "
|
||||||
"self-adjoint.")
|
"self-adjoint.")
|
||||||
|
|
||||||
is_positive_definite = _combined_hint(
|
is_positive_definite = combine_hint(
|
||||||
"is_positive_definite", is_positive_definite,
|
operator, "is_positive_definite", is_positive_definite,
|
||||||
"An operator is positive-definite if and only if its adjoint is "
|
"An operator is positive-definite if and only if its adjoint is "
|
||||||
"positive-definite.")
|
"positive-definite.")
|
||||||
|
|
||||||
is_square = _combined_hint(
|
|
||||||
"is_square", is_square,
|
|
||||||
"An operator is square if and only if its adjoint is square.")
|
|
||||||
|
|
||||||
# Initialization.
|
# Initialization.
|
||||||
if name is None:
|
if name is None:
|
||||||
name = operator.name + "_adjoint"
|
name = operator.name + "_adjoint"
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops.linalg import linear_operator
|
from tensorflow.python.ops.linalg import linear_operator
|
||||||
|
from tensorflow.python.ops.linalg import linear_operator_util
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -129,38 +130,28 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
|
|||||||
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
# The congruency of is_non_singular and is_self_adjoint was checked in the
|
||||||
# base operator. Other hints are, in this special case of inversion, ones
|
# base operator. Other hints are, in this special case of inversion, ones
|
||||||
# that must be the same for base/derived operator.
|
# that must be the same for base/derived operator.
|
||||||
def _combined_hint(hint_str, provided_hint_value, message):
|
combine_hint = (
|
||||||
"""Get combined hint in the case where operator.hint should equal hint."""
|
linear_operator_util.use_operator_or_provided_hint_unless_contradicting)
|
||||||
op_hint = getattr(operator, hint_str)
|
|
||||||
if op_hint is False and provided_hint_value:
|
|
||||||
raise ValueError(message)
|
|
||||||
if op_hint and provided_hint_value is False:
|
|
||||||
raise ValueError(message)
|
|
||||||
return (op_hint or provided_hint_value) or None
|
|
||||||
|
|
||||||
is_square = _combined_hint(
|
is_square = combine_hint(
|
||||||
"is_square", is_square,
|
operator, "is_square", is_square,
|
||||||
"An operator is square if and only if its inverse is square.")
|
"An operator is square if and only if its inverse is square.")
|
||||||
|
|
||||||
is_non_singular = _combined_hint(
|
is_non_singular = combine_hint(
|
||||||
"is_non_singular", is_non_singular,
|
operator, "is_non_singular", is_non_singular,
|
||||||
"An operator is non-singular if and only if its inverse is "
|
"An operator is non-singular if and only if its inverse is "
|
||||||
"non-singular.")
|
"non-singular.")
|
||||||
|
|
||||||
is_self_adjoint = _combined_hint(
|
is_self_adjoint = combine_hint(
|
||||||
"is_self_adjoint", is_self_adjoint,
|
operator, "is_self_adjoint", is_self_adjoint,
|
||||||
"An operator is self-adjoint if and only if its inverse is "
|
"An operator is self-adjoint if and only if its inverse is "
|
||||||
"self-adjoint.")
|
"self-adjoint.")
|
||||||
|
|
||||||
is_positive_definite = _combined_hint(
|
is_positive_definite = combine_hint(
|
||||||
"is_positive_definite", is_positive_definite,
|
operator, "is_positive_definite", is_positive_definite,
|
||||||
"An operator is positive-definite if and only if its inverse is "
|
"An operator is positive-definite if and only if its inverse is "
|
||||||
"positive-definite.")
|
"positive-definite.")
|
||||||
|
|
||||||
is_square = _combined_hint(
|
|
||||||
"is_square", is_square,
|
|
||||||
"An operator is square if and only if its inverse is square.")
|
|
||||||
|
|
||||||
# Initialization.
|
# Initialization.
|
||||||
if name is None:
|
if name is None:
|
||||||
name = operator.name + "_inv"
|
name = operator.name + "_inv"
|
||||||
|
@ -494,3 +494,38 @@ def _reshape_for_efficiency(a,
|
|||||||
return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
|
return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
|
||||||
|
|
||||||
return a, b_squashed_end, reshape_inv, still_need_to_transpose
|
return a, b_squashed_end, reshape_inv, still_need_to_transpose
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Helpers for hints.
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def use_operator_or_provided_hint_unless_contradicting(
|
||||||
|
operator, hint_attr_name, provided_hint_value, message):
|
||||||
|
"""Get combined hint in the case where operator.hint should equal hint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operator: LinearOperator that a meta-operator was initialized with.
|
||||||
|
hint_attr_name: String name for the attribute.
|
||||||
|
provided_hint_value: Bool or None. Value passed by user in initialization.
|
||||||
|
message: Error message to print if hints contradict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True, False, or None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If hints contradict.
|
||||||
|
"""
|
||||||
|
op_hint = getattr(operator, hint_attr_name)
|
||||||
|
# pylint: disable=g-bool-id-comparison
|
||||||
|
if op_hint is False and provided_hint_value:
|
||||||
|
raise ValueError(message)
|
||||||
|
if op_hint and provided_hint_value is False:
|
||||||
|
raise ValueError(message)
|
||||||
|
if op_hint or provided_hint_value:
|
||||||
|
return True
|
||||||
|
if op_hint is False or provided_hint_value is False:
|
||||||
|
return False
|
||||||
|
# pylint: enable=g-bool-id-comparison
|
||||||
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user