Merge pull request #32226 from refraction-ray:complex_svd
PiperOrigin-RevId: 267420351
This commit is contained in:
commit
eee9afd6db
@ -406,7 +406,7 @@ if __name__ == "__main__":
|
|||||||
_AddTest(SvdGradOpTest, "SvdGrad", name,
|
_AddTest(SvdGradOpTest, "SvdGrad", name,
|
||||||
_GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
|
_GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
|
||||||
# The results are too inacurate for float32.
|
# The results are too inacurate for float32.
|
||||||
if dtype == np.float64:
|
if dtype in (np.float64, np.complex128):
|
||||||
_AddTest(
|
_AddTest(
|
||||||
SvdGradGradOpTest, "SvdGradGrad", name,
|
SvdGradGradOpTest, "SvdGradGrad", name,
|
||||||
_GetSvdGradGradOpTest(dtype, shape, compute_uv,
|
_GetSvdGradGradOpTest(dtype, shape, compute_uv,
|
||||||
|
|||||||
@ -352,8 +352,12 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
# Giles' paper (see reference at top of file). A derivation for
|
# Giles' paper (see reference at top of file). A derivation for
|
||||||
# the full_matrices=False case is available at
|
# the full_matrices=False case is available at
|
||||||
# https://j-towns.github.io/papers/svd-derivative.pdf
|
# https://j-towns.github.io/papers/svd-derivative.pdf
|
||||||
|
# The derivation for complex valued SVD can be found in
|
||||||
|
# https://re-ra.xyz/misc/complexsvd.pdf or
|
||||||
|
# https://giggleliu.github.io/2019/04/02/einsumbp.html
|
||||||
a = op.inputs[0]
|
a = op.inputs[0]
|
||||||
a_shape = a.get_shape().with_rank_at_least(2)
|
a_shape = a.get_shape().with_rank_at_least(2)
|
||||||
|
grad_s = math_ops.cast(grad_s, a.dtype)
|
||||||
grad_s_mat = array_ops.matrix_diag(grad_s)
|
grad_s_mat = array_ops.matrix_diag(grad_s)
|
||||||
|
|
||||||
if not op.get_attr("compute_uv"):
|
if not op.get_attr("compute_uv"):
|
||||||
@ -364,11 +368,6 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
|
|
||||||
full_matrices = op.get_attr("full_matrices")
|
full_matrices = op.get_attr("full_matrices")
|
||||||
|
|
||||||
# TODO(rmlarsen): Make this work with complex types.
|
|
||||||
if a.dtype.is_complex:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"SVD gradient is not implemented for complex types and "
|
|
||||||
"compute_uv=True.")
|
|
||||||
grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
|
grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
|
||||||
grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
|
grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
|
||||||
m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
|
m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
|
||||||
@ -388,6 +387,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
s = op.outputs[0]
|
s = op.outputs[0]
|
||||||
u = op.outputs[1]
|
u = op.outputs[1]
|
||||||
v = op.outputs[2]
|
v = op.outputs[2]
|
||||||
|
s = math_ops.cast(s, a.dtype)
|
||||||
|
|
||||||
use_adjoint = False
|
use_adjoint = False
|
||||||
if m > n:
|
if m > n:
|
||||||
@ -413,17 +413,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
# only defined up a (k-dimensional) subspace. In practice, this can
|
# only defined up a (k-dimensional) subspace. In practice, this can
|
||||||
# lead to numerical instability when singular values are close but not
|
# lead to numerical instability when singular values are close but not
|
||||||
# exactly equal.
|
# exactly equal.
|
||||||
# Also, even with distinct singular values, the diagonal of f can have Inf
|
# To avoid nan in cases with degenrate sigular values or zero sigular values
|
||||||
# values before setting to zero, which hurt when differentiating through
|
# in calculating f and s_inv_mat, we introduce a Lorentz brodening.
|
||||||
# this op. To avoid that, we add eye to the matrix before taking
|
|
||||||
# the reciprocal.
|
def _SafeReciprocal(x, epsilon=1E-20):
|
||||||
|
return x * math_ops.reciprocal(x * x + epsilon)
|
||||||
|
|
||||||
s_shape = array_ops.shape(s)
|
s_shape = array_ops.shape(s)
|
||||||
eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=s.dtype)
|
|
||||||
f = array_ops.matrix_set_diag(
|
f = array_ops.matrix_set_diag(
|
||||||
math_ops.reciprocal(
|
_SafeReciprocal(
|
||||||
array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) +
|
array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
|
||||||
eye), array_ops.zeros_like(s))
|
array_ops.zeros_like(s))
|
||||||
s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
|
s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
|
||||||
|
|
||||||
v1 = v[..., :, :m]
|
v1 = v[..., :, :m]
|
||||||
grad_v1 = grad_v[..., :, :m]
|
grad_v1 = grad_v[..., :, :m]
|
||||||
@ -443,7 +444,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
if m == n:
|
if m == n:
|
||||||
grad_a_before_transpose = term1
|
grad_a_before_transpose = term1
|
||||||
else:
|
else:
|
||||||
gv1t = array_ops.matrix_transpose(grad_v1)
|
gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
|
||||||
gv1t_v1 = math_ops.matmul(gv1t, v1)
|
gv1t_v1 = math_ops.matmul(gv1t, v1)
|
||||||
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
|
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
|
||||||
|
|
||||||
@ -459,8 +460,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
|||||||
|
|
||||||
grad_a_before_transpose = term1 + term2
|
grad_a_before_transpose = term1 + term2
|
||||||
|
|
||||||
|
if a.dtype.is_complex:
|
||||||
|
eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
|
||||||
|
l = eye * v_gv
|
||||||
|
term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
|
||||||
|
term3 = 1 / 2. * math_ops.matmul(
|
||||||
|
u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
|
||||||
|
|
||||||
|
grad_a_before_transpose += term3
|
||||||
|
|
||||||
if use_adjoint:
|
if use_adjoint:
|
||||||
grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
|
grad_a = array_ops.matrix_transpose(
|
||||||
|
grad_a_before_transpose, conjugate=True)
|
||||||
else:
|
else:
|
||||||
grad_a = grad_a_before_transpose
|
grad_a = grad_a_before_transpose
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user