[XLA] Add a larger QR test and slightly relax tolerances for unitary test.
PiperOrigin-RevId: 333351927 Change-Id: I60f390599f1784d533d6a21697d8c4dfde8cb781
This commit is contained in:
parent
8c30da064f
commit
d712e534ef
@ -70,9 +70,9 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
xx = math_ops.matmul(x, x, adjoint_a=True)
|
||||
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
|
||||
precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity))
|
||||
self.assertTrue(np.all(precision < 5.0))
|
||||
self.assertTrue(np.all(precision < 6.0))
|
||||
|
||||
def _test(self, dtype, shape, full_matrices):
|
||||
def _random_matrix(self, dtype, shape):
|
||||
np.random.seed(1)
|
||||
|
||||
def rng():
|
||||
@ -82,7 +82,11 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
x_np = rng()
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
x_np += rng() * dtype(1j)
|
||||
return x_np
|
||||
|
||||
def _test(self, x_np, full_matrices, full_rank=True):
|
||||
dtype = x_np.dtype
|
||||
shape = x_np.shape
|
||||
with self.session() as sess:
|
||||
x_tf = array_ops.placeholder(dtype)
|
||||
with self.device_scope():
|
||||
@ -103,24 +107,39 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np_q_reshape[i, :, :], _ = np.linalg.qr(
|
||||
x_reshape[i, :, :], mode="reduced")
|
||||
np_q = np.reshape(np_q_reshape, q_dims)
|
||||
self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:]))
|
||||
if full_rank:
|
||||
# Q is unique up to sign/phase if the matrix is full-rank.
|
||||
self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:]))
|
||||
self.CheckApproximation(x_np, q_tf_val, r_tf_val)
|
||||
self.CheckUnitary(q_tf_val)
|
||||
|
||||
SIZES = [1, 2, 5, 10, 32, 100, 300]
|
||||
SIZES = [1, 2, 5, 10, 32, 100, 300, 603]
|
||||
DTYPES = [np.float32, np.complex64]
|
||||
PARAMS = itertools.product(SIZES, SIZES, DTYPES)
|
||||
|
||||
@parameterized.parameters(*PARAMS)
|
||||
def testQR(self, rows, cols, dtype):
|
||||
# TODO(b/111317468): Test other types.
|
||||
for full_matrices in [True, False]:
|
||||
# Only tests the (3, 2) case for small numbers of rows/columns.
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
|
||||
self._test(dtype, batch_dims + (rows, cols), full_matrices)
|
||||
x_np = self._random_matrix(dtype, batch_dims + (rows, cols))
|
||||
self._test(x_np, full_matrices)
|
||||
|
||||
def testLarge2000x2000(self):
|
||||
self._test(np.float32, (2000, 2000), full_matrices=True)
|
||||
x_np = self._random_matrix(np.float32, (2000, 2000))
|
||||
self._test(x_np, full_matrices=True)
|
||||
|
||||
@parameterized.parameters((23, 25), (513, 23))
|
||||
def testZeroColumn(self, rows, cols):
|
||||
x_np = self._random_matrix(np.complex64, (rows, cols))
|
||||
x_np[:, 7] = 0.
|
||||
self._test(x_np, full_matrices=True)
|
||||
|
||||
@parameterized.parameters((4, 4), (514, 20))
|
||||
def testRepeatedColumn(self, rows, cols):
|
||||
x_np = self._random_matrix(np.complex64, (rows, cols))
|
||||
x_np[:, 1] = x_np[:, 2]
|
||||
self._test(x_np, full_matrices=True, full_rank=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user