Support matrix inverse and solves in pfor/vectorized_map.

PiperOrigin-RevId: 294370231
Change-Id: I975c4e4c12a4891106118162e07e7e67a079d8cf
This commit is contained in:
A. Unique TensorFlower 2020-02-10 21:36:48 -08:00 committed by TensorFlower Gardener
parent 1c53dd246b
commit 91b16c867d
2 changed files with 52 additions and 0 deletions

View File

@ -648,6 +648,40 @@ class LinalgTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_matrix_inverse(self):
x = (random_ops.random_uniform([3, 4, 2, 2]) +
10 * linalg_ops.eye(2)) # Ensure well-conditioned.
for adjoint in (True, False):
# pylint: disable=cell-var-from-loop
def loop_fn(i):
return linalg_ops.matrix_inverse(array_ops.gather(x, i),
adjoint=adjoint)
# pylint: enable=cell-var-from-loop
self._test_loop_fn(loop_fn, 2)
def test_matrix_solve(self):
for adjoint in (True, False):
for stack_a in (True, False):
for stack_b in (True, False):
shape_a = (2, 4, 3, 3) if stack_a else (4, 3, 3)
shape_b = (2, 4, 3, 5) if stack_b else (4, 3, 5)
x = (random_ops.random_uniform(shape_a) +
10 * linalg_ops.eye(3)) # Ensure well-conditioned.
y = random_ops.random_uniform(shape_b)
# pylint: disable=cell-var-from-loop
def loop_fn(i):
a = array_ops.gather(x, i) if stack_a else x
b = array_ops.gather(y, i) if stack_b else y
return linalg_ops.matrix_solve(a, b, adjoint=adjoint)
# pylint: enable=cell-var-from-loop
self._test_loop_fn(loop_fn, 2)
def test_matrix_triangular_solve(self):
for lower in (True, False):
for adjoint in (True, False):

View File

@ -2914,6 +2914,24 @@ def _convert_log_matrix_determinant(pfor_input):
return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)]
@RegisterPFor("MatrixInverse")
def _convert_matrix_inverse(pfor_input):
t = pfor_input.stacked_input(0)
adjoint = pfor_input.get_attr("adjoint")
return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True)
@RegisterPFor("MatrixSolve")
def _convert_matrix_solve(pfor_input):
pfor_input.stack_inputs()
matrix = pfor_input.stacked_input(0)
rhs = pfor_input.stacked_input(1)
adjoint = pfor_input.get_attr("adjoint")
output = gen_linalg_ops.matrix_solve(
matrix, rhs, adjoint=adjoint)
return wrap(output, True)
@RegisterPFor("MatrixTriangularSolve")
def _convert_matrix_triangular_solve(pfor_input):
pfor_input.expanddim_inputs_for_broadcast()