Enable tf.linalg.matrix_solve tests in eager mode.

PiperOrigin-RevId: 311829192
Change-Id: I8d8c0fb2e28c6dd497a99724d4e2bcd78f2d2ed6
This commit is contained in:
A. Unique TensorFlower 2020-05-15 17:18:30 -07:00 committed by TensorFlower Gardener
parent a3746cc77a
commit c26ac449e0
1 changed files with 50 additions and 46 deletions

View File

@ -21,14 +21,16 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -56,12 +58,12 @@ class MatrixSolveOpTest(test.TestCase):
a_np = np.tile(a_np, batch_dims + [1, 1]) a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1])
np_ans = np.linalg.solve(a_np, b) np_ans = np.linalg.solve(a_np, b)
for use_placeholder in False, True: for use_placeholder in set((False, not context.executing_eagerly())):
with self.cached_session(use_gpu=True) as sess:
if use_placeholder: if use_placeholder:
a_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
b_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint) tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
with self.cached_session(use_gpu=True) as sess:
out = sess.run(tf_ans, {a_ph: a, b_ph: b}) out = sess.run(tf_ans, {a_ph: a, b_ph: b})
else: else:
tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
@ -77,7 +79,7 @@ class MatrixSolveOpTest(test.TestCase):
[m, n])) [m, n]))
return matrix return matrix
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSolve(self): def testSolve(self):
for n in 1, 2, 4, 9: for n in 1, 2, 4, 9:
matrix = self._generateMatrix(n, n) matrix = self._generateMatrix(n, n)
@ -85,7 +87,7 @@ class MatrixSolveOpTest(test.TestCase):
rhs = self._generateMatrix(n, nrhs) rhs = self._generateMatrix(n, nrhs)
self._verifySolve(matrix, rhs) self._verifySolve(matrix, rhs)
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSolveBatch(self): def testSolveBatch(self):
for n in 2, 5: for n in 2, 5:
matrix = self._generateMatrix(n, n) matrix = self._generateMatrix(n, n)
@ -94,48 +96,50 @@ class MatrixSolveOpTest(test.TestCase):
for batch_dims in [[2], [2, 2], [7, 4]]: for batch_dims in [[2], [2, 2], [7, 4]]:
self._verifySolve(matrix, rhs, batch_dims=batch_dims) self._verifySolve(matrix, rhs, batch_dims=batch_dims)
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testNonSquareMatrix(self): def testNonSquareMatrix(self):
# When the solve of a non-square matrix is attempted we should return # When the solve of a non-square matrix is attempted we should return
# an error # an error
with self.session(use_gpu=True): with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
with self.assertRaises(ValueError):
matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]]) matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]])
linalg_ops.matrix_solve(matrix, matrix) self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testWrongDimensions(self): def testWrongDimensions(self):
# The matrix and right-hand sides should have the same number of rows. # The matrix and right-hand sides should have the same number of rows.
with self.session(use_gpu=True):
matrix = constant_op.constant([[1., 0.], [0., 1.]]) matrix = constant_op.constant([[1., 0.], [0., 1.]])
rhs = constant_op.constant([[1., 0.]]) rhs = constant_op.constant([[1., 0.]])
with self.assertRaises(ValueError): with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
linalg_ops.matrix_solve(matrix, rhs) self.evaluate(linalg_ops.matrix_solve(matrix, rhs))
def testNotInvertible(self): def testNotInvertible(self):
# The input should be invertible. # The input should be invertible.
with self.session(use_gpu=True):
with self.assertRaisesOpError("Input matrix is not invertible."): with self.assertRaisesOpError("Input matrix is not invertible."):
# All rows of the matrix below add to zero # All rows of the matrix below add to zero
matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.], matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
[0., -1., 1.]]) [0., -1., 1.]])
linalg_ops.matrix_solve(matrix, matrix).eval() self.evaluate(linalg_ops.matrix_solve(matrix, matrix))
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testConcurrent(self): def testConcurrent(self):
with self.session(use_gpu=True) as sess: seed = [42, 24]
matrix_shape = [3, 3]
all_ops = [] all_ops = []
for adjoint_ in False, True: for adjoint_ in False, True:
lhs1 = random_ops.random_normal([3, 3], seed=42) lhs1 = stateless_random_ops.stateless_random_normal(
lhs2 = random_ops.random_normal([3, 3], seed=42) matrix_shape, seed=seed)
rhs1 = random_ops.random_normal([3, 3], seed=42) lhs2 = stateless_random_ops.stateless_random_normal(
rhs2 = random_ops.random_normal([3, 3], seed=42) matrix_shape, seed=seed)
rhs1 = stateless_random_ops.stateless_random_normal(
matrix_shape, seed=seed)
rhs2 = stateless_random_ops.stateless_random_normal(
matrix_shape, seed=seed)
s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_) s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_)
s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_) s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_)
all_ops += [s1, s2] all_ops += [s1, s2]
val = self.evaluate(all_ops) val = self.evaluate(all_ops)
self.assertAllEqual(val[0], val[1]) for i in range(0, len(all_ops), 2):
self.assertAllEqual(val[2], val[3]) self.assertAllEqual(val[i], val[i + 1])
class MatrixSolveBenchmark(test.Benchmark): class MatrixSolveBenchmark(test.Benchmark):