[TF:XLA] Update TF:XLA tests for matrix_triangular_solve to test V1 and V2.
TF:V1 raises an error on non-square coefficient matrices TF:V2 allows non-square coefficient matrices. PiperOrigin-RevId: 316839892 Change-Id: I34c2567ba3579c8f0fd4bc6da57abe14bc6471b2
This commit is contained in:
parent
806e85998a
commit
af92698487
@ -135,18 +135,16 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
|
||||
self._VerifyTriangularSolve(
|
||||
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonSquareCoefficientMatrixV1(self):
|
||||
def testNonSquareCoefficientMatrix(self):
|
||||
rng = np.random.RandomState(0)
|
||||
for dtype in self.float_types:
|
||||
a = rng.randn(3, 4).astype(dtype)
|
||||
b = rng.randn(4, 4).astype(dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
with self.test_scope():
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
|
||||
@test_util.run_v2_only
|
||||
@test_util.run_v2_only # Different error types
|
||||
def testWrongDimensionsV2(self):
|
||||
randn = np.random.RandomState(0).randn
|
||||
for dtype in self.float_types:
|
||||
|
@ -50,6 +50,14 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
auto lhs_size = lhs_shape.dims();
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2),
|
||||
errors::InvalidArgument("The coefficient matrix must be square in "
|
||||
"the inner-most two dimensions: ",
|
||||
lhs_shape.DebugString()));
|
||||
|
||||
xla::XlaOp a = ctx->Input(0);
|
||||
xla::XlaOp b = ctx->Input(1);
|
||||
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);
|
||||
|
Loading…
Reference in New Issue
Block a user