diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py index 3fedb277d9f..03086e64ecf 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.framework import dtypes @@ -459,5 +460,48 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase): operator, x).run() # pyformat: disable +class DummyOperatorWithHint(object): + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase, + parameterized.TestCase): + + @parameterized.named_parameters( + ("none_none", None, None, None), + ("none_true", None, True, True), + ("true_none", True, None, True), + ("true_true", True, True, True), + ("none_false", None, False, False), + ("false_none", False, None, False), + ("false_false", False, False, False), + ) + def test_computes_an_or_if_non_contradicting(self, operator_hint_value, + provided_hint_value, + expected_result): + self.assertEqual( + expected_result, + linear_operator_util.use_operator_or_provided_hint_unless_contradicting( + operator=DummyOperatorWithHint(my_hint=operator_hint_value), + hint_attr_name="my_hint", + provided_hint_value=provided_hint_value, + message="should not be needed here")) + + @parameterized.named_parameters( + ("true_false", True, False), + ("false_true", False, True), + ) + def test_raises_if_contradicting(self, operator_hint_value, + provided_hint_value): + with self.assertRaisesRegexp(ValueError, "my error message"): + linear_operator_util.use_operator_or_provided_hint_unless_contradicting( + operator=DummyOperatorWithHint(my_hint=operator_hint_value), + hint_attr_name="my_hint", + provided_hint_value=provided_hint_value, + message="my error message") + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/linalg/linear_operator_adjoint.py b/tensorflow/python/ops/linalg/linear_operator_adjoint.py index 7ee4752d264..d62762e148c 100644 --- a/tensorflow/python/ops/linalg/linear_operator_adjoint.py +++ b/tensorflow/python/ops/linalg/linear_operator_adjoint.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as linalg from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.util.tf_export import tf_export __all__ = [] @@ -116,38 +117,28 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator): # The congruency of is_non_singular and is_self_adjoint was checked in the # base operator. - def _combined_hint(hint_str, provided_hint_value, message): - """Get combined hint in the case where operator.hint should equal hint.""" - op_hint = getattr(operator, hint_str) - if op_hint is False and provided_hint_value: - raise ValueError(message) - if op_hint and provided_hint_value is False: - raise ValueError(message) - return (op_hint or provided_hint_value) or None + combine_hint = ( + linear_operator_util.use_operator_or_provided_hint_unless_contradicting) - is_square = _combined_hint( - "is_square", is_square, + is_square = combine_hint( + operator, "is_square", is_square, "An operator is square if and only if its adjoint is square.") - is_non_singular = _combined_hint( - "is_non_singular", is_non_singular, + is_non_singular = combine_hint( + operator, "is_non_singular", is_non_singular, "An operator is non-singular if and only if its adjoint is " "non-singular.") - is_self_adjoint = _combined_hint( - "is_self_adjoint", is_self_adjoint, + is_self_adjoint = combine_hint( + operator, "is_self_adjoint", is_self_adjoint, "An operator is self-adjoint if and only if its adjoint is " "self-adjoint.") - is_positive_definite = _combined_hint( - "is_positive_definite", is_positive_definite, + is_positive_definite = combine_hint( + operator, "is_positive_definite", is_positive_definite, "An operator is positive-definite if and only if its adjoint is " "positive-definite.") - is_square = _combined_hint( - "is_square", is_square, - "An operator is square if and only if its adjoint is square.") - # Initialization. if name is None: name = operator.name + "_adjoint" diff --git a/tensorflow/python/ops/linalg/linear_operator_inversion.py b/tensorflow/python/ops/linalg/linear_operator_inversion.py index 7aa4b40e16b..e86941fcbb0 100644 --- a/tensorflow/python/ops/linalg/linear_operator_inversion.py +++ b/tensorflow/python/ops/linalg/linear_operator_inversion.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.util.tf_export import tf_export __all__ = [] @@ -129,38 +130,28 @@ class LinearOperatorInversion(linear_operator.LinearOperator): # The congruency of is_non_singular and is_self_adjoint was checked in the # base operator. Other hints are, in this special case of inversion, ones # that must be the same for base/derived operator. - def _combined_hint(hint_str, provided_hint_value, message): - """Get combined hint in the case where operator.hint should equal hint.""" - op_hint = getattr(operator, hint_str) - if op_hint is False and provided_hint_value: - raise ValueError(message) - if op_hint and provided_hint_value is False: - raise ValueError(message) - return (op_hint or provided_hint_value) or None + combine_hint = ( + linear_operator_util.use_operator_or_provided_hint_unless_contradicting) - is_square = _combined_hint( - "is_square", is_square, + is_square = combine_hint( + operator, "is_square", is_square, "An operator is square if and only if its inverse is square.") - is_non_singular = _combined_hint( - "is_non_singular", is_non_singular, + is_non_singular = combine_hint( + operator, "is_non_singular", is_non_singular, "An operator is non-singular if and only if its inverse is " "non-singular.") - is_self_adjoint = _combined_hint( - "is_self_adjoint", is_self_adjoint, + is_self_adjoint = combine_hint( + operator, "is_self_adjoint", is_self_adjoint, "An operator is self-adjoint if and only if its inverse is " "self-adjoint.") - is_positive_definite = _combined_hint( - "is_positive_definite", is_positive_definite, + is_positive_definite = combine_hint( + operator, "is_positive_definite", is_positive_definite, "An operator is positive-definite if and only if its inverse is " "positive-definite.") - is_square = _combined_hint( - "is_square", is_square, - "An operator is square if and only if its inverse is square.") - # Initialization. if name is None: name = operator.name + "_inv" diff --git a/tensorflow/python/ops/linalg/linear_operator_util.py b/tensorflow/python/ops/linalg/linear_operator_util.py index b7be90a26d2..1057c28a4bd 100644 --- a/tensorflow/python/ops/linalg/linear_operator_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_util.py @@ -494,3 +494,38 @@ def _reshape_for_efficiency(a, return array_ops.transpose(y_extra_on_end, perm=inverse_perm) return a, b_squashed_end, reshape_inv, still_need_to_transpose + + +################################################################################ +# Helpers for hints. +################################################################################ + + +def use_operator_or_provided_hint_unless_contradicting( + operator, hint_attr_name, provided_hint_value, message): + """Get combined hint in the case where operator.hint should equal hint. + + Args: + operator: LinearOperator that a meta-operator was initialized with. + hint_attr_name: String name for the attribute. + provided_hint_value: Bool or None. Value passed by user in initialization. + message: Error message to print if hints contradict. + + Returns: + True, False, or None. + + Raises: + ValueError: If hints contradict. + """ + op_hint = getattr(operator, hint_attr_name) + # pylint: disable=g-bool-id-comparison + if op_hint is False and provided_hint_value: + raise ValueError(message) + if op_hint and provided_hint_value is False: + raise ValueError(message) + if op_hint or provided_hint_value: + return True + if op_hint is False or provided_hint_value is False: + return False + # pylint: enable=g-bool-id-comparison + return None