From 124386cd8312c530097cf930be9cdb3544c1e361 Mon Sep 17 00:00:00 2001 From: refraction-ray Date: Thu, 5 Sep 2019 12:39:58 +0800 Subject: [PATCH 1/2] add support for reverse AD for complex SVD --- tensorflow/python/ops/linalg_grad.py | 41 +++++++++++++++++----------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 21b09eb267f..ea8174e71b0 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -352,8 +352,12 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf + # The derivation for complex valued SVD can be found in + # https://re-ra.xyz/misc/complexsvd.pdf or + # https://giggleliu.github.io/2019/04/02/einsumbp.html a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) + grad_s = math_ops.cast(grad_s, a.dtype) grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): @@ -364,11 +368,6 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): full_matrices = op.get_attr("full_matrices") - # TODO(rmlarsen): Make this work with complex types. - if a.dtype.is_complex: - raise NotImplementedError( - "SVD gradient is not implemented for complex types and " - "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) @@ -388,6 +387,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] + s = math_ops.cast(s, a.dtype) use_adjoint = False if m > n: @@ -413,17 +413,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. - # Also, even with distinct singular values, the diagonal of f can have Inf - # values before setting to zero, which hurt when differentiating through - # this op. To avoid that, we add eye to the matrix before taking - # the reciprocal. + # To avoid nan in cases with degenrate sigular values or zero sigular values + # in calculating f and s_inv_mat, we introduce a Lorentz brodening. + + def safe_reciprocal(x, epsilon=1E-20): + return x * math_ops.reciprocal(x * x + epsilon) + s_shape = array_ops.shape(s) - eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=s.dtype) f = array_ops.matrix_set_diag( - math_ops.reciprocal( - array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) + - eye), array_ops.zeros_like(s)) - s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) + safe_reciprocal( + array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) + ), array_ops.zeros_like(s)) + s_inv_mat = array_ops.matrix_diag(safe_reciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] @@ -443,7 +444,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): if m == n: grad_a_before_transpose = term1 else: - gv1t = array_ops.matrix_transpose(grad_v1) + gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) @@ -458,9 +459,17 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 + + if a.dtype.is_complex: + eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) + l = eye * v_gv + term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l)-l) + term3 = 1/2. * math_ops.matmul(u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) + + grad_a_before_transpose += term3 if use_adjoint: - grad_a = array_ops.matrix_transpose(grad_a_before_transpose) + grad_a = array_ops.matrix_transpose(grad_a_before_transpose, conjugate=True) else: grad_a = grad_a_before_transpose From 83a540a73a86588736027ddc589789f47f1806bd Mon Sep 17 00:00:00 2001 From: refraction-ray Date: Thu, 5 Sep 2019 17:01:21 +0800 Subject: [PATCH 2/2] Format according to code styles and test for complex SVD backprop --- tensorflow/python/kernel_tests/svd_op_test.py | 2 +- tensorflow/python/ops/linalg_grad.py | 26 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py index 278ec9d93b6..bbcab12a163 100644 --- a/tensorflow/python/kernel_tests/svd_op_test.py +++ b/tensorflow/python/kernel_tests/svd_op_test.py @@ -406,7 +406,7 @@ if __name__ == "__main__": _AddTest(SvdGradOpTest, "SvdGrad", name, _GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices)) # The results are too inacurate for float32. - if dtype == np.float64: + if dtype in (np.float64, np.complex128): _AddTest( SvdGradGradOpTest, "SvdGradGrad", name, _GetSvdGradGradOpTest(dtype, shape, compute_uv, diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index ea8174e71b0..1d67b8ac942 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -352,7 +352,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf - # The derivation for complex valued SVD can be found in + # The derivation for complex valued SVD can be found in # https://re-ra.xyz/misc/complexsvd.pdf or # https://giggleliu.github.io/2019/04/02/einsumbp.html a = op.inputs[0] @@ -413,18 +413,18 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. - # To avoid nan in cases with degenrate sigular values or zero sigular values + # To avoid nan in cases with degenrate sigular values or zero sigular values # in calculating f and s_inv_mat, we introduce a Lorentz brodening. - - def safe_reciprocal(x, epsilon=1E-20): - return x * math_ops.reciprocal(x * x + epsilon) - + + def _SafeReciprocal(x, epsilon=1E-20): + return x * math_ops.reciprocal(x * x + epsilon) + s_shape = array_ops.shape(s) f = array_ops.matrix_set_diag( - safe_reciprocal( + _SafeReciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) ), array_ops.zeros_like(s)) - s_inv_mat = array_ops.matrix_diag(safe_reciprocal(s)) + s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] @@ -459,17 +459,19 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 - + if a.dtype.is_complex: eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) l = eye * v_gv term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l)-l) - term3 = 1/2. * math_ops.matmul(u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) - + term3 = 1/2. * math_ops.matmul( + u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) + grad_a_before_transpose += term3 if use_adjoint: - grad_a = array_ops.matrix_transpose(grad_a_before_transpose, conjugate=True) + grad_a = array_ops.matrix_transpose( + grad_a_before_transpose, conjugate=True) else: grad_a = grad_a_before_transpose