diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py index c076a5b3724..c03203f02e5 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.ops.linalg import linear_operator_adjoint from tensorflow.python.ops.linalg import linear_operator_test_util @@ -31,6 +33,7 @@ linalg = linalg_lib LinearOperatorAdjoint = linear_operator_adjoint.LinearOperatorAdjoint # pylint: disable=invalid-name +@test_util.run_all_in_graph_and_eager_modes class LinearOperatorAdjointTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" @@ -239,7 +242,13 @@ class LinearOperatorAdjointTest( self.assertAllClose( inv_matrix.T.dot(x), self.evaluate(operator.H.solvevec(x))) + def test_tape_safe(self): + matrix = variables_module.Variable([[1., 2.], [3., 4.]]) + operator = LinearOperatorAdjoint(linalg.LinearOperatorFullMatrix(matrix)) + self.check_tape_safe(operator) + +@test_util.run_all_in_graph_and_eager_modes class LinearOperatorAdjointNonSquareTest( linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): """Tests done in the base class NonSquareLinearOperatorDerivedClassTest.""" diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py index bab2c9b9d6c..4b2ce3d9da7 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_inversion_test.py @@ -18,7 +18,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.ops.linalg import linear_operator_inversion from tensorflow.python.ops.linalg import linear_operator_test_util @@ -29,6 +31,7 @@ linalg = linalg_lib LinearOperatorInversion = linear_operator_inversion.LinearOperatorInversion # pylint: disable=invalid-name +@test_util.run_all_in_graph_and_eager_modes class LinearOperatorInversionTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" @@ -125,6 +128,11 @@ class LinearOperatorInversionTest( self.assertEqual("my_operator_inv", operator.name) + def test_tape_safe(self): + matrix = variables_module.Variable([[1., 2.], [3., 4.]]) + operator = LinearOperatorInversion(linalg.LinearOperatorFullMatrix(matrix)) + self.check_tape_safe(operator) + if __name__ == "__main__": linear_operator_test_util.add_tests(LinearOperatorInversionTest) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py index 1dc296b3534..04d8ab2938a 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables as variables_module from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.ops.linalg import linear_operator_kronecker as kronecker from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular @@ -53,9 +54,9 @@ def _kronecker_dense(factors): class KroneckerDenseTest(test.TestCase): + """Test of `_kronecker_dense` function.""" - @test_util.run_deprecated_v1 - def testKroneckerDenseMatrix(self): + def test_kronecker_dense_matrix(self): x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32) y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32) # From explicitly writing out the kronecker product of x and y. @@ -71,11 +72,13 @@ class KroneckerDenseTest(test.TestCase): [10., 15., -2., -3.], [5., 10., -1., -2.]], dtype=dtypes.float32) - with self.cached_session(): - self.assertAllClose(_kronecker_dense([x, y]).eval(), self.evaluate(z)) - self.assertAllClose(_kronecker_dense([y, x]).eval(), self.evaluate(w)) + self.assertAllClose( + self.evaluate(_kronecker_dense([x, y])), self.evaluate(z)) + self.assertAllClose( + self.evaluate(_kronecker_dense([y, x])), self.evaluate(w)) +@test_util.run_all_in_graph_and_eager_modes class SquareLinearOperatorKroneckerTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" @@ -252,6 +255,20 @@ class SquareLinearOperatorKroneckerTest( kronecker.LinearOperatorKronecker) self.assertEqual(2, len(inverse.operators)) + def test_tape_safe(self): + matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]]) + matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]]) + operator = kronecker.LinearOperatorKronecker( + [ + linalg.LinearOperatorFullMatrix( + matrix_1, is_non_singular=True), + linalg.LinearOperatorFullMatrix( + matrix_2, is_non_singular=True), + ], + is_non_singular=True, + ) + self.check_tape_safe(operator) + if __name__ == "__main__": linear_operator_test_util.add_tests(SquareLinearOperatorKroneckerTest)