Add gradient for complex qr.

This commit is contained in:
Denisa Roberts 2020-12-07 17:16:32 -05:00
parent b1815fc713
commit 6623dc607d
2 changed files with 16 additions and 6 deletions

View File

@ -295,11 +295,10 @@ if __name__ == "__main__":
_GetQrOpTest(dtype, shape, full_matrices,
use_static_shape))
# TODO(pfau): Get working with complex types.
# TODO(pfau): Get working with full_matrices when rows > cols
# TODO(pfau): Get working with shapeholders (dynamic shapes)
for full_matrices in False, True:
for dtype in np.float32, np.float64:
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10:
for cols in 1, 2, 5, 10:
if rows <= cols or (not full_matrices and rows > cols):

View File

@ -487,9 +487,11 @@ def _CholeskyGrad(op, grad):
@ops.RegisterGradient("Qr")
def _QrGrad(op, dq, dr):
"""Gradient for Qr."""
# The methodology is explained in detail in https://arxiv.org/abs/2009.10071
# QR and LQ Decomposition Matrix Backpropagation Algorithms for
# Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation
q, r = op.outputs
if q.dtype.is_complex:
raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
r.shape.as_list()[-1] is None):
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
@ -516,7 +518,17 @@ def _QrGrad(op, dq, dr):
grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
return grad_a + grad_b
ret = grad_a + grad_b
if q.dtype.is_complex:
# need to add a correction to the gradient formula for complex case
m = rdr - _linalg.adjoint(qdq)
eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
ret = ret + _TriangularSolve(
math_ops.matmul(q, _linalg.adjoint(correction)), r)
return ret
num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
@ -524,7 +536,6 @@ def _QrGrad(op, dq, dr):
return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
# Partition a = [x, y], r = [u, v] and reduce to the square case
# The methodology is explained in detail in https://arxiv.org/abs/2009.10071
a = op.inputs[0]
y = a[..., :, num_rows:]
u = r[..., :, :num_rows]