BUGFIX: LinearOperatorAdjoint and LinearOperatorInversion were not using

defining operator's hints correctly.
PiperOrigin-RevId: 255327878
This commit is contained in:
Ian Langmore 2019-06-26 21:51:40 -07:00 committed by TensorFlower Gardener
parent 5b9fc81bff
commit 2f8c378621
4 changed files with 101 additions and 40 deletions

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
@ -459,5 +460,48 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
operator, x).run() # pyformat: disable
class DummyOperatorWithHint(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase,
parameterized.TestCase):
@parameterized.named_parameters(
("none_none", None, None, None),
("none_true", None, True, True),
("true_none", True, None, True),
("true_true", True, True, True),
("none_false", None, False, False),
("false_none", False, None, False),
("false_false", False, False, False),
)
def test_computes_an_or_if_non_contradicting(self, operator_hint_value,
provided_hint_value,
expected_result):
self.assertEqual(
expected_result,
linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
operator=DummyOperatorWithHint(my_hint=operator_hint_value),
hint_attr_name="my_hint",
provided_hint_value=provided_hint_value,
message="should not be needed here"))
@parameterized.named_parameters(
("true_false", True, False),
("false_true", False, True),
)
def test_raises_if_contradicting(self, operator_hint_value,
provided_hint_value):
with self.assertRaisesRegexp(ValueError, "my error message"):
linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
operator=DummyOperatorWithHint(my_hint=operator_hint_value),
hint_attr_name="my_hint",
provided_hint_value=provided_hint_value,
message="my error message")
if __name__ == "__main__":
test.main()

View File

@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.util.tf_export import tf_export
__all__ = []
@ -116,38 +117,28 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
# The congruency of is_non_singular and is_self_adjoint was checked in the
# base operator.
def _combined_hint(hint_str, provided_hint_value, message):
"""Get combined hint in the case where operator.hint should equal hint."""
op_hint = getattr(operator, hint_str)
if op_hint is False and provided_hint_value:
raise ValueError(message)
if op_hint and provided_hint_value is False:
raise ValueError(message)
return (op_hint or provided_hint_value) or None
combine_hint = (
linear_operator_util.use_operator_or_provided_hint_unless_contradicting)
is_square = _combined_hint(
"is_square", is_square,
is_square = combine_hint(
operator, "is_square", is_square,
"An operator is square if and only if its adjoint is square.")
is_non_singular = _combined_hint(
"is_non_singular", is_non_singular,
is_non_singular = combine_hint(
operator, "is_non_singular", is_non_singular,
"An operator is non-singular if and only if its adjoint is "
"non-singular.")
is_self_adjoint = _combined_hint(
"is_self_adjoint", is_self_adjoint,
is_self_adjoint = combine_hint(
operator, "is_self_adjoint", is_self_adjoint,
"An operator is self-adjoint if and only if its adjoint is "
"self-adjoint.")
is_positive_definite = _combined_hint(
"is_positive_definite", is_positive_definite,
is_positive_definite = combine_hint(
operator, "is_positive_definite", is_positive_definite,
"An operator is positive-definite if and only if its adjoint is "
"positive-definite.")
is_square = _combined_hint(
"is_square", is_square,
"An operator is square if and only if its adjoint is square.")
# Initialization.
if name is None:
name = operator.name + "_adjoint"

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.util.tf_export import tf_export
__all__ = []
@ -129,38 +130,28 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
# The congruency of is_non_singular and is_self_adjoint was checked in the
# base operator. Other hints are, in this special case of inversion, ones
# that must be the same for base/derived operator.
def _combined_hint(hint_str, provided_hint_value, message):
"""Get combined hint in the case where operator.hint should equal hint."""
op_hint = getattr(operator, hint_str)
if op_hint is False and provided_hint_value:
raise ValueError(message)
if op_hint and provided_hint_value is False:
raise ValueError(message)
return (op_hint or provided_hint_value) or None
combine_hint = (
linear_operator_util.use_operator_or_provided_hint_unless_contradicting)
is_square = _combined_hint(
"is_square", is_square,
is_square = combine_hint(
operator, "is_square", is_square,
"An operator is square if and only if its inverse is square.")
is_non_singular = _combined_hint(
"is_non_singular", is_non_singular,
is_non_singular = combine_hint(
operator, "is_non_singular", is_non_singular,
"An operator is non-singular if and only if its inverse is "
"non-singular.")
is_self_adjoint = _combined_hint(
"is_self_adjoint", is_self_adjoint,
is_self_adjoint = combine_hint(
operator, "is_self_adjoint", is_self_adjoint,
"An operator is self-adjoint if and only if its inverse is "
"self-adjoint.")
is_positive_definite = _combined_hint(
"is_positive_definite", is_positive_definite,
is_positive_definite = combine_hint(
operator, "is_positive_definite", is_positive_definite,
"An operator is positive-definite if and only if its inverse is "
"positive-definite.")
is_square = _combined_hint(
"is_square", is_square,
"An operator is square if and only if its inverse is square.")
# Initialization.
if name is None:
name = operator.name + "_inv"

View File

@ -494,3 +494,38 @@ def _reshape_for_efficiency(a,
return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
return a, b_squashed_end, reshape_inv, still_need_to_transpose
################################################################################
# Helpers for hints.
################################################################################
def use_operator_or_provided_hint_unless_contradicting(
operator, hint_attr_name, provided_hint_value, message):
"""Get combined hint in the case where operator.hint should equal hint.
Args:
operator: LinearOperator that a meta-operator was initialized with.
hint_attr_name: String name for the attribute.
provided_hint_value: Bool or None. Value passed by user in initialization.
message: Error message to print if hints contradict.
Returns:
True, False, or None.
Raises:
ValueError: If hints contradict.
"""
op_hint = getattr(operator, hint_attr_name)
# pylint: disable=g-bool-id-comparison
if op_hint is False and provided_hint_value:
raise ValueError(message)
if op_hint and provided_hint_value is False:
raise ValueError(message)
if op_hint or provided_hint_value:
return True
if op_hint is False or provided_hint_value is False:
return False
# pylint: enable=g-bool-id-comparison
return None