[TF:XLA] Update TF:XLA tests for matrix_triangular_solve to test V1 and V2.

TF:V1 raises an error on non-square coefficient matrices
TF:V2 allows non-square coefficient matrices.
PiperOrigin-RevId: 316839892
Change-Id: I34c2567ba3579c8f0fd4bc6da57abe14bc6471b2
This commit is contained in:
Tres Popp 2020-06-17 00:44:26 -07:00 committed by TensorFlower Gardener
parent 806e85998a
commit af92698487
2 changed files with 13 additions and 7 deletions

View File

@ -135,18 +135,16 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
self._VerifyTriangularSolve(
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
@test_util.run_deprecated_v1
def testNonSquareCoefficientMatrixV1(self):
def testNonSquareCoefficientMatrix(self):
rng = np.random.RandomState(0)
for dtype in self.float_types:
a = rng.randn(3, 4).astype(dtype)
b = rng.randn(4, 4).astype(dtype)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(a, b)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(a, b)
with self.test_scope():
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
linalg_ops.matrix_triangular_solve(a, b)
@test_util.run_v2_only
@test_util.run_v2_only # Different error types
def testWrongDimensionsV2(self):
randn = np.random.RandomState(0).randn
for dtype in self.float_types:

View File

@ -50,6 +50,14 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
return;
}
auto lhs_size = lhs_shape.dims();
OP_REQUIRES(
ctx,
lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2),
errors::InvalidArgument("The coefficient matrix must be square in "
"the inner-most two dimensions: ",
lhs_shape.DebugString()));
xla::XlaOp a = ctx->Input(0);
xla::XlaOp b = ctx->Input(1);
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);