diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index a9d855a5a2b..7804aa7bf53 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -295,11 +295,10 @@ if __name__ == "__main__":
                        _GetQrOpTest(dtype, shape, full_matrices,
                                     use_static_shape))
 
-  # TODO(pfau): Get working with complex types.
   # TODO(pfau): Get working with full_matrices when rows > cols
   # TODO(pfau): Get working with shapeholders (dynamic shapes)
   for full_matrices in False, True:
-    for dtype in np.float32, np.float64:
+    for dtype in np.float32, np.float64, np.complex64, np.complex128:
       for rows in 1, 2, 5, 10:
         for cols in 1, 2, 5, 10:
           if rows <= cols or (not full_matrices and rows > cols):
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 847d144bde3..d02a67f652a 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -487,9 +487,11 @@ def _CholeskyGrad(op, grad):
 @ops.RegisterGradient("Qr")
 def _QrGrad(op, dq, dr):
   """Gradient for Qr."""
+
+  # The methodology is explained in detail in https://arxiv.org/abs/2009.10071
+  # QR and LQ Decomposition Matrix Backpropagation Algorithms for
+  # Square, Wide, and Deep, Real and Complex, Matrices and Their Software Implementation
   q, r = op.outputs
-  if q.dtype.is_complex:
-    raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
   if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
       r.shape.as_list()[-1] is None):
     raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
@@ -516,7 +518,17 @@ def _QrGrad(op, dq, dr):
 
     grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
     grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
-    return grad_a + grad_b
+    ret = grad_a + grad_b
+
+    if q.dtype.is_complex:
+      # need to add a correction to the gradient formula for complex case
+      m = rdr - _linalg.adjoint(qdq)
+      eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
+      correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
+      ret = ret + _TriangularSolve(
+          math_ops.matmul(q, _linalg.adjoint(correction)), r)
+
+    return ret
 
   num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
 
@@ -524,7 +536,6 @@ def _QrGrad(op, dq, dr):
     return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
 
   # Partition a = [x, y], r = [u, v] and reduce to the square case
-  # The methodology is explained in detail in https://arxiv.org/abs/2009.10071
   a = op.inputs[0]
   y = a[..., :, num_rows:]
   u = r[..., :, :num_rows]