Merge pull request #32226 from refraction-ray:complex_svd
PiperOrigin-RevId: 267420351
This commit is contained in:
		
						commit
						eee9afd6db
					
				@ -406,7 +406,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
            _AddTest(SvdGradOpTest, "SvdGrad", name,
 | 
					            _AddTest(SvdGradOpTest, "SvdGrad", name,
 | 
				
			||||||
                     _GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
 | 
					                     _GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
 | 
				
			||||||
            # The results are too inacurate for float32.
 | 
					            # The results are too inacurate for float32.
 | 
				
			||||||
            if dtype == np.float64:
 | 
					            if dtype in (np.float64, np.complex128):
 | 
				
			||||||
              _AddTest(
 | 
					              _AddTest(
 | 
				
			||||||
                  SvdGradGradOpTest, "SvdGradGrad", name,
 | 
					                  SvdGradGradOpTest, "SvdGradGrad", name,
 | 
				
			||||||
                  _GetSvdGradGradOpTest(dtype, shape, compute_uv,
 | 
					                  _GetSvdGradGradOpTest(dtype, shape, compute_uv,
 | 
				
			||||||
 | 
				
			|||||||
@ -352,8 +352,12 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
				
			|||||||
  # Giles' paper (see reference at top of file).  A derivation for
 | 
					  # Giles' paper (see reference at top of file).  A derivation for
 | 
				
			||||||
  # the full_matrices=False case is available at
 | 
					  # the full_matrices=False case is available at
 | 
				
			||||||
  # https://j-towns.github.io/papers/svd-derivative.pdf
 | 
					  # https://j-towns.github.io/papers/svd-derivative.pdf
 | 
				
			||||||
 | 
					  # The derivation for complex valued SVD can be found in
 | 
				
			||||||
 | 
					  # https://re-ra.xyz/misc/complexsvd.pdf or
 | 
				
			||||||
 | 
					  # https://giggleliu.github.io/2019/04/02/einsumbp.html
 | 
				
			||||||
  a = op.inputs[0]
 | 
					  a = op.inputs[0]
 | 
				
			||||||
  a_shape = a.get_shape().with_rank_at_least(2)
 | 
					  a_shape = a.get_shape().with_rank_at_least(2)
 | 
				
			||||||
 | 
					  grad_s = math_ops.cast(grad_s, a.dtype)
 | 
				
			||||||
  grad_s_mat = array_ops.matrix_diag(grad_s)
 | 
					  grad_s_mat = array_ops.matrix_diag(grad_s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if not op.get_attr("compute_uv"):
 | 
					  if not op.get_attr("compute_uv"):
 | 
				
			||||||
@ -364,11 +368,6 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  full_matrices = op.get_attr("full_matrices")
 | 
					  full_matrices = op.get_attr("full_matrices")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # TODO(rmlarsen): Make this work with complex types.
 | 
					 | 
				
			||||||
  if a.dtype.is_complex:
 | 
					 | 
				
			||||||
    raise NotImplementedError(
 | 
					 | 
				
			||||||
        "SVD gradient is not implemented for complex types and "
 | 
					 | 
				
			||||||
        "compute_uv=True.")
 | 
					 | 
				
			||||||
  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
 | 
					  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
 | 
				
			||||||
  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
 | 
					  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
 | 
				
			||||||
  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
 | 
					  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
 | 
				
			||||||
@ -388,6 +387,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
				
			|||||||
  s = op.outputs[0]
 | 
					  s = op.outputs[0]
 | 
				
			||||||
  u = op.outputs[1]
 | 
					  u = op.outputs[1]
 | 
				
			||||||
  v = op.outputs[2]
 | 
					  v = op.outputs[2]
 | 
				
			||||||
 | 
					  s = math_ops.cast(s, a.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  use_adjoint = False
 | 
					  use_adjoint = False
 | 
				
			||||||
  if m > n:
 | 
					  if m > n:
 | 
				
			||||||
@ -413,17 +413,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
				
			|||||||
    # only defined up a (k-dimensional) subspace. In practice, this can
 | 
					    # only defined up a (k-dimensional) subspace. In practice, this can
 | 
				
			||||||
    # lead to numerical instability when singular values are close but not
 | 
					    # lead to numerical instability when singular values are close but not
 | 
				
			||||||
    # exactly equal.
 | 
					    # exactly equal.
 | 
				
			||||||
    # Also, even with distinct singular values, the diagonal of f can have Inf
 | 
					    # To avoid nan in cases with degenrate sigular values or zero sigular values
 | 
				
			||||||
    # values before setting to zero, which hurt when differentiating through
 | 
					    # in calculating f and s_inv_mat, we introduce a Lorentz brodening.
 | 
				
			||||||
    # this op. To avoid that, we add eye to the matrix before taking
 | 
					
 | 
				
			||||||
    # the reciprocal.
 | 
					    def _SafeReciprocal(x, epsilon=1E-20):
 | 
				
			||||||
 | 
					      return x * math_ops.reciprocal(x * x + epsilon)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    s_shape = array_ops.shape(s)
 | 
					    s_shape = array_ops.shape(s)
 | 
				
			||||||
    eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=s.dtype)
 | 
					 | 
				
			||||||
    f = array_ops.matrix_set_diag(
 | 
					    f = array_ops.matrix_set_diag(
 | 
				
			||||||
        math_ops.reciprocal(
 | 
					        _SafeReciprocal(
 | 
				
			||||||
            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) +
 | 
					            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
 | 
				
			||||||
            eye), array_ops.zeros_like(s))
 | 
					        array_ops.zeros_like(s))
 | 
				
			||||||
    s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
 | 
					    s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    v1 = v[..., :, :m]
 | 
					    v1 = v[..., :, :m]
 | 
				
			||||||
    grad_v1 = grad_v[..., :, :m]
 | 
					    grad_v1 = grad_v[..., :, :m]
 | 
				
			||||||
@ -443,7 +444,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
				
			|||||||
    if m == n:
 | 
					    if m == n:
 | 
				
			||||||
      grad_a_before_transpose = term1
 | 
					      grad_a_before_transpose = term1
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
      gv1t = array_ops.matrix_transpose(grad_v1)
 | 
					      gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
 | 
				
			||||||
      gv1t_v1 = math_ops.matmul(gv1t, v1)
 | 
					      gv1t_v1 = math_ops.matmul(gv1t, v1)
 | 
				
			||||||
      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
 | 
					      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -459,8 +460,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
      grad_a_before_transpose = term1 + term2
 | 
					      grad_a_before_transpose = term1 + term2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if a.dtype.is_complex:
 | 
				
			||||||
 | 
					      eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
 | 
				
			||||||
 | 
					      l = eye * v_gv
 | 
				
			||||||
 | 
					      term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
 | 
				
			||||||
 | 
					      term3 = 1 / 2. * math_ops.matmul(
 | 
				
			||||||
 | 
					          u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      grad_a_before_transpose += term3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_adjoint:
 | 
					    if use_adjoint:
 | 
				
			||||||
      grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
 | 
					      grad_a = array_ops.matrix_transpose(
 | 
				
			||||||
 | 
					          grad_a_before_transpose, conjugate=True)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
      grad_a = grad_a_before_transpose
 | 
					      grad_a = grad_a_before_transpose
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user