Meta linear operators tested for "tape safe"

PiperOrigin-RevId: 266182803
This commit is contained in:
Ian Langmore 2019-08-29 11:11:41 -07:00 committed by TensorFlower Gardener
parent 4f8a6dd61c
commit 709d160772
3 changed files with 39 additions and 5 deletions

View File

@ -20,7 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_adjoint
from tensorflow.python.ops.linalg import linear_operator_test_util
@ -31,6 +33,7 @@ linalg = linalg_lib
LinearOperatorAdjoint = linear_operator_adjoint.LinearOperatorAdjoint # pylint: disable=invalid-name
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorAdjointTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -239,7 +242,13 @@ class LinearOperatorAdjointTest(
self.assertAllClose(
inv_matrix.T.dot(x), self.evaluate(operator.H.solvevec(x)))
def test_tape_safe(self):
matrix = variables_module.Variable([[1., 2.], [3., 4.]])
operator = LinearOperatorAdjoint(linalg.LinearOperatorFullMatrix(matrix))
self.check_tape_safe(operator)
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorAdjointNonSquareTest(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
"""Tests done in the base class NonSquareLinearOperatorDerivedClassTest."""

View File

@ -18,7 +18,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_inversion
from tensorflow.python.ops.linalg import linear_operator_test_util
@ -29,6 +31,7 @@ linalg = linalg_lib
LinearOperatorInversion = linear_operator_inversion.LinearOperatorInversion # pylint: disable=invalid-name
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorInversionTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -125,6 +128,11 @@ class LinearOperatorInversionTest(
self.assertEqual("my_operator_inv", operator.name)
def test_tape_safe(self):
matrix = variables_module.Variable([[1., 2.], [3., 4.]])
operator = LinearOperatorInversion(linalg.LinearOperatorFullMatrix(matrix))
self.check_tape_safe(operator)
if __name__ == "__main__":
linear_operator_test_util.add_tests(LinearOperatorInversionTest)

View File

@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_kronecker as kronecker
from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular
@ -53,9 +54,9 @@ def _kronecker_dense(factors):
class KroneckerDenseTest(test.TestCase):
"""Test of `_kronecker_dense` function."""
@test_util.run_deprecated_v1
def testKroneckerDenseMatrix(self):
def test_kronecker_dense_matrix(self):
x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32)
y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32)
# From explicitly writing out the kronecker product of x and y.
@ -71,11 +72,13 @@ class KroneckerDenseTest(test.TestCase):
[10., 15., -2., -3.],
[5., 10., -1., -2.]], dtype=dtypes.float32)
with self.cached_session():
self.assertAllClose(_kronecker_dense([x, y]).eval(), self.evaluate(z))
self.assertAllClose(_kronecker_dense([y, x]).eval(), self.evaluate(w))
self.assertAllClose(
self.evaluate(_kronecker_dense([x, y])), self.evaluate(z))
self.assertAllClose(
self.evaluate(_kronecker_dense([y, x])), self.evaluate(w))
@test_util.run_all_in_graph_and_eager_modes
class SquareLinearOperatorKroneckerTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -252,6 +255,20 @@ class SquareLinearOperatorKroneckerTest(
kronecker.LinearOperatorKronecker)
self.assertEqual(2, len(inverse.operators))
def test_tape_safe(self):
matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]])
matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]])
operator = kronecker.LinearOperatorKronecker(
[
linalg.LinearOperatorFullMatrix(
matrix_1, is_non_singular=True),
linalg.LinearOperatorFullMatrix(
matrix_2, is_non_singular=True),
],
is_non_singular=True,
)
self.check_tape_safe(operator)
if __name__ == "__main__":
linear_operator_test_util.add_tests(SquareLinearOperatorKroneckerTest)