Meta linear operators tested for "tape safe"
PiperOrigin-RevId: 266182803
This commit is contained in:
parent
4f8a6dd61c
commit
709d160772
tensorflow/python/kernel_tests/linalg
@ -20,7 +20,9 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_adjoint
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
@ -31,6 +33,7 @@ linalg = linalg_lib
|
||||
LinearOperatorAdjoint = linear_operator_adjoint.LinearOperatorAdjoint # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorAdjointTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -239,7 +242,13 @@ class LinearOperatorAdjointTest(
|
||||
self.assertAllClose(
|
||||
inv_matrix.T.dot(x), self.evaluate(operator.H.solvevec(x)))
|
||||
|
||||
def test_tape_safe(self):
|
||||
matrix = variables_module.Variable([[1., 2.], [3., 4.]])
|
||||
operator = LinearOperatorAdjoint(linalg.LinearOperatorFullMatrix(matrix))
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorAdjointNonSquareTest(
|
||||
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
|
||||
"""Tests done in the base class NonSquareLinearOperatorDerivedClassTest."""
|
||||
|
@ -18,7 +18,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_inversion
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
@ -29,6 +31,7 @@ linalg = linalg_lib
|
||||
LinearOperatorInversion = linear_operator_inversion.LinearOperatorInversion # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorInversionTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -125,6 +128,11 @@ class LinearOperatorInversionTest(
|
||||
|
||||
self.assertEqual("my_operator_inv", operator.name)
|
||||
|
||||
def test_tape_safe(self):
|
||||
matrix = variables_module.Variable([[1., 2.], [3., 4.]])
|
||||
operator = LinearOperatorInversion(linalg.LinearOperatorFullMatrix(matrix))
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(LinearOperatorInversionTest)
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_kronecker as kronecker
|
||||
from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular
|
||||
@ -53,9 +54,9 @@ def _kronecker_dense(factors):
|
||||
|
||||
|
||||
class KroneckerDenseTest(test.TestCase):
|
||||
"""Test of `_kronecker_dense` function."""
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testKroneckerDenseMatrix(self):
|
||||
def test_kronecker_dense_matrix(self):
|
||||
x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32)
|
||||
y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32)
|
||||
# From explicitly writing out the kronecker product of x and y.
|
||||
@ -71,11 +72,13 @@ class KroneckerDenseTest(test.TestCase):
|
||||
[10., 15., -2., -3.],
|
||||
[5., 10., -1., -2.]], dtype=dtypes.float32)
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(_kronecker_dense([x, y]).eval(), self.evaluate(z))
|
||||
self.assertAllClose(_kronecker_dense([y, x]).eval(), self.evaluate(w))
|
||||
self.assertAllClose(
|
||||
self.evaluate(_kronecker_dense([x, y])), self.evaluate(z))
|
||||
self.assertAllClose(
|
||||
self.evaluate(_kronecker_dense([y, x])), self.evaluate(w))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SquareLinearOperatorKroneckerTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -252,6 +255,20 @@ class SquareLinearOperatorKroneckerTest(
|
||||
kronecker.LinearOperatorKronecker)
|
||||
self.assertEqual(2, len(inverse.operators))
|
||||
|
||||
def test_tape_safe(self):
|
||||
matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]])
|
||||
matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]])
|
||||
operator = kronecker.LinearOperatorKronecker(
|
||||
[
|
||||
linalg.LinearOperatorFullMatrix(
|
||||
matrix_1, is_non_singular=True),
|
||||
linalg.LinearOperatorFullMatrix(
|
||||
matrix_2, is_non_singular=True),
|
||||
],
|
||||
is_non_singular=True,
|
||||
)
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(SquareLinearOperatorKroneckerTest)
|
||||
|
Loading…
Reference in New Issue
Block a user