diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index a9d855a5a2b..7804aa7bf53 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -295,11 +295,10 @@ if __name__ == "__main__": _GetQrOpTest(dtype, shape, full_matrices, use_static_shape)) - # TODO(pfau): Get working with complex types. # TODO(pfau): Get working with full_matrices when rows > cols # TODO(pfau): Get working with shapeholders (dynamic shapes) for full_matrices in False, True: - for dtype in np.float32, np.float64: + for dtype in np.float32, np.float64, np.complex64, np.complex128: for rows in 1, 2, 5, 10: for cols in 1, 2, 5, 10: if rows <= cols or (not full_matrices and rows > cols): diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 847d144bde3..d02a67f652a 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -487,9 +487,11 @@ def _CholeskyGrad(op, grad): @ops.RegisterGradient("Qr") def _QrGrad(op, dq, dr): """Gradient for Qr.""" + + # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 + # QR and LQ Decomposition Matrix Backpropagation Algorithms for + # Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation q, r = op.outputs - if q.dtype.is_complex: - raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) if (r.shape.ndims is None or r.shape.as_list()[-2] is None or r.shape.as_list()[-1] is None): raise NotImplementedError("QrGrad not implemented with dynamic shapes.") @@ -516,7 +518,17 @@ def _QrGrad(op, dq, dr): grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) - return grad_a + grad_b + ret = grad_a + grad_b + + if q.dtype.is_complex: + # need to add a correction to the gradient formula for complex case + m = rdr - _linalg.adjoint(qdq) + eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m)) + correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype) + ret = ret + _TriangularSolve( + math_ops.matmul(q, _linalg.adjoint(correction)), r) + + return ret num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1] @@ -524,7 +536,6 @@ def _QrGrad(op, dq, dr): return _QrGradSquareAndDeepMatrices(q, r, dq, dr) # Partition a = [x, y], r = [u, v] and reduce to the square case - # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 a = op.inputs[0] y = a[..., :, num_rows:] u = r[..., :, :num_rows]