Enable tf.linalg.matrix_solve tests in eager mode.
PiperOrigin-RevId: 311829192 Change-Id: I8d8c0fb2e28c6dd497a99724d4e2bcd78f2d2ed6
This commit is contained in:
parent
a3746cc77a
commit
c26ac449e0
|
@ -21,14 +21,16 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
@ -56,19 +58,19 @@ class MatrixSolveOpTest(test.TestCase):
|
|||
a_np = np.tile(a_np, batch_dims + [1, 1])
|
||||
b = np.tile(b, batch_dims + [1, 1])
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
for use_placeholder in False, True:
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
if use_placeholder:
|
||||
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
||||
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
||||
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
|
||||
for use_placeholder in set((False, not context.executing_eagerly())):
|
||||
if use_placeholder:
|
||||
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
||||
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
||||
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
out = sess.run(tf_ans, {a_ph: a, b_ph: b})
|
||||
else:
|
||||
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
|
||||
out = self.evaluate(tf_ans)
|
||||
self.assertEqual(tf_ans.get_shape(), out.shape)
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
|
||||
else:
|
||||
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
|
||||
out = self.evaluate(tf_ans)
|
||||
self.assertEqual(tf_ans.get_shape(), out.shape)
|
||||
self.assertEqual(np_ans.shape, out.shape)
|
||||
self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
|
||||
|
||||
def _generateMatrix(self, m, n):
|
||||
matrix = (np.random.normal(-5, 5,
|
||||
|
@ -77,7 +79,7 @@ class MatrixSolveOpTest(test.TestCase):
|
|||
[m, n]))
|
||||
return matrix
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testSolve(self):
|
||||
for n in 1, 2, 4, 9:
|
||||
matrix = self._generateMatrix(n, n)
|
||||
|
@ -85,7 +87,7 @@ class MatrixSolveOpTest(test.TestCase):
|
|||
rhs = self._generateMatrix(n, nrhs)
|
||||
self._verifySolve(matrix, rhs)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testSolveBatch(self):
|
||||
for n in 2, 5:
|
||||
matrix = self._generateMatrix(n, n)
|
||||
|
@ -94,48 +96,50 @@ class MatrixSolveOpTest(test.TestCase):
|
|||
for batch_dims in [[2], [2, 2], [7, 4]]:
|
||||
self._verifySolve(matrix, rhs, batch_dims=batch_dims)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testNonSquareMatrix(self):
|
||||
# When the solve of a non-square matrix is attempted we should return
|
||||
# an error
|
||||
with self.session(use_gpu=True):
|
||||
with self.assertRaises(ValueError):
|
||||
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
|
||||
linalg_ops.matrix_solve(matrix, matrix)
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
|
||||
self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testWrongDimensions(self):
|
||||
# The matrix and right-hand sides should have the same number of rows.
|
||||
with self.session(use_gpu=True):
|
||||
matrix = constant_op.constant([[1., 0.], [0., 1.]])
|
||||
rhs = constant_op.constant([[1., 0.]])
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_solve(matrix, rhs)
|
||||
matrix = constant_op.constant([[1., 0.], [0., 1.]])
|
||||
rhs = constant_op.constant([[1., 0.]])
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
self.evaluate(linalg_ops.matrix_solve(matrix, rhs))
|
||||
|
||||
def testNotInvertible(self):
|
||||
# The input should be invertible.
|
||||
with self.session(use_gpu=True):
|
||||
with self.assertRaisesOpError("Input matrix is not invertible."):
|
||||
# All rows of the matrix below add to zero
|
||||
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
|
||||
[0., -1., 1.]])
|
||||
linalg_ops.matrix_solve(matrix, matrix).eval()
|
||||
with self.assertRaisesOpError("Input matrix is not invertible."):
|
||||
# All rows of the matrix below add to zero
|
||||
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
|
||||
[0., -1., 1.]])
|
||||
self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testConcurrent(self):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
all_ops = []
|
||||
for adjoint_ in False, True:
|
||||
lhs1 = random_ops.random_normal([3, 3], seed=42)
|
||||
lhs2 = random_ops.random_normal([3, 3], seed=42)
|
||||
rhs1 = random_ops.random_normal([3, 3], seed=42)
|
||||
rhs2 = random_ops.random_normal([3, 3], seed=42)
|
||||
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
|
||||
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
|
||||
all_ops += [s1, s2]
|
||||
val = self.evaluate(all_ops)
|
||||
self.assertAllEqual(val[0], val[1])
|
||||
self.assertAllEqual(val[2], val[3])
|
||||
seed = [42, 24]
|
||||
matrix_shape = [3, 3]
|
||||
all_ops = []
|
||||
for adjoint_ in False, True:
|
||||
lhs1 = stateless_random_ops.stateless_random_normal(
|
||||
matrix_shape, seed=seed)
|
||||
lhs2 = stateless_random_ops.stateless_random_normal(
|
||||
matrix_shape, seed=seed)
|
||||
rhs1 = stateless_random_ops.stateless_random_normal(
|
||||
matrix_shape, seed=seed)
|
||||
rhs2 = stateless_random_ops.stateless_random_normal(
|
||||
matrix_shape, seed=seed)
|
||||
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
|
||||
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
|
||||
all_ops += [s1, s2]
|
||||
val = self.evaluate(all_ops)
|
||||
for i in range(0, len(all_ops), 2):
|
||||
self.assertAllEqual(val[i], val[i + 1])
|
||||
|
||||
|
||||
class MatrixSolveBenchmark(test.Benchmark):
|
||||
|
|
Loading…
Reference in New Issue