Don't crash on empty RHS in matrix_triangular_solve on GPU.

PiperOrigin-RevId: 279424548
Change-Id: I59b10866bd99f9e92e8714fc1b1a05674b301ce8
This commit is contained in:
A. Unique TensorFlower 2019-11-08 17:32:50 -08:00 committed by TensorFlower Gardener
parent c82df39b0c
commit 755aec33ef
2 changed files with 4 additions and 3 deletions

View File

@ -175,7 +175,7 @@ class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
const ConstMatrixMap& rhs = inputs[1];
MatrixMap& output = outputs->at(0);
if (matrix.rows() == 0 || rhs.cols() == 0) {
if (matrix.rows() == 0 || rhs.rows() == 0 || rhs.cols() == 0) {
// To be consistent with the MatrixInverse op, we define the solution for
// an empty set of equation as the empty matrix.
return;

View File

@ -153,7 +153,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.
matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaises(ValueError):
self._verifySolve(matrix, matrix)
with self.assertRaises(ValueError):
@ -165,7 +165,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# right-hand sides.
matrix = np.array([[1., 0.], [0., 1.]])
rhs = np.array([[1., 0.]])
with self.cached_session():
with self.cached_session(use_gpu=True):
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
@ -176,6 +176,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
def testNotInvertible(self):
# The input should be invertible.
# The matrix is singular because it has a zero on the diagonal.
# FIXME(rmlarsen): The GPU kernel does not check for singularity.
singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
with self.cached_session():
with self.assertRaisesOpError("Input matrix is not invertible."):