Merge pull request #33808 from Randl:eig_grad2

PiperOrigin-RevId: 306650050
Change-Id: I49df540bab790bb4e5be83fc4244871c2ac5321a
This commit is contained in:
TensorFlower Gardener 2020-04-15 08:43:16 -07:00
commit 44547d9fd6
2 changed files with 149 additions and 17 deletions

View File

@ -24,9 +24,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.platform import test
@ -82,7 +84,7 @@ class EigTest(test.TestCase):
"self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
self.assertEqual(matrix.shape, (32, 32))
matrix_tensor = constant_op.constant(matrix)
with self.session(use_gpu=True) as sess:
with self.session(use_gpu=True) as _:
(e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
self.assertEqual(e.size, 32)
self.assertAllClose(
@ -99,7 +101,6 @@ def SortEigenValues(e):
def SortEigenDecomposition(e, v):
if v.ndim < 2:
return e, v
else:
perm = np.argsort(e.real + e.imag, -1)
return np.take(e, perm, -1), np.take(v, perm, -1)
@ -147,17 +148,23 @@ def _GetEigTest(dtype_, shape_, compute_v_):
n = shape_[-1]
batch_shape = shape_[:-2]
np_dtype = dtype_.as_numpy_dtype
# most of matrices are diagonalizable # TODO
def RandomInput():
# Most matrices are diagonalizable
a = np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
if dtype_.is_complex:
a += 1j * np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
a = np.tile(a, batch_shape + (1, 1))
return a
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
atol = 1e-4
else:
atol = 1e-12
a = RandomInput()
np_e, np_v = np.linalg.eig(a)
with self.session(use_gpu=True):
if compute_v_:
@ -182,6 +189,72 @@ def _GetEigTest(dtype_, shape_, compute_v_):
return Test
class EigGradTest(test.TestCase):
pass # Filled in below
def _GetEigGradTest(dtype_, shape_, compute_v_):
def Test(self):
np.random.seed(1)
n = shape_[-1]
batch_shape = shape_[:-2]
np_dtype = dtype_.as_numpy_dtype
def RandomInput():
# Most matrices are diagonalizable
a = np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
if dtype_.is_complex:
a += 1j * np.random.uniform(
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
a = np.tile(a, batch_shape + (1, 1))
return a
# Optimal stepsize for central difference is O(epsilon^{1/3}).
epsilon = np.finfo(np_dtype).eps
delta = 0.1 * epsilon**(1.0 / 3.0)
# tolerance obtained by looking at actual differences using
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
# after discarding one random input sample
_ = RandomInput()
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
tol = 1e-2
else:
tol = 1e-7
with self.session(use_gpu=True):
def Compute(x):
e, v = linalg_ops.eig(x)
# We sort eigenvalues by e.real+e.imag to have consistent
# order between runs
b_dims = len(e.shape) - 1
idx = sort_ops.argsort(math_ops.real(e) + math_ops.imag(e), axis=-1)
e = array_ops.gather(e, idx, batch_dims=b_dims)
v = array_ops.gather(v, idx, batch_dims=b_dims)
# (complex) Eigenvectors are only unique up to an arbitrary phase
# We normalize the vectors such that the first component has phase 0.
top_rows = v[..., 0:1, :]
angle = -math_ops.angle(top_rows)
phase = math_ops.complex(math_ops.cos(angle), math_ops.sin(angle))
v *= phase
return e, v
if compute_v_:
funcs = [lambda x: Compute(x)[0], lambda x: Compute(x)[1]]
else:
funcs = [linalg_ops.eigvals]
for f in funcs:
theoretical, numerical = gradient_checker_v2.compute_gradient(
f, [RandomInput()], delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
return Test
if __name__ == "__main__":
dtypes_to_test = [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
@ -194,5 +267,8 @@ if __name__ == "__main__":
shape = batch_dims + (size, size)
name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v)
_AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v))
# No gradient yet
if dtype not in [dtypes_lib.float32, dtypes_lib.float64]:
_AddTest(EigGradTest, "EigGrad", name,
_GetEigGradTest(dtype, shape, compute_v))
test.main()

View File

@ -633,6 +633,67 @@ def _MatrixTriangularSolveGrad(op, grad):
return grad_a, grad_b
# To avoid nan in cases with degenerate eigenvalues or
# degenerate/zero singular values in calculations of
# f and s_inv_mat, we introduce a Lorentz broadening.
def _SafeReciprocal(x, epsilon=1E-20):
return x * math_ops.reciprocal(x * x + epsilon)
@ops.RegisterGradient("Eig")
def _EigGrad(op, grad_e, grad_v):
"""Gradient for Eig.
Based on eq. 4.77 from paper by
Christoph Boeddeker et al.
https://arxiv.org/abs/1701.00392
See also
"Computation of eigenvalue and eigenvector derivatives
for a general complex-valued eigensystem" by Nico van der Aa.
As for now only distinct eigenvalue case is considered.
"""
e = op.outputs[0]
compute_v = op.get_attr("compute_v")
# a = op.inputs[0], which satisfies
# a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
with ops.control_dependencies([grad_e, grad_v]):
if compute_v:
v = op.outputs[1]
vt = _linalg.adjoint(v)
# Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
# Notice that because of the term involving f, the gradient becomes
# infinite (or NaN in practice) when eigenvalues are not unique.
# Mathematically this should not be surprising, since for (k-fold)
# degenerate eigenvalues, the corresponding eigenvectors are only defined
# up to arbitrary rotation in a (k-dimensional) subspace.
f = array_ops.matrix_set_diag(
_SafeReciprocal(
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
array_ops.zeros_like(e))
f = math_ops.conj(f)
vgv = math_ops.matmul(vt, grad_v)
mid = array_ops.matrix_diag(grad_e)
diag_grad_part = array_ops.matrix_diag(
array_ops.matrix_diag_part(
math_ops.cast(math_ops.real(vgv), vgv.dtype)))
mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
# vt is formally invertible as long as the original matrix is
# diagonalizable. However, in practice, vt may
# be ill-conditioned when matrix original matrix is close to
# non-diagonalizable one
grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
else:
_, v = linalg_ops.eig(op.inputs[0])
vt = _linalg.adjoint(v)
# vt is formally invertible as long as the original matrix is
# diagonalizable. However, in practice, vt may
# be ill-conditioned when matrix original matrix is close to
# non-diagonalizable one
grad_a = linalg_ops.matrix_solve(
vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
return math_ops.cast(grad_a, op.inputs[0].dtype)
@ops.RegisterGradient("SelfAdjointEigV2")
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
"""Gradient for SelfAdjointEigV2."""
@ -650,7 +711,7 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
# degenerate eigenvalues, the corresponding eigenvectors are only defined
# up to arbitrary rotation in a (k-dimensional) subspace.
f = array_ops.matrix_set_diag(
math_ops.reciprocal(
_SafeReciprocal(
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
array_ops.zeros_like(e))
grad_a = math_ops.matmul(
@ -745,11 +806,6 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
# only defined up a (k-dimensional) subspace. In practice, this can
# lead to numerical instability when singular values are close but not
# exactly equal.
# To avoid nan in cases with degenerate sigular values or zero singular values
# in calculating f and s_inv_mat, we introduce a Lorentz brodening.
def _SafeReciprocal(x, epsilon=1E-20):
return x * math_ops.reciprocal(x * x + epsilon)
s_shape = array_ops.shape(s)
f = array_ops.matrix_set_diag(