Removed double-wrapping of tf.linalg ops

PiperOrigin-RevId: 254746740
This commit is contained in:
Sergei Lebedev 2019-06-24 06:36:10 -07:00 committed by TensorFlower Gardener
parent 0e989df426
commit 0bd27b7fd4
3 changed files with 8 additions and 6 deletions

View File

@ -1956,6 +1956,7 @@ def matrix_diag(diagonal,
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
@deprecation.deprecated_endpoints("matrix_diag_part")
@dispatch.add_dispatch_support
def matrix_diag_part(
input, # pylint:disable=redefined-builtin
name="diag_part",

View File

@ -33,18 +33,18 @@ from tensorflow.python.util.tf_export import tf_export
# Linear algebra ops.
band_part = array_ops.matrix_band_part
cholesky = dispatch.add_dispatch_support(linalg_ops.cholesky)
cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve
det = dispatch.add_dispatch_support(linalg_ops.matrix_determinant)
det = linalg_ops.matrix_determinant
slogdet = gen_linalg_ops.log_matrix_determinant
tf_export('linalg.slogdet')(slogdet)
diag = array_ops.matrix_diag
diag_part = dispatch.add_dispatch_support(array_ops.matrix_diag_part)
diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
eye = linalg_ops.eye
inv = dispatch.add_dispatch_support(linalg_ops.matrix_inverse)
inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm
lu = gen_linalg_ops.lu
tf_export('linalg.logm')(logm)
@ -52,11 +52,11 @@ lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm
qr = linalg_ops.qr
set_diag = array_ops.matrix_set_diag
solve = dispatch.add_dispatch_support(linalg_ops.matrix_solve)
solve = linalg_ops.matrix_solve
sqrtm = linalg_ops.matrix_square_root
svd = linalg_ops.svd
tensordot = math_ops.tensordot
trace = dispatch.add_dispatch_support(math_ops.trace)
trace = math_ops.trace
transpose = array_ops.matrix_transpose
triangular_solve = linalg_ops.matrix_triangular_solve

View File

@ -2428,6 +2428,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("linalg.trace", v1=["linalg.trace", "trace"])
@deprecation.deprecated_endpoints("trace")
@dispatch.add_dispatch_support
def trace(x, name=None):
"""Compute the trace of a tensor `x`.