Enable tf.linalg.matrix_solve tests in eager mode.

PiperOrigin-RevId: 311829192
Change-Id: I8d8c0fb2e28c6dd497a99724d4e2bcd78f2d2ed6
This commit is contained in:
A. Unique TensorFlower 2020-05-15 17:18:30 -07:00 committed by TensorFlower Gardener
parent a3746cc77a
commit c26ac449e0
1 changed files with 50 additions and 46 deletions

View File

@ -21,14 +21,16 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@ -56,19 +58,19 @@ class MatrixSolveOpTest(test.TestCase):
a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1])
np_ans = np.linalg.solve(a_np, b)
for use_placeholder in False, True:
with self.cached_session(use_gpu=True) as sess:
if use_placeholder:
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
for use_placeholder in set((False, not context.executing_eagerly())):
if use_placeholder:
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
with self.cached_session(use_gpu=True) as sess:
out = sess.run(tf_ans, {a_ph: a, b_ph: b})
else:
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
out = self.evaluate(tf_ans)
self.assertEqual(tf_ans.get_shape(), out.shape)
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
else:
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
out = self.evaluate(tf_ans)
self.assertEqual(tf_ans.get_shape(), out.shape)
self.assertEqual(np_ans.shape, out.shape)
self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
def _generateMatrix(self, m, n):
matrix = (np.random.normal(-5, 5,
@ -77,7 +79,7 @@ class MatrixSolveOpTest(test.TestCase):
[m, n]))
return matrix
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSolve(self):
for n in 1, 2, 4, 9:
matrix = self._generateMatrix(n, n)
@ -85,7 +87,7 @@ class MatrixSolveOpTest(test.TestCase):
rhs = self._generateMatrix(n, nrhs)
self._verifySolve(matrix, rhs)
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSolveBatch(self):
for n in 2, 5:
matrix = self._generateMatrix(n, n)
@ -94,48 +96,50 @@ class MatrixSolveOpTest(test.TestCase):
for batch_dims in [[2], [2, 2], [7, 4]]:
self._verifySolve(matrix, rhs, batch_dims=batch_dims)
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testNonSquareMatrix(self):
# When the solve of a non-square matrix is attempted we should return
# an error
with self.session(use_gpu=True):
with self.assertRaises(ValueError):
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
linalg_ops.matrix_solve(matrix, matrix)
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testWrongDimensions(self):
# The matrix and right-hand sides should have the same number of rows.
with self.session(use_gpu=True):
matrix = constant_op.constant([[1., 0.], [0., 1.]])
rhs = constant_op.constant([[1., 0.]])
with self.assertRaises(ValueError):
linalg_ops.matrix_solve(matrix, rhs)
matrix = constant_op.constant([[1., 0.], [0., 1.]])
rhs = constant_op.constant([[1., 0.]])
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
self.evaluate(linalg_ops.matrix_solve(matrix, rhs))
def testNotInvertible(self):
# The input should be invertible.
with self.session(use_gpu=True):
with self.assertRaisesOpError("Input matrix is not invertible."):
# All rows of the matrix below add to zero
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
[0., -1., 1.]])
linalg_ops.matrix_solve(matrix, matrix).eval()
with self.assertRaisesOpError("Input matrix is not invertible."):
# All rows of the matrix below add to zero
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
[0., -1., 1.]])
self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testConcurrent(self):
with self.session(use_gpu=True) as sess:
all_ops = []
for adjoint_ in False, True:
lhs1 = random_ops.random_normal([3, 3], seed=42)
lhs2 = random_ops.random_normal([3, 3], seed=42)
rhs1 = random_ops.random_normal([3, 3], seed=42)
rhs2 = random_ops.random_normal([3, 3], seed=42)
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
all_ops += [s1, s2]
val = self.evaluate(all_ops)
self.assertAllEqual(val[0], val[1])
self.assertAllEqual(val[2], val[3])
seed = [42, 24]
matrix_shape = [3, 3]
all_ops = []
for adjoint_ in False, True:
lhs1 = stateless_random_ops.stateless_random_normal(
matrix_shape, seed=seed)
lhs2 = stateless_random_ops.stateless_random_normal(
matrix_shape, seed=seed)
rhs1 = stateless_random_ops.stateless_random_normal(
matrix_shape, seed=seed)
rhs2 = stateless_random_ops.stateless_random_normal(
matrix_shape, seed=seed)
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
all_ops += [s1, s2]
val = self.evaluate(all_ops)
for i in range(0, len(all_ops), 2):
self.assertAllEqual(val[i], val[i + 1])
class MatrixSolveBenchmark(test.Benchmark):