From c26ac449e0c798e5527f565e95078e42c662952f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 15 May 2020 17:18:30 -0700 Subject: [PATCH] Enable tf.linalg.matrix_solve tests in eager mode. PiperOrigin-RevId: 311829192 Change-Id: I8d8c0fb2e28c6dd497a99724d4e2bcd78f2d2ed6 --- .../kernel_tests/matrix_solve_op_test.py | 96 ++++++++++--------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py index 0b6b403210c..bbd909c8e58 100644 --- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py @@ -21,14 +21,16 @@ from __future__ import print_function import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -56,19 +58,19 @@ class MatrixSolveOpTest(test.TestCase): a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) np_ans = np.linalg.solve(a_np, b) - for use_placeholder in False, True: - with self.cached_session(use_gpu=True) as sess: - if use_placeholder: - a_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) - b_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) - tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint) + for use_placeholder in set((False, not context.executing_eagerly())): + if use_placeholder: + a_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) + b_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) + tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint) + with self.cached_session(use_gpu=True) as sess: out = sess.run(tf_ans, {a_ph: a, b_ph: b}) - else: - tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) - out = self.evaluate(tf_ans) - self.assertEqual(tf_ans.get_shape(), out.shape) - self.assertEqual(np_ans.shape, out.shape) - self.assertAllClose(np_ans, out, atol=tol, rtol=tol) + else: + tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) + out = self.evaluate(tf_ans) + self.assertEqual(tf_ans.get_shape(), out.shape) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out, atol=tol, rtol=tol) def _generateMatrix(self, m, n): matrix = (np.random.normal(-5, 5, @@ -77,7 +79,7 @@ class MatrixSolveOpTest(test.TestCase): [m, n])) return matrix - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSolve(self): for n in 1, 2, 4, 9: matrix = self._generateMatrix(n, n) @@ -85,7 +87,7 @@ class MatrixSolveOpTest(test.TestCase): rhs = self._generateMatrix(n, nrhs) self._verifySolve(matrix, rhs) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSolveBatch(self): for n in 2, 5: matrix = self._generateMatrix(n, n) @@ -94,48 +96,50 @@ class MatrixSolveOpTest(test.TestCase): for batch_dims in [[2], [2, 2], [7, 4]]: self._verifySolve(matrix, rhs, batch_dims=batch_dims) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNonSquareMatrix(self): # When the solve of a non-square matrix is attempted we should return # an error - with self.session(use_gpu=True): - with self.assertRaises(ValueError): - matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]]) - linalg_ops.matrix_solve(matrix, matrix) + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): + matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]]) + self.evaluate(linalg_ops.matrix_solve(matrix, matrix)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The matrix and right-hand sides should have the same number of rows. - with self.session(use_gpu=True): - matrix = constant_op.constant([[1., 0.], [0., 1.]]) - rhs = constant_op.constant([[1., 0.]]) - with self.assertRaises(ValueError): - linalg_ops.matrix_solve(matrix, rhs) + matrix = constant_op.constant([[1., 0.], [0., 1.]]) + rhs = constant_op.constant([[1., 0.]]) + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): + self.evaluate(linalg_ops.matrix_solve(matrix, rhs)) def testNotInvertible(self): # The input should be invertible. - with self.session(use_gpu=True): - with self.assertRaisesOpError("Input matrix is not invertible."): - # All rows of the matrix below add to zero - matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.], - [0., -1., 1.]]) - linalg_ops.matrix_solve(matrix, matrix).eval() + with self.assertRaisesOpError("Input matrix is not invertible."): + # All rows of the matrix below add to zero + matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.], + [0., -1., 1.]]) + self.evaluate(linalg_ops.matrix_solve(matrix, matrix)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrent(self): - with self.session(use_gpu=True) as sess: - all_ops = [] - for adjoint_ in False, True: - lhs1 = random_ops.random_normal([3, 3], seed=42) - lhs2 = random_ops.random_normal([3, 3], seed=42) - rhs1 = random_ops.random_normal([3, 3], seed=42) - rhs2 = random_ops.random_normal([3, 3], seed=42) - s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_) - s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_) - all_ops += [s1, s2] - val = self.evaluate(all_ops) - self.assertAllEqual(val[0], val[1]) - self.assertAllEqual(val[2], val[3]) + seed = [42, 24] + matrix_shape = [3, 3] + all_ops = [] + for adjoint_ in False, True: + lhs1 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + lhs2 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + rhs1 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + rhs2 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_) + s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_) + all_ops += [s1, s2] + val = self.evaluate(all_ops) + for i in range(0, len(all_ops), 2): + self.assertAllEqual(val[i], val[i + 1]) class MatrixSolveBenchmark(test.Benchmark):