Add support for complex SVD. (#3976)

* Added support for complex SVD.

* Added complex support for batch_svd.

* Added some complex SVD tests.

* Made things look nice.
This commit is contained in:
Rasmus Munk Larsen 2016-08-30 17:35:50 -07:00 committed by GitHub
commit 06737cbf79
5 changed files with 70 additions and 18 deletions

View File

@ -232,5 +232,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
template class LinearAlgebraOp<float>;
template class LinearAlgebraOp<double>;
template class LinearAlgebraOp<complex64>;
template class LinearAlgebraOp<complex128>;
} // namespace tensorflow

View File

@ -166,6 +166,8 @@ class LinearAlgebraOp : public OpKernel {
// linalg_ops_common.cc for float and double.
extern template class LinearAlgebraOp<float>;
extern template class LinearAlgebraOp<double>;
extern template class LinearAlgebraOp<complex64>;
extern template class LinearAlgebraOp<complex128>;
} // namespace tensorflow

View File

@ -99,7 +99,11 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
REGISTER_LINALG_OP("Svd", (SvdOp<float>), float);
REGISTER_LINALG_OP("Svd", (SvdOp<double>), double);
REGISTER_LINALG_OP("Svd", (SvdOp<complex64>), complex64);
REGISTER_LINALG_OP("Svd", (SvdOp<complex128>), complex128);
REGISTER_LINALG_OP("BatchSvd", (SvdOp<float>), float);
REGISTER_LINALG_OP("BatchSvd", (SvdOp<double>), double);
REGISTER_LINALG_OP("BatchSvd", (SvdOp<complex64>), complex64);
REGISTER_LINALG_OP("BatchSvd", (SvdOp<complex128>), complex128);
} // namespace tensorflow

View File

@ -732,7 +732,7 @@ REGISTER_OP("Svd")
.Output("v: T")
.Attr("compute_uv: bool = True")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float}")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(SvdShapeFn)
.Doc(R"doc(
Computes the singular value decomposition of a matrix.
@ -771,7 +771,7 @@ REGISTER_OP("BatchSvd")
.Output("v: T")
.Attr("compute_uv: bool = True")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float}")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(BatchSvdShapeFn)
.Doc(R"doc(
Computes the singular value decompositions of a batch of matrices.

View File

@ -34,6 +34,16 @@ class SvdOpTest(tf.test.TestCase):
tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
with self.assertRaises(ValueError):
tf.svd(tensor)
scalar = tf.constant(1. + 1.0j)
with self.assertRaises(ValueError):
tf.svd(scalar)
vector = tf.constant([1. + 1.0j, 2. + 2.0j])
with self.assertRaises(ValueError):
tf.svd(vector)
tensor = tf.constant([[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]],
[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]]])
with self.assertRaises(ValueError):
tf.svd(tensor)
# The input to batch_svd should be a tensor of at least rank 2.
scalar = tf.constant(1.)
@ -42,19 +52,28 @@ class SvdOpTest(tf.test.TestCase):
vector = tf.constant([1., 2.])
with self.assertRaises(ValueError):
tf.batch_svd(vector)
scalar = tf.constant(1. + 1.0j)
with self.assertRaises(ValueError):
tf.batch_svd(scalar)
vector = tf.constant([1. + 1.0j, 2. + 2.0j])
with self.assertRaises(ValueError):
tf.batch_svd(vector)
def _GetSvdOpTest(dtype_, shape_):
def CompareSingularValues(self, x, y):
if dtype_ == np.float32:
if dtype_ in (np.float32, np.complex64):
tol = 5e-5
else:
tol = 1e-14
self.assertAllClose(x, y, atol=(x[0] + y[0]) * tol)
self.assertAllClose(np.real(x), np.real(y),
atol=(np.real(x)[0] + np.real(y)[0]) * tol)
self.assertAllClose(np.imag(x), np.imag(y),
atol=(np.imag(x)[0] + np.imag(y)[0]) * tol)
def CompareSingularVectors(self, x, y, rank):
if dtype_ == np.float32:
if dtype_ in (np.float32, np.complex64):
atol = 5e-4
else:
atol = 1e-14
@ -69,12 +88,19 @@ def _GetSvdOpTest(dtype_, shape_):
y = y[..., 0:rank]
# Singular vectors are only unique up to sign (complex phase factor for
# complex matrices), so we normalize the signs first.
signs = np.sign(np.sum(np.divide(x, y), -2, keepdims=True))
x *= signs
self.assertAllClose(x, y, atol=atol)
if dtype_ in (np.float32, np.float64):
signs = np.sign(np.sum(np.divide(x, y), -2, keepdims=True))
x *= signs
self.assertAllClose(x, y, atol=atol)
else:
phases = np.divide(np.sum(np.divide(y, x), -2, keepdims=True),
np.abs(np.sum(np.divide(y, x), -2, keepdims=True)))
x *= phases
self.assertAllClose(np.real(x), np.real(y), atol=atol)
self.assertAllClose(np.imag(x), np.imag(y), atol=atol)
def CheckApproximation(self, a, u, s, v, full_matrices):
if dtype_ == np.float32:
if dtype_ in (np.float32, np.complex64):
tol = 1e-5
else:
tol = 1e-14
@ -82,7 +108,7 @@ def _GetSvdOpTest(dtype_, shape_):
batch_shape = a.shape[:-2]
m = a.shape[-2]
n = a.shape[-1]
diag_s = tf.batch_matrix_diag(s)
diag_s = tf.cast(tf.batch_matrix_diag(s), dtype=dtype_)
if full_matrices:
if m > n:
zeros = tf.zeros(batch_shape + (m - n, n), dtype=dtype_)
@ -90,24 +116,42 @@ def _GetSvdOpTest(dtype_, shape_):
elif n > m:
zeros = tf.zeros(batch_shape + (m, n - m), dtype=dtype_)
diag_s = tf.concat(a.ndim - 1, [diag_s, zeros])
a_recon = tf.batch_matmul(u, diag_s)
a_recon = tf.batch_matmul(a_recon, v, adj_y=True)
self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol)
a_recon = tf.batch_matmul(tf.cast(u, dtype=dtype_),
tf.cast(diag_s, dtype=dtype_))
a_recon = tf.batch_matmul(a_recon, tf.cast(v, dtype=dtype_), adj_y=True)
self.assertAllClose(np.real(a_recon.eval()),
np.real(a), rtol=tol, atol=tol)
self.assertAllClose(np.imag(a_recon.eval()),
np.imag(a), rtol=tol, atol=tol)
def CheckUnitary(self, x):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
xx = tf.batch_matmul(x, x, adj_x=True)
identity = tf.batch_matrix_band_part(tf.ones_like(xx), 0, 0)
if dtype_ == np.float32:
if dtype_ in (np.float32, np.complex64):
tol = 1e-5
else:
tol = 1e-14
self.assertAllClose(identity.eval(), xx.eval(), atol=tol)
self.assertAllClose(np.real(identity.eval()),
np.real(xx.eval()), atol=tol)
self.assertAllClose(np.imag(identity.eval()),
np.imag(xx.eval()), atol=tol)
def Test(self):
np.random.seed(1)
x = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
if dtype_ in (np.float32, np.float64):
x = np.random.uniform(low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
elif dtype == np.complex64:
x = np.random.uniform(low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(np.float32)
+ 1j * np.random.uniform(low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(np.float32)
else:
x = np.random.uniform(low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(np.float64)
+ 1j * np.random.uniform(low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(np.float64)
for compute_uv in False, True:
for full_matrices in False, True:
with self.test_session():
@ -152,7 +196,7 @@ def _GetSvdOpTest(dtype_, shape_):
if __name__ == '__main__':
for dtype in np.float32, np.float64:
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
for cols in 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):