From 4eaf597cbaf1ef1f2a216a1a83289d06419f87fd Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 17 Jul 2018 11:27:23 -0700 Subject: [PATCH] [TF:XLA] Add a 2000x2000 test case to Cholesky and QR decomposition tests. PiperOrigin-RevId: 204943689 --- tensorflow/compiler/tests/cholesky_op_test.py | 8 ++------ tensorflow/compiler/tests/qr_op_test.py | 5 ++++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index d2867278af9..ed532db0ee5 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest - import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -103,9 +101,8 @@ class CholeskyOpTest(xla_test.XLATestCase): with self.assertRaises(ValueError): linalg_ops.cholesky(tensor3) - @unittest.skip("Test is slow") - def testLarge(self): - n = 200 + def testLarge2000x2000(self): + n = 2000 shape = (n, n) data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( np.ones(n).astype(np.float32)) @@ -128,6 +125,5 @@ class CholeskyOpTest(xla_test.XLATestCase): matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) self._verifyCholesky(matrix, atol=1e-4) - if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 93752a21db4..1b969ee2b38 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -57,7 +57,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): def CheckApproximation(self, a, q, r): # Tests that a ~= q*r. precision = self.AdjustedNorm(a - np.matmul(q, r)) - self.assertTrue(np.all(precision < 5.0)) + self.assertTrue(np.all(precision < 10.0)) def CheckUnitary(self, x): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. @@ -107,6 +107,9 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): self._test(dtype, batch_dims + (rows, cols), full_matrices) + def testLarge2000x2000(self): + self._test(np.float32, (2000, 2000), full_matrices=True) + if __name__ == "__main__": test.main()