diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py index beaf0f574ca..b1c83959f27 100644 --- a/tensorflow/python/kernel_tests/eig_op_test.py +++ b/tensorflow/python/kernel_tests/eig_op_test.py @@ -24,9 +24,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sort_ops from tensorflow.python.platform import test @@ -82,7 +84,7 @@ class EigTest(test.TestCase): "self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32) self.assertEqual(matrix.shape, (32, 32)) matrix_tensor = constant_op.constant(matrix) - with self.session(use_gpu=True) as sess: + with self.session(use_gpu=True) as _: (e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor)) self.assertEqual(e.size, 32) self.assertAllClose( @@ -99,9 +101,8 @@ def SortEigenValues(e): def SortEigenDecomposition(e, v): if v.ndim < 2: return e, v - else: - perm = np.argsort(e.real + e.imag, -1) - return np.take(e, perm, -1), np.take(v, perm, -1) + perm = np.argsort(e.real + e.imag, -1) + return np.take(e, perm, -1), np.take(v, perm, -1) def EquilibrateEigenVectorPhases(x, y): @@ -147,17 +148,23 @@ def _GetEigTest(dtype_, shape_, compute_v_): n = shape_[-1] batch_shape = shape_[:-2] np_dtype = dtype_.as_numpy_dtype - # most of matrices are diagonalizable # TODO - a = np.random.uniform( - low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) - if dtype_.is_complex: - a += 1j * np.random.uniform( + + def RandomInput(): + # Most matrices are diagonalizable + a = np.random.uniform( low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) - a = np.tile(a, batch_shape + (1, 1)) + if dtype_.is_complex: + a += 1j * np.random.uniform( + low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) + a = np.tile(a, batch_shape + (1, 1)) + return a + if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64): atol = 1e-4 else: atol = 1e-12 + + a = RandomInput() np_e, np_v = np.linalg.eig(a) with self.session(use_gpu=True): if compute_v_: @@ -182,6 +189,72 @@ def _GetEigTest(dtype_, shape_, compute_v_): return Test +class EigGradTest(test.TestCase): + pass # Filled in below + + +def _GetEigGradTest(dtype_, shape_, compute_v_): + + def Test(self): + np.random.seed(1) + n = shape_[-1] + batch_shape = shape_[:-2] + np_dtype = dtype_.as_numpy_dtype + + def RandomInput(): + # Most matrices are diagonalizable + a = np.random.uniform( + low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) + if dtype_.is_complex: + a += 1j * np.random.uniform( + low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) + a = np.tile(a, batch_shape + (1, 1)) + return a + + # Optimal stepsize for central difference is O(epsilon^{1/3}). + epsilon = np.finfo(np_dtype).eps + delta = 0.1 * epsilon**(1.0 / 3.0) + # tolerance obtained by looking at actual differences using + # np.linalg.norm(theoretical-numerical, np.inf) on -mavx build + # after discarding one random input sample + _ = RandomInput() + if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64): + tol = 1e-2 + else: + tol = 1e-7 + with self.session(use_gpu=True): + + def Compute(x): + e, v = linalg_ops.eig(x) + + # We sort eigenvalues by e.real+e.imag to have consistent + # order between runs + b_dims = len(e.shape) - 1 + idx = sort_ops.argsort(math_ops.real(e) + math_ops.imag(e), axis=-1) + e = array_ops.gather(e, idx, batch_dims=b_dims) + v = array_ops.gather(v, idx, batch_dims=b_dims) + + # (complex) Eigenvectors are only unique up to an arbitrary phase + # We normalize the vectors such that the first component has phase 0. + top_rows = v[..., 0:1, :] + angle = -math_ops.angle(top_rows) + phase = math_ops.complex(math_ops.cos(angle), math_ops.sin(angle)) + v *= phase + return e, v + + if compute_v_: + funcs = [lambda x: Compute(x)[0], lambda x: Compute(x)[1]] + else: + funcs = [linalg_ops.eigvals] + + for f in funcs: + theoretical, numerical = gradient_checker_v2.compute_gradient( + f, [RandomInput()], delta=delta) + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + + return Test + + if __name__ == "__main__": dtypes_to_test = [ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, @@ -194,5 +267,8 @@ if __name__ == "__main__": shape = batch_dims + (size, size) name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v) _AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v)) - # No gradient yet + + if dtype not in [dtypes_lib.float32, dtypes_lib.float64]: + _AddTest(EigGradTest, "EigGrad", name, + _GetEigGradTest(dtype, shape, compute_v)) test.main() diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index f456581ef60..6ba0401f334 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -633,6 +633,67 @@ def _MatrixTriangularSolveGrad(op, grad): return grad_a, grad_b +# To avoid nan in cases with degenerate eigenvalues or +# degenerate/zero singular values in calculations of +# f and s_inv_mat, we introduce a Lorentz broadening. +def _SafeReciprocal(x, epsilon=1E-20): + return x * math_ops.reciprocal(x * x + epsilon) + + +@ops.RegisterGradient("Eig") +def _EigGrad(op, grad_e, grad_v): + """Gradient for Eig. + + Based on eq. 4.77 from paper by + Christoph Boeddeker et al. + https://arxiv.org/abs/1701.00392 + See also + "Computation of eigenvalue and eigenvector derivatives + for a general complex-valued eigensystem" by Nico van der Aa. + As for now only distinct eigenvalue case is considered. + """ + e = op.outputs[0] + compute_v = op.get_attr("compute_v") + # a = op.inputs[0], which satisfies + # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] + with ops.control_dependencies([grad_e, grad_v]): + if compute_v: + v = op.outputs[1] + vt = _linalg.adjoint(v) + # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). + # Notice that because of the term involving f, the gradient becomes + # infinite (or NaN in practice) when eigenvalues are not unique. + # Mathematically this should not be surprising, since for (k-fold) + # degenerate eigenvalues, the corresponding eigenvectors are only defined + # up to arbitrary rotation in a (k-dimensional) subspace. + f = array_ops.matrix_set_diag( + _SafeReciprocal( + array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), + array_ops.zeros_like(e)) + f = math_ops.conj(f) + vgv = math_ops.matmul(vt, grad_v) + mid = array_ops.matrix_diag(grad_e) + diag_grad_part = array_ops.matrix_diag( + array_ops.matrix_diag_part( + math_ops.cast(math_ops.real(vgv), vgv.dtype))) + mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part)) + # vt is formally invertible as long as the original matrix is + # diagonalizable. However, in practice, vt may + # be ill-conditioned when matrix original matrix is close to + # non-diagonalizable one + grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt)) + else: + _, v = linalg_ops.eig(op.inputs[0]) + vt = _linalg.adjoint(v) + # vt is formally invertible as long as the original matrix is + # diagonalizable. However, in practice, vt may + # be ill-conditioned when matrix original matrix is close to + # non-diagonalizable one + grad_a = linalg_ops.matrix_solve( + vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt)) + return math_ops.cast(grad_a, op.inputs[0].dtype) + + @ops.RegisterGradient("SelfAdjointEigV2") def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" @@ -650,7 +711,7 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v): # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( - math_ops.reciprocal( + _SafeReciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( @@ -745,11 +806,6 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. - # To avoid nan in cases with degenerate sigular values or zero singular values - # in calculating f and s_inv_mat, we introduce a Lorentz brodening. - - def _SafeReciprocal(x, epsilon=1E-20): - return x * math_ops.reciprocal(x * x + epsilon) s_shape = array_ops.shape(s) f = array_ops.matrix_set_diag(