Enable (non-gradient) tests of tf.linalg.cholesky in eager mode.
PiperOrigin-RevId: 312102967 Change-Id: Icefc46a8268413dfaec42109d4f57dd07f602a54
This commit is contained in:
parent
1b2a65c15f
commit
0bf90cb2a8
@ -32,7 +32,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.linalg import linalg
|
||||
from tensorflow.python.platform import benchmark
|
||||
@ -91,7 +91,7 @@ def TriAngInvCompositeGrad(l, grad):
|
||||
|
||||
class CholeskyOpTest(test.TestCase):
|
||||
|
||||
def _verifyCholeskyBase(self, sess, x, chol, verification):
|
||||
def _verifyCholeskyBase(self, x, chol, verification):
|
||||
chol_np, verification_np = self.evaluate([chol, verification])
|
||||
self.assertAllClose(x, verification_np)
|
||||
self.assertShapeEqual(x, chol)
|
||||
@ -106,11 +106,11 @@ class CholeskyOpTest(test.TestCase):
|
||||
|
||||
def _verifyCholesky(self, x):
|
||||
# Verify that LL^T == x.
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
chol = linalg_ops.cholesky(x)
|
||||
verification = math_ops.matmul(chol, chol, adjoint_b=True)
|
||||
self._verifyCholeskyBase(sess, x, chol, verification)
|
||||
chol = linalg_ops.cholesky(x)
|
||||
verification = math_ops.matmul(chol, chol, adjoint_b=True)
|
||||
self._verifyCholeskyBase(x, chol, verification)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testBasic(self):
|
||||
data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
|
||||
for dtype in (np.float32, np.float64):
|
||||
@ -123,6 +123,7 @@ class CholeskyOpTest(test.TestCase):
|
||||
complex_data += data
|
||||
self._verifyCholesky(complex_data)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testBatch(self):
|
||||
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
|
||||
self._verifyCholesky(simple_array)
|
||||
@ -144,21 +145,21 @@ class CholeskyOpTest(test.TestCase):
|
||||
matrices[i] = np.dot(matrices[i].T.conj(), matrices[i])
|
||||
self._verifyCholesky(matrices)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testNonSquareMatrix(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
linalg_ops.cholesky(
|
||||
np.array([[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]
|
||||
]))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testWrongDimensions(self):
|
||||
tensor3 = constant_op.constant([1., 2.])
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
linalg_ops.cholesky(tensor3)
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||
linalg_ops.cholesky(tensor3)
|
||||
|
||||
# The below invalid Cholesky call returns an error with TF Classic and just
|
||||
@ -175,21 +176,23 @@ class CholeskyOpTest(test.TestCase):
|
||||
self._verifyCholesky(
|
||||
np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def testEmpty(self):
|
||||
self._verifyCholesky(np.empty([0, 2, 2]))
|
||||
self._verifyCholesky(np.empty([2, 0, 0]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testConcurrentExecutesWithoutError(self):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
matrix1 = random_ops.random_normal([5, 5], seed=42)
|
||||
matrix2 = random_ops.random_normal([5, 5], seed=42)
|
||||
matrix1 = math_ops.matmul(matrix1, matrix1, adjoint_a=True)
|
||||
matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True)
|
||||
c1 = linalg_ops.cholesky(matrix1)
|
||||
c2 = linalg_ops.cholesky(matrix2)
|
||||
c1_val, c2_val = self.evaluate([c1, c2])
|
||||
self.assertAllClose(c1_val, c2_val)
|
||||
seed = [42, 24]
|
||||
matrix_shape = [5, 5]
|
||||
matrix1 = stateless_random_ops.stateless_random_normal(matrix_shape, seed)
|
||||
matrix2 = stateless_random_ops.stateless_random_normal(matrix_shape, seed)
|
||||
matrix1 = math_ops.matmul(matrix1, matrix1, adjoint_a=True)
|
||||
matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True)
|
||||
c1 = linalg_ops.cholesky(matrix1)
|
||||
c2 = linalg_ops.cholesky(matrix2)
|
||||
c1_val, c2_val = self.evaluate([c1, c2])
|
||||
self.assertAllClose(c1_val, c2_val)
|
||||
|
||||
|
||||
class CholeskyGradTest(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user