diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index ebcf492b2ae..fe884ae08f3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1956,6 +1956,7 @@ def matrix_diag(diagonal,
 
 @tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
 @deprecation.deprecated_endpoints("matrix_diag_part")
+@dispatch.add_dispatch_support
 def matrix_diag_part(
     input,  # pylint:disable=redefined-builtin
     name="diag_part",
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index 527a8d365aa..9e29ba934cf 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -33,18 +33,18 @@ from tensorflow.python.util.tf_export import tf_export
 
 # Linear algebra ops.
 band_part = array_ops.matrix_band_part
-cholesky = dispatch.add_dispatch_support(linalg_ops.cholesky)
+cholesky = linalg_ops.cholesky
 cholesky_solve = linalg_ops.cholesky_solve
-det = dispatch.add_dispatch_support(linalg_ops.matrix_determinant)
+det = linalg_ops.matrix_determinant
 slogdet = gen_linalg_ops.log_matrix_determinant
 tf_export('linalg.slogdet')(slogdet)
 diag = array_ops.matrix_diag
-diag_part = dispatch.add_dispatch_support(array_ops.matrix_diag_part)
+diag_part = array_ops.matrix_diag_part
 eigh = linalg_ops.self_adjoint_eig
 eigvalsh = linalg_ops.self_adjoint_eigvals
 einsum = special_math_ops.einsum
 eye = linalg_ops.eye
-inv = dispatch.add_dispatch_support(linalg_ops.matrix_inverse)
+inv = linalg_ops.matrix_inverse
 logm = gen_linalg_ops.matrix_logarithm
 lu = gen_linalg_ops.lu
 tf_export('linalg.logm')(logm)
@@ -52,11 +52,11 @@ lstsq = linalg_ops.matrix_solve_ls
 norm = linalg_ops.norm
 qr = linalg_ops.qr
 set_diag = array_ops.matrix_set_diag
-solve = dispatch.add_dispatch_support(linalg_ops.matrix_solve)
+solve = linalg_ops.matrix_solve
 sqrtm = linalg_ops.matrix_square_root
 svd = linalg_ops.svd
 tensordot = math_ops.tensordot
-trace = dispatch.add_dispatch_support(math_ops.trace)
+trace = math_ops.trace
 transpose = array_ops.matrix_transpose
 triangular_solve = linalg_ops.matrix_triangular_solve
 
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 57d863a6cae..cbba9b2f7f9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2428,6 +2428,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
 
 @tf_export("linalg.trace", v1=["linalg.trace", "trace"])
 @deprecation.deprecated_endpoints("trace")
+@dispatch.add_dispatch_support
 def trace(x, name=None):
   """Compute the trace of a tensor `x`.