Enable tf.linalg.matrix_solve tests in eager mode.
PiperOrigin-RevId: 311829192 Change-Id: I8d8c0fb2e28c6dd497a99724d4e2bcd78f2d2ed6
This commit is contained in:
parent
a3746cc77a
commit
c26ac449e0
|
@ -21,14 +21,16 @@ from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import benchmark
|
from tensorflow.python.platform import benchmark
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
@ -56,12 +58,12 @@ class MatrixSolveOpTest(test.TestCase):
|
||||||
a_np = np.tile(a_np, batch_dims + [1, 1])
|
a_np = np.tile(a_np, batch_dims + [1, 1])
|
||||||
b = np.tile(b, batch_dims + [1, 1])
|
b = np.tile(b, batch_dims + [1, 1])
|
||||||
np_ans = np.linalg.solve(a_np, b)
|
np_ans = np.linalg.solve(a_np, b)
|
||||||
for use_placeholder in False, True:
|
for use_placeholder in set((False, not context.executing_eagerly())):
|
||||||
with self.cached_session(use_gpu=True) as sess:
|
|
||||||
if use_placeholder:
|
if use_placeholder:
|
||||||
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
||||||
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
|
||||||
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
|
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
|
||||||
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
out = sess.run(tf_ans, {a_ph: a, b_ph: b})
|
out = sess.run(tf_ans, {a_ph: a, b_ph: b})
|
||||||
else:
|
else:
|
||||||
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
|
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
|
||||||
|
@ -77,7 +79,7 @@ class MatrixSolveOpTest(test.TestCase):
|
||||||
[m, n]))
|
[m, n]))
|
||||||
return matrix
|
return matrix
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testSolve(self):
|
def testSolve(self):
|
||||||
for n in 1, 2, 4, 9:
|
for n in 1, 2, 4, 9:
|
||||||
matrix = self._generateMatrix(n, n)
|
matrix = self._generateMatrix(n, n)
|
||||||
|
@ -85,7 +87,7 @@ class MatrixSolveOpTest(test.TestCase):
|
||||||
rhs = self._generateMatrix(n, nrhs)
|
rhs = self._generateMatrix(n, nrhs)
|
||||||
self._verifySolve(matrix, rhs)
|
self._verifySolve(matrix, rhs)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testSolveBatch(self):
|
def testSolveBatch(self):
|
||||||
for n in 2, 5:
|
for n in 2, 5:
|
||||||
matrix = self._generateMatrix(n, n)
|
matrix = self._generateMatrix(n, n)
|
||||||
|
@ -94,48 +96,50 @@ class MatrixSolveOpTest(test.TestCase):
|
||||||
for batch_dims in [[2], [2, 2], [7, 4]]:
|
for batch_dims in [[2], [2, 2], [7, 4]]:
|
||||||
self._verifySolve(matrix, rhs, batch_dims=batch_dims)
|
self._verifySolve(matrix, rhs, batch_dims=batch_dims)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testNonSquareMatrix(self):
|
def testNonSquareMatrix(self):
|
||||||
# When the solve of a non-square matrix is attempted we should return
|
# When the solve of a non-square matrix is attempted we should return
|
||||||
# an error
|
# an error
|
||||||
with self.session(use_gpu=True):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
|
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
|
||||||
linalg_ops.matrix_solve(matrix, matrix)
|
self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testWrongDimensions(self):
|
def testWrongDimensions(self):
|
||||||
# The matrix and right-hand sides should have the same number of rows.
|
# The matrix and right-hand sides should have the same number of rows.
|
||||||
with self.session(use_gpu=True):
|
|
||||||
matrix = constant_op.constant([[1., 0.], [0., 1.]])
|
matrix = constant_op.constant([[1., 0.], [0., 1.]])
|
||||||
rhs = constant_op.constant([[1., 0.]])
|
rhs = constant_op.constant([[1., 0.]])
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
linalg_ops.matrix_solve(matrix, rhs)
|
self.evaluate(linalg_ops.matrix_solve(matrix, rhs))
|
||||||
|
|
||||||
def testNotInvertible(self):
|
def testNotInvertible(self):
|
||||||
# The input should be invertible.
|
# The input should be invertible.
|
||||||
with self.session(use_gpu=True):
|
|
||||||
with self.assertRaisesOpError("Input matrix is not invertible."):
|
with self.assertRaisesOpError("Input matrix is not invertible."):
|
||||||
# All rows of the matrix below add to zero
|
# All rows of the matrix below add to zero
|
||||||
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
|
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
|
||||||
[0., -1., 1.]])
|
[0., -1., 1.]])
|
||||||
linalg_ops.matrix_solve(matrix, matrix).eval()
|
self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testConcurrent(self):
|
def testConcurrent(self):
|
||||||
with self.session(use_gpu=True) as sess:
|
seed = [42, 24]
|
||||||
|
matrix_shape = [3, 3]
|
||||||
all_ops = []
|
all_ops = []
|
||||||
for adjoint_ in False, True:
|
for adjoint_ in False, True:
|
||||||
lhs1 = random_ops.random_normal([3, 3], seed=42)
|
lhs1 = stateless_random_ops.stateless_random_normal(
|
||||||
lhs2 = random_ops.random_normal([3, 3], seed=42)
|
matrix_shape, seed=seed)
|
||||||
rhs1 = random_ops.random_normal([3, 3], seed=42)
|
lhs2 = stateless_random_ops.stateless_random_normal(
|
||||||
rhs2 = random_ops.random_normal([3, 3], seed=42)
|
matrix_shape, seed=seed)
|
||||||
|
rhs1 = stateless_random_ops.stateless_random_normal(
|
||||||
|
matrix_shape, seed=seed)
|
||||||
|
rhs2 = stateless_random_ops.stateless_random_normal(
|
||||||
|
matrix_shape, seed=seed)
|
||||||
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
|
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
|
||||||
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
|
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
|
||||||
all_ops += [s1, s2]
|
all_ops += [s1, s2]
|
||||||
val = self.evaluate(all_ops)
|
val = self.evaluate(all_ops)
|
||||||
self.assertAllEqual(val[0], val[1])
|
for i in range(0, len(all_ops), 2):
|
||||||
self.assertAllEqual(val[2], val[3])
|
self.assertAllEqual(val[i], val[i + 1])
|
||||||
|
|
||||||
|
|
||||||
class MatrixSolveBenchmark(test.Benchmark):
|
class MatrixSolveBenchmark(test.Benchmark):
|
||||||
|
|
Loading…
Reference in New Issue