Support matrix inverse and solves in pfor/vectorized_map.
PiperOrigin-RevId: 294370231 Change-Id: I975c4e4c12a4891106118162e07e7e67a079d8cf
This commit is contained in:
parent
1c53dd246b
commit
91b16c867d
@ -648,6 +648,40 @@ class LinalgTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_matrix_inverse(self):
|
||||
x = (random_ops.random_uniform([3, 4, 2, 2]) +
|
||||
10 * linalg_ops.eye(2)) # Ensure well-conditioned.
|
||||
|
||||
for adjoint in (True, False):
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def loop_fn(i):
|
||||
return linalg_ops.matrix_inverse(array_ops.gather(x, i),
|
||||
adjoint=adjoint)
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_matrix_solve(self):
|
||||
for adjoint in (True, False):
|
||||
for stack_a in (True, False):
|
||||
for stack_b in (True, False):
|
||||
shape_a = (2, 4, 3, 3) if stack_a else (4, 3, 3)
|
||||
shape_b = (2, 4, 3, 5) if stack_b else (4, 3, 5)
|
||||
x = (random_ops.random_uniform(shape_a) +
|
||||
10 * linalg_ops.eye(3)) # Ensure well-conditioned.
|
||||
y = random_ops.random_uniform(shape_b)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def loop_fn(i):
|
||||
a = array_ops.gather(x, i) if stack_a else x
|
||||
b = array_ops.gather(y, i) if stack_b else y
|
||||
return linalg_ops.matrix_solve(a, b, adjoint=adjoint)
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_matrix_triangular_solve(self):
|
||||
for lower in (True, False):
|
||||
for adjoint in (True, False):
|
||||
|
@ -2914,6 +2914,24 @@ def _convert_log_matrix_determinant(pfor_input):
|
||||
return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
|
||||
|
||||
|
||||
@RegisterPFor("MatrixInverse")
|
||||
def _convert_matrix_inverse(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
adjoint = pfor_input.get_attr("adjoint")
|
||||
return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True)
|
||||
|
||||
|
||||
@RegisterPFor("MatrixSolve")
|
||||
def _convert_matrix_solve(pfor_input):
|
||||
pfor_input.stack_inputs()
|
||||
matrix = pfor_input.stacked_input(0)
|
||||
rhs = pfor_input.stacked_input(1)
|
||||
adjoint = pfor_input.get_attr("adjoint")
|
||||
output = gen_linalg_ops.matrix_solve(
|
||||
matrix, rhs, adjoint=adjoint)
|
||||
return wrap(output, True)
|
||||
|
||||
|
||||
@RegisterPFor("MatrixTriangularSolve")
|
||||
def _convert_matrix_triangular_solve(pfor_input):
|
||||
pfor_input.expanddim_inputs_for_broadcast()
|
||||
|
Loading…
Reference in New Issue
Block a user