From 0bd27b7fd4c0318a43150392d21c10a5e463fcec Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 24 Jun 2019 06:36:10 -0700 Subject: [PATCH] Removed double-wrapping of tf.linalg ops PiperOrigin-RevId: 254746740 --- tensorflow/python/ops/array_ops.py | 1 + tensorflow/python/ops/linalg/linalg_impl.py | 12 ++++++------ tensorflow/python/ops/math_ops.py | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ebcf492b2ae..fe884ae08f3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1956,6 +1956,7 @@ def matrix_diag(diagonal, @tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"]) @deprecation.deprecated_endpoints("matrix_diag_part") +@dispatch.add_dispatch_support def matrix_diag_part( input, # pylint:disable=redefined-builtin name="diag_part", diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 527a8d365aa..9e29ba934cf 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -33,18 +33,18 @@ from tensorflow.python.util.tf_export import tf_export # Linear algebra ops. band_part = array_ops.matrix_band_part -cholesky = dispatch.add_dispatch_support(linalg_ops.cholesky) +cholesky = linalg_ops.cholesky cholesky_solve = linalg_ops.cholesky_solve -det = dispatch.add_dispatch_support(linalg_ops.matrix_determinant) +det = linalg_ops.matrix_determinant slogdet = gen_linalg_ops.log_matrix_determinant tf_export('linalg.slogdet')(slogdet) diag = array_ops.matrix_diag -diag_part = dispatch.add_dispatch_support(array_ops.matrix_diag_part) +diag_part = array_ops.matrix_diag_part eigh = linalg_ops.self_adjoint_eig eigvalsh = linalg_ops.self_adjoint_eigvals einsum = special_math_ops.einsum eye = linalg_ops.eye -inv = dispatch.add_dispatch_support(linalg_ops.matrix_inverse) +inv = linalg_ops.matrix_inverse logm = gen_linalg_ops.matrix_logarithm lu = gen_linalg_ops.lu tf_export('linalg.logm')(logm) @@ -52,11 +52,11 @@ lstsq = linalg_ops.matrix_solve_ls norm = linalg_ops.norm qr = linalg_ops.qr set_diag = array_ops.matrix_set_diag -solve = dispatch.add_dispatch_support(linalg_ops.matrix_solve) +solve = linalg_ops.matrix_solve sqrtm = linalg_ops.matrix_square_root svd = linalg_ops.svd tensordot = math_ops.tensordot -trace = dispatch.add_dispatch_support(math_ops.trace) +trace = math_ops.trace transpose = array_ops.matrix_transpose triangular_solve = linalg_ops.matrix_triangular_solve diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 57d863a6cae..cbba9b2f7f9 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2428,6 +2428,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): @tf_export("linalg.trace", v1=["linalg.trace", "trace"]) @deprecation.deprecated_endpoints("trace") +@dispatch.add_dispatch_support def trace(x, name=None): """Compute the trace of a tensor `x`.