Move linear algebra ops 'lu_solve', 'lu_inverse', and 'lu_reconstruct' from TensorFlow Probability to TensorFlow core.
Slightly refactor linear_operator_util.py to avoid a circular dependence with linalg_impl.py. PiperOrigin-RevId: 264692711
This commit is contained in:
parent
4f6a4c080a
commit
02f4686aee
@ -407,5 +407,155 @@ class PinvTestStatic64CustomRcond(test.TestCase, _PinvTest):
|
||||
use_default_rcond = False
|
||||
|
||||
|
||||
def make_tensor_hiding_attributes(value, hide_shape, hide_value=True):
|
||||
if not hide_value:
|
||||
return ops.convert_to_tensor(value)
|
||||
|
||||
shape = None if hide_shape else getattr(value, "shape", None)
|
||||
return array_ops.placeholder_with_default(value, shape=shape)
|
||||
|
||||
|
||||
class _LUReconstruct(object):
|
||||
dtype = np.float32
|
||||
use_static_shape = True
|
||||
|
||||
def test_non_batch(self):
|
||||
x_ = np.array([[3, 4], [1, 2]], dtype=self.dtype)
|
||||
x = array_ops.placeholder_with_default(
|
||||
x_, shape=x_.shape if self.use_static_shape else None)
|
||||
|
||||
y = linalg.lu_reconstruct(*linalg.lu(x), validate_args=True)
|
||||
y_ = self.evaluate(y)
|
||||
|
||||
if self.use_static_shape:
|
||||
self.assertAllEqual(x_.shape, y.shape)
|
||||
self.assertAllClose(x_, y_, atol=0., rtol=1e-3)
|
||||
|
||||
def test_batch(self):
|
||||
x_ = np.array([
|
||||
[[3, 4], [1, 2]],
|
||||
[[7, 8], [3, 4]],
|
||||
], dtype=self.dtype)
|
||||
x = array_ops.placeholder_with_default(
|
||||
x_, shape=x_.shape if self.use_static_shape else None)
|
||||
|
||||
y = linalg.lu_reconstruct(*linalg.lu(x), validate_args=True)
|
||||
y_ = self.evaluate(y)
|
||||
|
||||
if self.use_static_shape:
|
||||
self.assertAllEqual(x_.shape, y.shape)
|
||||
self.assertAllClose(x_, y_, atol=0., rtol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LUReconstructStatic(test.TestCase, _LUReconstruct):
|
||||
use_static_shape = True
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LUReconstructDynamic(test.TestCase, _LUReconstruct):
|
||||
use_static_shape = False
|
||||
|
||||
|
||||
class _LUMatrixInverse(object):
|
||||
dtype = np.float32
|
||||
use_static_shape = True
|
||||
|
||||
def test_non_batch(self):
|
||||
x_ = np.array([[1, 2], [3, 4]], dtype=self.dtype)
|
||||
x = array_ops.placeholder_with_default(
|
||||
x_, shape=x_.shape if self.use_static_shape else None)
|
||||
|
||||
y = linalg.lu_matrix_inverse(*linalg.lu(x), validate_args=True)
|
||||
y_ = self.evaluate(y)
|
||||
|
||||
if self.use_static_shape:
|
||||
self.assertAllEqual(x_.shape, y.shape)
|
||||
self.assertAllClose(np.linalg.inv(x_), y_, atol=0., rtol=1e-3)
|
||||
|
||||
def test_batch(self):
|
||||
x_ = np.array([
|
||||
[[1, 2], [3, 4]],
|
||||
[[7, 8], [3, 4]],
|
||||
[[0.25, 0.5], [0.75, -2.]],
|
||||
],
|
||||
dtype=self.dtype)
|
||||
x = array_ops.placeholder_with_default(
|
||||
x_, shape=x_.shape if self.use_static_shape else None)
|
||||
|
||||
y = linalg.lu_matrix_inverse(*linalg.lu(x), validate_args=True)
|
||||
y_ = self.evaluate(y)
|
||||
|
||||
if self.use_static_shape:
|
||||
self.assertAllEqual(x_.shape, y.shape)
|
||||
self.assertAllClose(np.linalg.inv(x_), y_, atol=0., rtol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LUMatrixInverseStatic(test.TestCase, _LUMatrixInverse):
|
||||
use_static_shape = True
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LUMatrixInverseDynamic(test.TestCase, _LUMatrixInverse):
|
||||
use_static_shape = False
|
||||
|
||||
|
||||
class _LUSolve(object):
|
||||
dtype = np.float32
|
||||
use_static_shape = True
|
||||
|
||||
def test_non_batch(self):
|
||||
x_ = np.array([[1, 2], [3, 4]], dtype=self.dtype)
|
||||
x = array_ops.placeholder_with_default(
|
||||
x_, shape=x_.shape if self.use_static_shape else None)
|
||||
rhs_ = np.array([[1, 1]], dtype=self.dtype).T
|
||||
rhs = array_ops.placeholder_with_default(
|
||||
rhs_, shape=rhs_.shape if self.use_static_shape else None)
|
||||
|
||||
lower_upper, perm = linalg.lu(x)
|
||||
y = linalg.lu_solve(lower_upper, perm, rhs, validate_args=True)
|
||||
y_, perm_ = self.evaluate([y, perm])
|
||||
|
||||
self.assertAllEqual([1, 0], perm_)
|
||||
expected_ = np.linalg.solve(x_, rhs_)
|
||||
if self.use_static_shape:
|
||||
self.assertAllEqual(expected_.shape, y.shape)
|
||||
self.assertAllClose(expected_, y_, atol=0., rtol=1e-3)
|
||||
|
||||
def test_batch_broadcast(self):
|
||||
x_ = np.array([
|
||||
[[1, 2], [3, 4]],
|
||||
[[7, 8], [3, 4]],
|
||||
[[0.25, 0.5], [0.75, -2.]],
|
||||
],
|
||||
dtype=self.dtype)
|
||||
x = array_ops.placeholder_with_default(
|
||||
x_, shape=x_.shape if self.use_static_shape else None)
|
||||
rhs_ = np.array([[1, 1]], dtype=self.dtype).T
|
||||
rhs = array_ops.placeholder_with_default(
|
||||
rhs_, shape=rhs_.shape if self.use_static_shape else None)
|
||||
|
||||
lower_upper, perm = linalg.lu(x)
|
||||
y = linalg.lu_solve(lower_upper, perm, rhs, validate_args=True)
|
||||
y_, perm_ = self.evaluate([y, perm])
|
||||
|
||||
self.assertAllEqual([[1, 0], [0, 1], [1, 0]], perm_)
|
||||
expected_ = np.linalg.solve(x_, rhs_[np.newaxis])
|
||||
if self.use_static_shape:
|
||||
self.assertAllEqual(expected_.shape, y.shape)
|
||||
self.assertAllClose(expected_, y_, atol=0., rtol=1e-3)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LUSolveStatic(test.TestCase, _LUSolve):
|
||||
use_static_shape = True
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LUSolveDynamic(test.TestCase, _LUSolve):
|
||||
use_static_shape = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -28,6 +28,7 @@ py_library(
|
||||
srcs = ["linalg_impl.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":linear_operator_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
@ -35,3 +36,18 @@ py_library(
|
||||
"//tensorflow/python:special_math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "linear_operator_util",
|
||||
srcs = ["linear_operator_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/module",
|
||||
],
|
||||
)
|
||||
|
@ -29,8 +29,10 @@ from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -821,3 +823,296 @@ def pinv(a, rcond=None, validate_args=False, name=None):
|
||||
a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]]))
|
||||
|
||||
return a_pinv
|
||||
|
||||
|
||||
@tf_export('linalg.lu_solve')
|
||||
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
|
||||
"""Solves systems of linear eqns `A X = RHS`, given LU factorizations.
|
||||
|
||||
Note: this function does not verify the implied matrix is actually invertible
|
||||
nor is this condition checked even when `validate_args=True`.
|
||||
|
||||
Args:
|
||||
lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
|
||||
matmul(L, U)) = X` then `lower_upper = L + U - eye`.
|
||||
perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
|
||||
X` then `perm = argmax(P)`.
|
||||
rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
|
||||
`A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[...,
|
||||
tf.newaxis])[..., 0]`.
|
||||
validate_args: Python `bool` indicating whether arguments should be checked
|
||||
for correctness. Note: this function does not verify the implied matrix is
|
||||
actually invertible, even when `validate_args=True`.
|
||||
Default value: `False` (i.e., don't validate arguments).
|
||||
name: Python `str` name given to ops managed by this object.
|
||||
Default value: `None` (i.e., 'lu_solve').
|
||||
|
||||
Returns:
|
||||
x: The `X` in `A @ X = RHS`.
|
||||
|
||||
#### Examples
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
x = [[[1., 2],
|
||||
[3, 4]],
|
||||
[[7, 8],
|
||||
[3, 4]]]
|
||||
inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
|
||||
tf.assert_near(tf.matrix_inverse(x), inv_x)
|
||||
# ==> True
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
with ops.name_scope(name or 'lu_solve'):
|
||||
lower_upper = ops.convert_to_tensor(
|
||||
lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
|
||||
perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
|
||||
rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs')
|
||||
|
||||
assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
|
||||
if assertions:
|
||||
with ops.control_dependencies(assertions):
|
||||
lower_upper = array_ops.identity(lower_upper)
|
||||
perm = array_ops.identity(perm)
|
||||
rhs = array_ops.identity(rhs)
|
||||
|
||||
if (rhs.shape.rank == 2 and perm.shape.rank == 1):
|
||||
# Both rhs and perm have scalar batch_shape.
|
||||
permuted_rhs = array_ops.gather(rhs, perm, axis=-2)
|
||||
else:
|
||||
# Either rhs or perm have non-scalar batch_shape or we can't determine
|
||||
# this information statically.
|
||||
rhs_shape = array_ops.shape(rhs)
|
||||
broadcast_batch_shape = array_ops.broadcast_dynamic_shape(
|
||||
rhs_shape[:-2],
|
||||
array_ops.shape(perm)[:-1])
|
||||
d, m = rhs_shape[-2], rhs_shape[-1]
|
||||
rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]],
|
||||
axis=0)
|
||||
|
||||
# Tile out rhs.
|
||||
broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape)
|
||||
broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m])
|
||||
|
||||
# Tile out perm and add batch indices.
|
||||
broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1])
|
||||
broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d])
|
||||
broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape)
|
||||
broadcast_batch_indices = array_ops.broadcast_to(
|
||||
math_ops.range(broadcast_batch_size)[:, array_ops.newaxis],
|
||||
[broadcast_batch_size, d])
|
||||
broadcast_perm = array_ops.stack(
|
||||
[broadcast_batch_indices, broadcast_perm], axis=-1)
|
||||
|
||||
permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm)
|
||||
permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape)
|
||||
|
||||
lower = set_diag(
|
||||
band_part(lower_upper, num_lower=-1, num_upper=0),
|
||||
array_ops.ones(
|
||||
array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
|
||||
return linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
lower_upper, # Only upper is accessed.
|
||||
linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
lower, permuted_rhs),
|
||||
lower=False)
|
||||
|
||||
|
||||
@tf_export('linalg.lu_matrix_inverse')
|
||||
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
|
||||
"""Computes the inverse given the LU decomposition(s) of one or more matrices.
|
||||
|
||||
This op is conceptually identical to,
|
||||
|
||||
```python
|
||||
inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
|
||||
tf.assert_near(tf.matrix_inverse(X), inv_X)
|
||||
# ==> True
|
||||
```
|
||||
|
||||
Note: this function does not verify the implied matrix is actually invertible
|
||||
nor is this condition checked even when `validate_args=True`.
|
||||
|
||||
Args:
|
||||
lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
|
||||
matmul(L, U)) = X` then `lower_upper = L + U - eye`.
|
||||
perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
|
||||
X` then `perm = argmax(P)`.
|
||||
validate_args: Python `bool` indicating whether arguments should be checked
|
||||
for correctness. Note: this function does not verify the implied matrix is
|
||||
actually invertible, even when `validate_args=True`.
|
||||
Default value: `False` (i.e., don't validate arguments).
|
||||
name: Python `str` name given to ops managed by this object.
|
||||
Default value: `None` (i.e., 'lu_matrix_inverse').
|
||||
|
||||
Returns:
|
||||
inv_x: The matrix_inv, i.e.,
|
||||
`tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`.
|
||||
|
||||
#### Examples
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
x = [[[3., 4], [1, 2]],
|
||||
[[7., 8], [3, 4]]]
|
||||
inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x))
|
||||
tf.assert_near(tf.matrix_inverse(x), inv_x)
|
||||
# ==> True
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
with ops.name_scope(name or 'lu_matrix_inverse'):
|
||||
lower_upper = ops.convert_to_tensor(
|
||||
lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
|
||||
perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
|
||||
assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
|
||||
if assertions:
|
||||
with ops.control_dependencies(assertions):
|
||||
lower_upper = array_ops.identity(lower_upper)
|
||||
perm = array_ops.identity(perm)
|
||||
shape = array_ops.shape(lower_upper)
|
||||
return lu_solve(
|
||||
lower_upper,
|
||||
perm,
|
||||
rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype),
|
||||
validate_args=False)
|
||||
|
||||
|
||||
@tf_export('linalg.lu_reconstruct')
|
||||
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
|
||||
"""The reconstruct one or more matrices from their LU decomposition(s).
|
||||
|
||||
Args:
|
||||
lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
|
||||
matmul(L, U)) = X` then `lower_upper = L + U - eye`.
|
||||
perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
|
||||
X` then `perm = argmax(P)`.
|
||||
validate_args: Python `bool` indicating whether arguments should be checked
|
||||
for correctness.
|
||||
Default value: `False` (i.e., don't validate arguments).
|
||||
name: Python `str` name given to ops managed by this object.
|
||||
Default value: `None` (i.e., 'lu_reconstruct').
|
||||
|
||||
Returns:
|
||||
x: The original input to `tf.linalg.lu`, i.e., `x` as in,
|
||||
`lu_reconstruct(*tf.linalg.lu(x))`.
|
||||
|
||||
#### Examples
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
x = [[[3., 4], [1, 2]],
|
||||
[[7., 8], [3, 4]]]
|
||||
x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x))
|
||||
tf.assert_near(x, x_reconstructed)
|
||||
# ==> True
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.name_scope(name or 'lu_reconstruct'):
|
||||
lower_upper = ops.convert_to_tensor(
|
||||
lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
|
||||
perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
|
||||
|
||||
assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
|
||||
if assertions:
|
||||
with ops.control_dependencies(assertions):
|
||||
lower_upper = array_ops.identity(lower_upper)
|
||||
perm = array_ops.identity(perm)
|
||||
|
||||
shape = array_ops.shape(lower_upper)
|
||||
|
||||
lower = set_diag(
|
||||
band_part(lower_upper, num_lower=-1, num_upper=0),
|
||||
array_ops.ones(shape[:-1], dtype=lower_upper.dtype))
|
||||
upper = band_part(lower_upper, num_lower=0, num_upper=-1)
|
||||
x = math_ops.matmul(lower, upper)
|
||||
|
||||
if (lower_upper.shape is None or lower_upper.shape.rank is None or
|
||||
lower_upper.shape.rank != 2):
|
||||
# We either don't know the batch rank or there are >0 batch dims.
|
||||
batch_size = math_ops.reduce_prod(shape[:-2])
|
||||
d = shape[-1]
|
||||
x = array_ops.reshape(x, [batch_size, d, d])
|
||||
perm = array_ops.reshape(perm, [batch_size, d])
|
||||
perm = map_fn.map_fn(array_ops.invert_permutation, perm)
|
||||
batch_indices = array_ops.broadcast_to(
|
||||
math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d])
|
||||
x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm],
|
||||
axis=-1))
|
||||
x = array_ops.reshape(x, shape)
|
||||
else:
|
||||
x = array_ops.gather(x, array_ops.invert_permutation(perm))
|
||||
|
||||
x.set_shape(lower_upper.shape)
|
||||
return x
|
||||
|
||||
|
||||
def lu_reconstruct_assertions(lower_upper, perm, validate_args):
|
||||
"""Returns list of assertions related to `lu_reconstruct` assumptions."""
|
||||
assertions = []
|
||||
|
||||
message = 'Input `lower_upper` must have at least 2 dimensions.'
|
||||
if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2:
|
||||
raise ValueError(message)
|
||||
elif validate_args:
|
||||
assertions.append(
|
||||
check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message))
|
||||
|
||||
message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
|
||||
if lower_upper.shape.rank is not None and perm.shape.rank is not None:
|
||||
if lower_upper.shape.rank != perm.shape.rank + 1:
|
||||
raise ValueError(message)
|
||||
elif validate_args:
|
||||
assertions.append(
|
||||
check_ops.assert_rank(
|
||||
lower_upper, rank=array_ops.rank(perm) + 1, message=message))
|
||||
|
||||
message = '`lower_upper` must be square.'
|
||||
if lower_upper.shape[:-2].is_fully_defined():
|
||||
if lower_upper.shape[-2] != lower_upper.shape[-1]:
|
||||
raise ValueError(message)
|
||||
elif validate_args:
|
||||
m, n = array_ops.split(
|
||||
array_ops.shape(lower_upper)[-2:], num_or_size_splits=2)
|
||||
assertions.append(check_ops.assert_equal(m, n, message=message))
|
||||
|
||||
return assertions
|
||||
|
||||
|
||||
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
|
||||
"""Returns list of assertions related to `lu_solve` assumptions."""
|
||||
assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
|
||||
|
||||
message = 'Input `rhs` must have at least 2 dimensions.'
|
||||
if rhs.shape.ndims is not None:
|
||||
if rhs.shape.ndims < 2:
|
||||
raise ValueError(message)
|
||||
elif validate_args:
|
||||
assertions.append(
|
||||
check_ops.assert_rank_at_least(rhs, rank=2, message=message))
|
||||
|
||||
message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
|
||||
if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None):
|
||||
if lower_upper.shape[-1] != rhs.shape[-2]:
|
||||
raise ValueError(message)
|
||||
elif validate_args:
|
||||
assertions.append(
|
||||
check_ops.assert_equal(
|
||||
array_ops.shape(lower_upper)[-1],
|
||||
array_ops.shape(rhs)[-2],
|
||||
message=message))
|
||||
|
||||
return assertions
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
||||
|
||||
|
||||
################################################################################
|
||||
@ -489,13 +488,13 @@ def _reshape_for_efficiency(a,
|
||||
# Any transposes/adjoints will happen here explicitly, rather than in calling
|
||||
# code. Why? To avoid having to write separate complex code for each case.
|
||||
if adjoint_a:
|
||||
a = linalg.adjoint(a)
|
||||
a = array_ops.matrix_transpose(a, conjugate=True)
|
||||
elif transpose_a:
|
||||
a = linalg.transpose(a)
|
||||
a = array_ops.matrix_transpose(a, conjugate=False)
|
||||
if adjoint_b:
|
||||
b = linalg.adjoint(b)
|
||||
elif transpose_b:
|
||||
b = linalg.transpose(b)
|
||||
b = array_ops.matrix_transpose(b, conjugate=True)
|
||||
elif transpose_a:
|
||||
b = array_ops.matrix_transpose(b, conjugate=False)
|
||||
still_need_to_transpose = False
|
||||
|
||||
# Recompute shapes, since the transpose/adjoint may have changed them.
|
||||
|
@ -152,6 +152,18 @@ tf_module {
|
||||
name: "lu"
|
||||
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lu_matrix_inverse"
|
||||
argspec: "args=[\'lower_upper\', \'perm\', \'validate_args\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lu_reconstruct"
|
||||
argspec: "args=[\'lower_upper\', \'perm\', \'validate_args\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lu_solve"
|
||||
argspec: "args=[\'lower_upper\', \'perm\', \'rhs\', \'validate_args\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "matmul"
|
||||
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
|
||||
|
@ -152,6 +152,18 @@ tf_module {
|
||||
name: "lu"
|
||||
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lu_matrix_inverse"
|
||||
argspec: "args=[\'lower_upper\', \'perm\', \'validate_args\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lu_reconstruct"
|
||||
argspec: "args=[\'lower_upper\', \'perm\', \'validate_args\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lu_solve"
|
||||
argspec: "args=[\'lower_upper\', \'perm\', \'rhs\', \'validate_args\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "matmul"
|
||||
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user