Add gradient for complex qr.
This commit is contained in:
parent
b1815fc713
commit
6623dc607d
tensorflow/python
@ -295,11 +295,10 @@ if __name__ == "__main__":
|
|||||||
_GetQrOpTest(dtype, shape, full_matrices,
|
_GetQrOpTest(dtype, shape, full_matrices,
|
||||||
use_static_shape))
|
use_static_shape))
|
||||||
|
|
||||||
# TODO(pfau): Get working with complex types.
|
|
||||||
# TODO(pfau): Get working with full_matrices when rows > cols
|
# TODO(pfau): Get working with full_matrices when rows > cols
|
||||||
# TODO(pfau): Get working with shapeholders (dynamic shapes)
|
# TODO(pfau): Get working with shapeholders (dynamic shapes)
|
||||||
for full_matrices in False, True:
|
for full_matrices in False, True:
|
||||||
for dtype in np.float32, np.float64:
|
for dtype in np.float32, np.float64, np.complex64, np.complex128:
|
||||||
for rows in 1, 2, 5, 10:
|
for rows in 1, 2, 5, 10:
|
||||||
for cols in 1, 2, 5, 10:
|
for cols in 1, 2, 5, 10:
|
||||||
if rows <= cols or (not full_matrices and rows > cols):
|
if rows <= cols or (not full_matrices and rows > cols):
|
||||||
|
@ -487,9 +487,11 @@ def _CholeskyGrad(op, grad):
|
|||||||
@ops.RegisterGradient("Qr")
|
@ops.RegisterGradient("Qr")
|
||||||
def _QrGrad(op, dq, dr):
|
def _QrGrad(op, dq, dr):
|
||||||
"""Gradient for Qr."""
|
"""Gradient for Qr."""
|
||||||
|
|
||||||
|
# The methodology is explained in detail in https://arxiv.org/abs/2009.10071
|
||||||
|
# QR and LQ Decomposition Matrix Backpropagation Algorithms for
|
||||||
|
# Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation
|
||||||
q, r = op.outputs
|
q, r = op.outputs
|
||||||
if q.dtype.is_complex:
|
|
||||||
raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
|
|
||||||
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
|
if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
|
||||||
r.shape.as_list()[-1] is None):
|
r.shape.as_list()[-1] is None):
|
||||||
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
|
raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
|
||||||
@ -516,7 +518,17 @@ def _QrGrad(op, dq, dr):
|
|||||||
|
|
||||||
grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
|
grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
|
||||||
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
|
grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
|
||||||
return grad_a + grad_b
|
ret = grad_a + grad_b
|
||||||
|
|
||||||
|
if q.dtype.is_complex:
|
||||||
|
# need to add a correction to the gradient formula for complex case
|
||||||
|
m = rdr - _linalg.adjoint(qdq)
|
||||||
|
eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
|
||||||
|
correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
|
||||||
|
ret = ret + _TriangularSolve(
|
||||||
|
math_ops.matmul(q, _linalg.adjoint(correction)), r)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
|
num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
|
||||||
|
|
||||||
@ -524,7 +536,6 @@ def _QrGrad(op, dq, dr):
|
|||||||
return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
|
return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
|
||||||
|
|
||||||
# Partition a = [x, y], r = [u, v] and reduce to the square case
|
# Partition a = [x, y], r = [u, v] and reduce to the square case
|
||||||
# The methodology is explained in detail in https://arxiv.org/abs/2009.10071
|
|
||||||
a = op.inputs[0]
|
a = op.inputs[0]
|
||||||
y = a[..., :, num_rows:]
|
y = a[..., :, num_rows:]
|
||||||
u = r[..., :, :num_rows]
|
u = r[..., :, :num_rows]
|
||||||
|
Loading…
Reference in New Issue
Block a user