Merge pull request #33808 from Randl:eig_grad2
PiperOrigin-RevId: 306650050 Change-Id: I49df540bab790bb4e5be83fc4244871c2ac5321a
This commit is contained in:
commit
44547d9fd6
@ -24,9 +24,11 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import sort_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -82,7 +84,7 @@ class EigTest(test.TestCase):
|
|||||||
"self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
|
"self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
|
||||||
self.assertEqual(matrix.shape, (32, 32))
|
self.assertEqual(matrix.shape, (32, 32))
|
||||||
matrix_tensor = constant_op.constant(matrix)
|
matrix_tensor = constant_op.constant(matrix)
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as _:
|
||||||
(e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
|
(e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
|
||||||
self.assertEqual(e.size, 32)
|
self.assertEqual(e.size, 32)
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
@ -99,9 +101,8 @@ def SortEigenValues(e):
|
|||||||
def SortEigenDecomposition(e, v):
|
def SortEigenDecomposition(e, v):
|
||||||
if v.ndim < 2:
|
if v.ndim < 2:
|
||||||
return e, v
|
return e, v
|
||||||
else:
|
perm = np.argsort(e.real + e.imag, -1)
|
||||||
perm = np.argsort(e.real + e.imag, -1)
|
return np.take(e, perm, -1), np.take(v, perm, -1)
|
||||||
return np.take(e, perm, -1), np.take(v, perm, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def EquilibrateEigenVectorPhases(x, y):
|
def EquilibrateEigenVectorPhases(x, y):
|
||||||
@ -147,17 +148,23 @@ def _GetEigTest(dtype_, shape_, compute_v_):
|
|||||||
n = shape_[-1]
|
n = shape_[-1]
|
||||||
batch_shape = shape_[:-2]
|
batch_shape = shape_[:-2]
|
||||||
np_dtype = dtype_.as_numpy_dtype
|
np_dtype = dtype_.as_numpy_dtype
|
||||||
# most of matrices are diagonalizable # TODO
|
|
||||||
a = np.random.uniform(
|
def RandomInput():
|
||||||
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
# Most matrices are diagonalizable
|
||||||
if dtype_.is_complex:
|
a = np.random.uniform(
|
||||||
a += 1j * np.random.uniform(
|
|
||||||
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
||||||
a = np.tile(a, batch_shape + (1, 1))
|
if dtype_.is_complex:
|
||||||
|
a += 1j * np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
||||||
|
a = np.tile(a, batch_shape + (1, 1))
|
||||||
|
return a
|
||||||
|
|
||||||
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
|
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
|
||||||
atol = 1e-4
|
atol = 1e-4
|
||||||
else:
|
else:
|
||||||
atol = 1e-12
|
atol = 1e-12
|
||||||
|
|
||||||
|
a = RandomInput()
|
||||||
np_e, np_v = np.linalg.eig(a)
|
np_e, np_v = np.linalg.eig(a)
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
if compute_v_:
|
if compute_v_:
|
||||||
@ -182,6 +189,72 @@ def _GetEigTest(dtype_, shape_, compute_v_):
|
|||||||
return Test
|
return Test
|
||||||
|
|
||||||
|
|
||||||
|
class EigGradTest(test.TestCase):
|
||||||
|
pass # Filled in below
|
||||||
|
|
||||||
|
|
||||||
|
def _GetEigGradTest(dtype_, shape_, compute_v_):
|
||||||
|
|
||||||
|
def Test(self):
|
||||||
|
np.random.seed(1)
|
||||||
|
n = shape_[-1]
|
||||||
|
batch_shape = shape_[:-2]
|
||||||
|
np_dtype = dtype_.as_numpy_dtype
|
||||||
|
|
||||||
|
def RandomInput():
|
||||||
|
# Most matrices are diagonalizable
|
||||||
|
a = np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
||||||
|
if dtype_.is_complex:
|
||||||
|
a += 1j * np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
||||||
|
a = np.tile(a, batch_shape + (1, 1))
|
||||||
|
return a
|
||||||
|
|
||||||
|
# Optimal stepsize for central difference is O(epsilon^{1/3}).
|
||||||
|
epsilon = np.finfo(np_dtype).eps
|
||||||
|
delta = 0.1 * epsilon**(1.0 / 3.0)
|
||||||
|
# tolerance obtained by looking at actual differences using
|
||||||
|
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
|
||||||
|
# after discarding one random input sample
|
||||||
|
_ = RandomInput()
|
||||||
|
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
|
||||||
|
tol = 1e-2
|
||||||
|
else:
|
||||||
|
tol = 1e-7
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
|
||||||
|
def Compute(x):
|
||||||
|
e, v = linalg_ops.eig(x)
|
||||||
|
|
||||||
|
# We sort eigenvalues by e.real+e.imag to have consistent
|
||||||
|
# order between runs
|
||||||
|
b_dims = len(e.shape) - 1
|
||||||
|
idx = sort_ops.argsort(math_ops.real(e) + math_ops.imag(e), axis=-1)
|
||||||
|
e = array_ops.gather(e, idx, batch_dims=b_dims)
|
||||||
|
v = array_ops.gather(v, idx, batch_dims=b_dims)
|
||||||
|
|
||||||
|
# (complex) Eigenvectors are only unique up to an arbitrary phase
|
||||||
|
# We normalize the vectors such that the first component has phase 0.
|
||||||
|
top_rows = v[..., 0:1, :]
|
||||||
|
angle = -math_ops.angle(top_rows)
|
||||||
|
phase = math_ops.complex(math_ops.cos(angle), math_ops.sin(angle))
|
||||||
|
v *= phase
|
||||||
|
return e, v
|
||||||
|
|
||||||
|
if compute_v_:
|
||||||
|
funcs = [lambda x: Compute(x)[0], lambda x: Compute(x)[1]]
|
||||||
|
else:
|
||||||
|
funcs = [linalg_ops.eigvals]
|
||||||
|
|
||||||
|
for f in funcs:
|
||||||
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
|
f, [RandomInput()], delta=delta)
|
||||||
|
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
return Test
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dtypes_to_test = [
|
dtypes_to_test = [
|
||||||
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
|
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
|
||||||
@ -194,5 +267,8 @@ if __name__ == "__main__":
|
|||||||
shape = batch_dims + (size, size)
|
shape = batch_dims + (size, size)
|
||||||
name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v)
|
name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v)
|
||||||
_AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v))
|
_AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v))
|
||||||
# No gradient yet
|
|
||||||
|
if dtype not in [dtypes_lib.float32, dtypes_lib.float64]:
|
||||||
|
_AddTest(EigGradTest, "EigGrad", name,
|
||||||
|
_GetEigGradTest(dtype, shape, compute_v))
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -633,6 +633,67 @@ def _MatrixTriangularSolveGrad(op, grad):
|
|||||||
return grad_a, grad_b
|
return grad_a, grad_b
|
||||||
|
|
||||||
|
|
||||||
|
# To avoid nan in cases with degenerate eigenvalues or
|
||||||
|
# degenerate/zero singular values in calculations of
|
||||||
|
# f and s_inv_mat, we introduce a Lorentz broadening.
|
||||||
|
def _SafeReciprocal(x, epsilon=1E-20):
|
||||||
|
return x * math_ops.reciprocal(x * x + epsilon)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("Eig")
|
||||||
|
def _EigGrad(op, grad_e, grad_v):
|
||||||
|
"""Gradient for Eig.
|
||||||
|
|
||||||
|
Based on eq. 4.77 from paper by
|
||||||
|
Christoph Boeddeker et al.
|
||||||
|
https://arxiv.org/abs/1701.00392
|
||||||
|
See also
|
||||||
|
"Computation of eigenvalue and eigenvector derivatives
|
||||||
|
for a general complex-valued eigensystem" by Nico van der Aa.
|
||||||
|
As for now only distinct eigenvalue case is considered.
|
||||||
|
"""
|
||||||
|
e = op.outputs[0]
|
||||||
|
compute_v = op.get_attr("compute_v")
|
||||||
|
# a = op.inputs[0], which satisfies
|
||||||
|
# a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
|
||||||
|
with ops.control_dependencies([grad_e, grad_v]):
|
||||||
|
if compute_v:
|
||||||
|
v = op.outputs[1]
|
||||||
|
vt = _linalg.adjoint(v)
|
||||||
|
# Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
|
||||||
|
# Notice that because of the term involving f, the gradient becomes
|
||||||
|
# infinite (or NaN in practice) when eigenvalues are not unique.
|
||||||
|
# Mathematically this should not be surprising, since for (k-fold)
|
||||||
|
# degenerate eigenvalues, the corresponding eigenvectors are only defined
|
||||||
|
# up to arbitrary rotation in a (k-dimensional) subspace.
|
||||||
|
f = array_ops.matrix_set_diag(
|
||||||
|
_SafeReciprocal(
|
||||||
|
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
|
||||||
|
array_ops.zeros_like(e))
|
||||||
|
f = math_ops.conj(f)
|
||||||
|
vgv = math_ops.matmul(vt, grad_v)
|
||||||
|
mid = array_ops.matrix_diag(grad_e)
|
||||||
|
diag_grad_part = array_ops.matrix_diag(
|
||||||
|
array_ops.matrix_diag_part(
|
||||||
|
math_ops.cast(math_ops.real(vgv), vgv.dtype)))
|
||||||
|
mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
|
||||||
|
# vt is formally invertible as long as the original matrix is
|
||||||
|
# diagonalizable. However, in practice, vt may
|
||||||
|
# be ill-conditioned when matrix original matrix is close to
|
||||||
|
# non-diagonalizable one
|
||||||
|
grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
|
||||||
|
else:
|
||||||
|
_, v = linalg_ops.eig(op.inputs[0])
|
||||||
|
vt = _linalg.adjoint(v)
|
||||||
|
# vt is formally invertible as long as the original matrix is
|
||||||
|
# diagonalizable. However, in practice, vt may
|
||||||
|
# be ill-conditioned when matrix original matrix is close to
|
||||||
|
# non-diagonalizable one
|
||||||
|
grad_a = linalg_ops.matrix_solve(
|
||||||
|
vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
|
||||||
|
return math_ops.cast(grad_a, op.inputs[0].dtype)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("SelfAdjointEigV2")
|
@ops.RegisterGradient("SelfAdjointEigV2")
|
||||||
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
|
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
|
||||||
"""Gradient for SelfAdjointEigV2."""
|
"""Gradient for SelfAdjointEigV2."""
|
||||||
@ -650,7 +711,7 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
|
|||||||
# degenerate eigenvalues, the corresponding eigenvectors are only defined
|
# degenerate eigenvalues, the corresponding eigenvectors are only defined
|
||||||
# up to arbitrary rotation in a (k-dimensional) subspace.
|
# up to arbitrary rotation in a (k-dimensional) subspace.
|
||||||
f = array_ops.matrix_set_diag(
|
f = array_ops.matrix_set_diag(
|
||||||
math_ops.reciprocal(
|
_SafeReciprocal(
|
||||||
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
|
array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
|
||||||
array_ops.zeros_like(e))
|
array_ops.zeros_like(e))
|
||||||
grad_a = math_ops.matmul(
|
grad_a = math_ops.matmul(
|
||||||
@ -745,11 +806,6 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
# only defined up a (k-dimensional) subspace. In practice, this can
|
# only defined up a (k-dimensional) subspace. In practice, this can
|
||||||
# lead to numerical instability when singular values are close but not
|
# lead to numerical instability when singular values are close but not
|
||||||
# exactly equal.
|
# exactly equal.
|
||||||
# To avoid nan in cases with degenerate sigular values or zero singular values
|
|
||||||
# in calculating f and s_inv_mat, we introduce a Lorentz brodening.
|
|
||||||
|
|
||||||
def _SafeReciprocal(x, epsilon=1E-20):
|
|
||||||
return x * math_ops.reciprocal(x * x + epsilon)
|
|
||||||
|
|
||||||
s_shape = array_ops.shape(s)
|
s_shape = array_ops.shape(s)
|
||||||
f = array_ops.matrix_set_diag(
|
f = array_ops.matrix_set_diag(
|
||||||
|
Loading…
Reference in New Issue
Block a user