Format according to code styles and test for complex SVD backprop
This commit is contained in:
		
							parent
							
								
									c9b193d1a9
								
							
						
					
					
						commit
						83a540a73a
					
				@ -406,7 +406,7 @@ if __name__ == "__main__":
 | 
			
		||||
            _AddTest(SvdGradOpTest, "SvdGrad", name,
 | 
			
		||||
                     _GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
 | 
			
		||||
            # The results are too inacurate for float32.
 | 
			
		||||
            if dtype == np.float64:
 | 
			
		||||
            if dtype in (np.float64, np.complex128):
 | 
			
		||||
              _AddTest(
 | 
			
		||||
                  SvdGradGradOpTest, "SvdGradGrad", name,
 | 
			
		||||
                  _GetSvdGradGradOpTest(dtype, shape, compute_uv,
 | 
			
		||||
 | 
			
		||||
@ -352,7 +352,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
			
		||||
  # Giles' paper (see reference at top of file).  A derivation for
 | 
			
		||||
  # the full_matrices=False case is available at
 | 
			
		||||
  # https://j-towns.github.io/papers/svd-derivative.pdf
 | 
			
		||||
  # The derivation for complex valued SVD can be found in 
 | 
			
		||||
  # The derivation for complex valued SVD can be found in
 | 
			
		||||
  # https://re-ra.xyz/misc/complexsvd.pdf or
 | 
			
		||||
  # https://giggleliu.github.io/2019/04/02/einsumbp.html
 | 
			
		||||
  a = op.inputs[0]
 | 
			
		||||
@ -413,18 +413,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
			
		||||
    # only defined up a (k-dimensional) subspace. In practice, this can
 | 
			
		||||
    # lead to numerical instability when singular values are close but not
 | 
			
		||||
    # exactly equal.
 | 
			
		||||
    # To avoid nan in cases with degenrate sigular values or zero sigular values 
 | 
			
		||||
    # To avoid nan in cases with degenrate sigular values or zero sigular values
 | 
			
		||||
    # in calculating f and s_inv_mat, we introduce a Lorentz brodening.
 | 
			
		||||
    
 | 
			
		||||
    def safe_reciprocal(x, epsilon=1E-20):
 | 
			
		||||
        return x * math_ops.reciprocal(x * x + epsilon)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _SafeReciprocal(x, epsilon=1E-20):
 | 
			
		||||
      return x * math_ops.reciprocal(x * x + epsilon)
 | 
			
		||||
 | 
			
		||||
    s_shape = array_ops.shape(s)
 | 
			
		||||
    f = array_ops.matrix_set_diag(
 | 
			
		||||
        safe_reciprocal(
 | 
			
		||||
        _SafeReciprocal(
 | 
			
		||||
            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)
 | 
			
		||||
        ), array_ops.zeros_like(s))
 | 
			
		||||
    s_inv_mat = array_ops.matrix_diag(safe_reciprocal(s))
 | 
			
		||||
    s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
 | 
			
		||||
 | 
			
		||||
    v1 = v[..., :, :m]
 | 
			
		||||
    grad_v1 = grad_v[..., :, :m]
 | 
			
		||||
@ -459,17 +459,19 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
			
		||||
      term2 = math_ops.matmul(u_s_inv, term2_nous)
 | 
			
		||||
 | 
			
		||||
      grad_a_before_transpose = term1 + term2
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    if a.dtype.is_complex:
 | 
			
		||||
      eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
 | 
			
		||||
      l = eye * v_gv
 | 
			
		||||
      term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l)-l)
 | 
			
		||||
      term3 = 1/2. * math_ops.matmul(u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
 | 
			
		||||
        
 | 
			
		||||
      term3 = 1/2. * math_ops.matmul(
 | 
			
		||||
          u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
 | 
			
		||||
 | 
			
		||||
      grad_a_before_transpose += term3
 | 
			
		||||
 | 
			
		||||
    if use_adjoint:
 | 
			
		||||
      grad_a = array_ops.matrix_transpose(grad_a_before_transpose, conjugate=True)
 | 
			
		||||
      grad_a = array_ops.matrix_transpose(
 | 
			
		||||
          grad_a_before_transpose, conjugate=True)
 | 
			
		||||
    else:
 | 
			
		||||
      grad_a = grad_a_before_transpose
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user