Add support for complex SVD. (#3976)
* Added support for complex SVD. * Added complex support for batch_svd. * Added some complex SVD tests. * Made things look nice.
This commit is contained in:
commit
06737cbf79
@ -232,5 +232,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
|
||||
// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
|
||||
template class LinearAlgebraOp<float>;
|
||||
template class LinearAlgebraOp<double>;
|
||||
template class LinearAlgebraOp<complex64>;
|
||||
template class LinearAlgebraOp<complex128>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -166,6 +166,8 @@ class LinearAlgebraOp : public OpKernel {
|
||||
// linalg_ops_common.cc for float and double.
|
||||
extern template class LinearAlgebraOp<float>;
|
||||
extern template class LinearAlgebraOp<double>;
|
||||
extern template class LinearAlgebraOp<complex64>;
|
||||
extern template class LinearAlgebraOp<complex128>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -99,7 +99,11 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
|
||||
|
||||
REGISTER_LINALG_OP("Svd", (SvdOp<float>), float);
|
||||
REGISTER_LINALG_OP("Svd", (SvdOp<double>), double);
|
||||
REGISTER_LINALG_OP("Svd", (SvdOp<complex64>), complex64);
|
||||
REGISTER_LINALG_OP("Svd", (SvdOp<complex128>), complex128);
|
||||
REGISTER_LINALG_OP("BatchSvd", (SvdOp<float>), float);
|
||||
REGISTER_LINALG_OP("BatchSvd", (SvdOp<double>), double);
|
||||
REGISTER_LINALG_OP("BatchSvd", (SvdOp<complex64>), complex64);
|
||||
REGISTER_LINALG_OP("BatchSvd", (SvdOp<complex128>), complex128);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -732,7 +732,7 @@ REGISTER_OP("Svd")
|
||||
.Output("v: T")
|
||||
.Attr("compute_uv: bool = True")
|
||||
.Attr("full_matrices: bool = False")
|
||||
.Attr("T: {double, float}")
|
||||
.Attr("T: {double, float, complex64, complex128}")
|
||||
.SetShapeFn(SvdShapeFn)
|
||||
.Doc(R"doc(
|
||||
Computes the singular value decomposition of a matrix.
|
||||
@ -771,7 +771,7 @@ REGISTER_OP("BatchSvd")
|
||||
.Output("v: T")
|
||||
.Attr("compute_uv: bool = True")
|
||||
.Attr("full_matrices: bool = False")
|
||||
.Attr("T: {double, float}")
|
||||
.Attr("T: {double, float, complex64, complex128}")
|
||||
.SetShapeFn(BatchSvdShapeFn)
|
||||
.Doc(R"doc(
|
||||
Computes the singular value decompositions of a batch of matrices.
|
||||
|
@ -34,6 +34,16 @@ class SvdOpTest(tf.test.TestCase):
|
||||
tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.svd(tensor)
|
||||
scalar = tf.constant(1. + 1.0j)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.svd(scalar)
|
||||
vector = tf.constant([1. + 1.0j, 2. + 2.0j])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.svd(vector)
|
||||
tensor = tf.constant([[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]],
|
||||
[[1. + 1.0j, 2. + 2.0j], [3. + 3.0j, 4. + 4.0j]]])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.svd(tensor)
|
||||
|
||||
# The input to batch_svd should be a tensor of at least rank 2.
|
||||
scalar = tf.constant(1.)
|
||||
@ -42,19 +52,28 @@ class SvdOpTest(tf.test.TestCase):
|
||||
vector = tf.constant([1., 2.])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.batch_svd(vector)
|
||||
scalar = tf.constant(1. + 1.0j)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.batch_svd(scalar)
|
||||
vector = tf.constant([1. + 1.0j, 2. + 2.0j])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.batch_svd(vector)
|
||||
|
||||
|
||||
def _GetSvdOpTest(dtype_, shape_):
|
||||
|
||||
def CompareSingularValues(self, x, y):
|
||||
if dtype_ == np.float32:
|
||||
if dtype_ in (np.float32, np.complex64):
|
||||
tol = 5e-5
|
||||
else:
|
||||
tol = 1e-14
|
||||
self.assertAllClose(x, y, atol=(x[0] + y[0]) * tol)
|
||||
self.assertAllClose(np.real(x), np.real(y),
|
||||
atol=(np.real(x)[0] + np.real(y)[0]) * tol)
|
||||
self.assertAllClose(np.imag(x), np.imag(y),
|
||||
atol=(np.imag(x)[0] + np.imag(y)[0]) * tol)
|
||||
|
||||
def CompareSingularVectors(self, x, y, rank):
|
||||
if dtype_ == np.float32:
|
||||
if dtype_ in (np.float32, np.complex64):
|
||||
atol = 5e-4
|
||||
else:
|
||||
atol = 1e-14
|
||||
@ -69,12 +88,19 @@ def _GetSvdOpTest(dtype_, shape_):
|
||||
y = y[..., 0:rank]
|
||||
# Singular vectors are only unique up to sign (complex phase factor for
|
||||
# complex matrices), so we normalize the signs first.
|
||||
signs = np.sign(np.sum(np.divide(x, y), -2, keepdims=True))
|
||||
x *= signs
|
||||
self.assertAllClose(x, y, atol=atol)
|
||||
if dtype_ in (np.float32, np.float64):
|
||||
signs = np.sign(np.sum(np.divide(x, y), -2, keepdims=True))
|
||||
x *= signs
|
||||
self.assertAllClose(x, y, atol=atol)
|
||||
else:
|
||||
phases = np.divide(np.sum(np.divide(y, x), -2, keepdims=True),
|
||||
np.abs(np.sum(np.divide(y, x), -2, keepdims=True)))
|
||||
x *= phases
|
||||
self.assertAllClose(np.real(x), np.real(y), atol=atol)
|
||||
self.assertAllClose(np.imag(x), np.imag(y), atol=atol)
|
||||
|
||||
def CheckApproximation(self, a, u, s, v, full_matrices):
|
||||
if dtype_ == np.float32:
|
||||
if dtype_ in (np.float32, np.complex64):
|
||||
tol = 1e-5
|
||||
else:
|
||||
tol = 1e-14
|
||||
@ -82,7 +108,7 @@ def _GetSvdOpTest(dtype_, shape_):
|
||||
batch_shape = a.shape[:-2]
|
||||
m = a.shape[-2]
|
||||
n = a.shape[-1]
|
||||
diag_s = tf.batch_matrix_diag(s)
|
||||
diag_s = tf.cast(tf.batch_matrix_diag(s), dtype=dtype_)
|
||||
if full_matrices:
|
||||
if m > n:
|
||||
zeros = tf.zeros(batch_shape + (m - n, n), dtype=dtype_)
|
||||
@ -90,24 +116,42 @@ def _GetSvdOpTest(dtype_, shape_):
|
||||
elif n > m:
|
||||
zeros = tf.zeros(batch_shape + (m, n - m), dtype=dtype_)
|
||||
diag_s = tf.concat(a.ndim - 1, [diag_s, zeros])
|
||||
a_recon = tf.batch_matmul(u, diag_s)
|
||||
a_recon = tf.batch_matmul(a_recon, v, adj_y=True)
|
||||
self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol)
|
||||
a_recon = tf.batch_matmul(tf.cast(u, dtype=dtype_),
|
||||
tf.cast(diag_s, dtype=dtype_))
|
||||
a_recon = tf.batch_matmul(a_recon, tf.cast(v, dtype=dtype_), adj_y=True)
|
||||
self.assertAllClose(np.real(a_recon.eval()),
|
||||
np.real(a), rtol=tol, atol=tol)
|
||||
self.assertAllClose(np.imag(a_recon.eval()),
|
||||
np.imag(a), rtol=tol, atol=tol)
|
||||
|
||||
def CheckUnitary(self, x):
|
||||
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
|
||||
xx = tf.batch_matmul(x, x, adj_x=True)
|
||||
identity = tf.batch_matrix_band_part(tf.ones_like(xx), 0, 0)
|
||||
if dtype_ == np.float32:
|
||||
if dtype_ in (np.float32, np.complex64):
|
||||
tol = 1e-5
|
||||
else:
|
||||
tol = 1e-14
|
||||
self.assertAllClose(identity.eval(), xx.eval(), atol=tol)
|
||||
self.assertAllClose(np.real(identity.eval()),
|
||||
np.real(xx.eval()), atol=tol)
|
||||
self.assertAllClose(np.imag(identity.eval()),
|
||||
np.imag(xx.eval()), atol=tol)
|
||||
|
||||
def Test(self):
|
||||
np.random.seed(1)
|
||||
x = np.random.uniform(
|
||||
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
|
||||
if dtype_ in (np.float32, np.float64):
|
||||
x = np.random.uniform(low=-1.0, high=1.0,
|
||||
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
|
||||
elif dtype == np.complex64:
|
||||
x = np.random.uniform(low=-1.0, high=1.0,
|
||||
size=np.prod(shape_)).reshape(shape_).astype(np.float32)
|
||||
+ 1j * np.random.uniform(low=-1.0, high=1.0,
|
||||
size=np.prod(shape_)).reshape(shape_).astype(np.float32)
|
||||
else:
|
||||
x = np.random.uniform(low=-1.0, high=1.0,
|
||||
size=np.prod(shape_)).reshape(shape_).astype(np.float64)
|
||||
+ 1j * np.random.uniform(low=-1.0, high=1.0,
|
||||
size=np.prod(shape_)).reshape(shape_).astype(np.float64)
|
||||
for compute_uv in False, True:
|
||||
for full_matrices in False, True:
|
||||
with self.test_session():
|
||||
@ -152,7 +196,7 @@ def _GetSvdOpTest(dtype_, shape_):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in np.float32, np.float64:
|
||||
for dtype in np.float32, np.float64, np.complex64, np.complex128:
|
||||
for rows in 1, 2, 5, 10, 32, 100:
|
||||
for cols in 1, 2, 5, 10, 32, 100:
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
|
||||
|
Loading…
Reference in New Issue
Block a user