Format according to code styles and test for complex SVD backprop
This commit is contained in:
parent
c9b193d1a9
commit
83a540a73a
@ -406,7 +406,7 @@ if __name__ == "__main__":
|
||||
_AddTest(SvdGradOpTest, "SvdGrad", name,
|
||||
_GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
|
||||
# The results are too inacurate for float32.
|
||||
if dtype == np.float64:
|
||||
if dtype in (np.float64, np.complex128):
|
||||
_AddTest(
|
||||
SvdGradGradOpTest, "SvdGradGrad", name,
|
||||
_GetSvdGradGradOpTest(dtype, shape, compute_uv,
|
||||
|
||||
@ -416,15 +416,15 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
||||
# To avoid nan in cases with degenrate sigular values or zero sigular values
|
||||
# in calculating f and s_inv_mat, we introduce a Lorentz brodening.
|
||||
|
||||
def safe_reciprocal(x, epsilon=1E-20):
|
||||
return x * math_ops.reciprocal(x * x + epsilon)
|
||||
def _SafeReciprocal(x, epsilon=1E-20):
|
||||
return x * math_ops.reciprocal(x * x + epsilon)
|
||||
|
||||
s_shape = array_ops.shape(s)
|
||||
f = array_ops.matrix_set_diag(
|
||||
safe_reciprocal(
|
||||
_SafeReciprocal(
|
||||
array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)
|
||||
), array_ops.zeros_like(s))
|
||||
s_inv_mat = array_ops.matrix_diag(safe_reciprocal(s))
|
||||
s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
|
||||
|
||||
v1 = v[..., :, :m]
|
||||
grad_v1 = grad_v[..., :, :m]
|
||||
@ -464,12 +464,14 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
|
||||
eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
|
||||
l = eye * v_gv
|
||||
term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l)-l)
|
||||
term3 = 1/2. * math_ops.matmul(u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
|
||||
term3 = 1/2. * math_ops.matmul(
|
||||
u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
|
||||
|
||||
grad_a_before_transpose += term3
|
||||
|
||||
if use_adjoint:
|
||||
grad_a = array_ops.matrix_transpose(grad_a_before_transpose, conjugate=True)
|
||||
grad_a = array_ops.matrix_transpose(
|
||||
grad_a_before_transpose, conjugate=True)
|
||||
else:
|
||||
grad_a = grad_a_before_transpose
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user