diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py index c95621537db..acb567569c2 100644 --- a/tensorflow/python/ops/parallel_for/math_test.py +++ b/tensorflow/python/ops/parallel_for/math_test.py @@ -648,6 +648,40 @@ class LinalgTest(PForTestCase): self._test_loop_fn(loop_fn, 3) + def test_matrix_inverse(self): + x = (random_ops.random_uniform([3, 4, 2, 2]) + + 10 * linalg_ops.eye(2)) # Ensure well-conditioned. + + for adjoint in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + return linalg_ops.matrix_inverse(array_ops.gather(x, i), + adjoint=adjoint) + + # pylint: enable=cell-var-from-loop + self._test_loop_fn(loop_fn, 2) + + def test_matrix_solve(self): + for adjoint in (True, False): + for stack_a in (True, False): + for stack_b in (True, False): + shape_a = (2, 4, 3, 3) if stack_a else (4, 3, 3) + shape_b = (2, 4, 3, 5) if stack_b else (4, 3, 5) + x = (random_ops.random_uniform(shape_a) + + 10 * linalg_ops.eye(3)) # Ensure well-conditioned. + y = random_ops.random_uniform(shape_b) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) if stack_a else x + b = array_ops.gather(y, i) if stack_b else y + return linalg_ops.matrix_solve(a, b, adjoint=adjoint) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + def test_matrix_triangular_solve(self): for lower in (True, False): for adjoint in (True, False): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 109669d9345..43632c8e062 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -2914,6 +2914,24 @@ def _convert_log_matrix_determinant(pfor_input): return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] +@RegisterPFor("MatrixInverse") +def _convert_matrix_inverse(pfor_input): + t = pfor_input.stacked_input(0) + adjoint = pfor_input.get_attr("adjoint") + return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True) + + +@RegisterPFor("MatrixSolve") +def _convert_matrix_solve(pfor_input): + pfor_input.stack_inputs() + matrix = pfor_input.stacked_input(0) + rhs = pfor_input.stacked_input(1) + adjoint = pfor_input.get_attr("adjoint") + output = gen_linalg_ops.matrix_solve( + matrix, rhs, adjoint=adjoint) + return wrap(output, True) + + @RegisterPFor("MatrixTriangularSolve") def _convert_matrix_triangular_solve(pfor_input): pfor_input.expanddim_inputs_for_broadcast()