Removed double-wrapping of tf.linalg ops

PiperOrigin-RevId: 254746740
This commit is contained in:
Sergei Lebedev 2019-06-24 06:36:10 -07:00 committed by TensorFlower Gardener
parent 0e989df426
commit 0bd27b7fd4
3 changed files with 8 additions and 6 deletions

View File

@ -1956,6 +1956,7 @@ def matrix_diag(diagonal,
@tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"]) @tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"])
@deprecation.deprecated_endpoints("matrix_diag_part") @deprecation.deprecated_endpoints("matrix_diag_part")
@dispatch.add_dispatch_support
def matrix_diag_part( def matrix_diag_part(
input, # pylint:disable=redefined-builtin input, # pylint:disable=redefined-builtin
name="diag_part", name="diag_part",

View File

@ -33,18 +33,18 @@ from tensorflow.python.util.tf_export import tf_export
# Linear algebra ops. # Linear algebra ops.
band_part = array_ops.matrix_band_part band_part = array_ops.matrix_band_part
cholesky = dispatch.add_dispatch_support(linalg_ops.cholesky) cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve cholesky_solve = linalg_ops.cholesky_solve
det = dispatch.add_dispatch_support(linalg_ops.matrix_determinant) det = linalg_ops.matrix_determinant
slogdet = gen_linalg_ops.log_matrix_determinant slogdet = gen_linalg_ops.log_matrix_determinant
tf_export('linalg.slogdet')(slogdet) tf_export('linalg.slogdet')(slogdet)
diag = array_ops.matrix_diag diag = array_ops.matrix_diag
diag_part = dispatch.add_dispatch_support(array_ops.matrix_diag_part) diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum einsum = special_math_ops.einsum
eye = linalg_ops.eye eye = linalg_ops.eye
inv = dispatch.add_dispatch_support(linalg_ops.matrix_inverse) inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm logm = gen_linalg_ops.matrix_logarithm
lu = gen_linalg_ops.lu lu = gen_linalg_ops.lu
tf_export('linalg.logm')(logm) tf_export('linalg.logm')(logm)
@ -52,11 +52,11 @@ lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm norm = linalg_ops.norm
qr = linalg_ops.qr qr = linalg_ops.qr
set_diag = array_ops.matrix_set_diag set_diag = array_ops.matrix_set_diag
solve = dispatch.add_dispatch_support(linalg_ops.matrix_solve) solve = linalg_ops.matrix_solve
sqrtm = linalg_ops.matrix_square_root sqrtm = linalg_ops.matrix_square_root
svd = linalg_ops.svd svd = linalg_ops.svd
tensordot = math_ops.tensordot tensordot = math_ops.tensordot
trace = dispatch.add_dispatch_support(math_ops.trace) trace = math_ops.trace
transpose = array_ops.matrix_transpose transpose = array_ops.matrix_transpose
triangular_solve = linalg_ops.matrix_triangular_solve triangular_solve = linalg_ops.matrix_triangular_solve

View File

@ -2428,6 +2428,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
@tf_export("linalg.trace", v1=["linalg.trace", "trace"]) @tf_export("linalg.trace", v1=["linalg.trace", "trace"])
@deprecation.deprecated_endpoints("trace") @deprecation.deprecated_endpoints("trace")
@dispatch.add_dispatch_support
def trace(x, name=None): def trace(x, name=None):
"""Compute the trace of a tensor `x`. """Compute the trace of a tensor `x`.