[TF:XLA] Update TF:XLA tests for matrix_triangular_solve to test V1 and V2.
TF:V1 raises an error on non-square coefficient matrices TF:V2 allows non-square coefficient matrices. PiperOrigin-RevId: 316839892 Change-Id: I34c2567ba3579c8f0fd4bc6da57abe14bc6471b2
This commit is contained in:
parent
806e85998a
commit
af92698487
@ -135,18 +135,16 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
|
|||||||
self._VerifyTriangularSolve(
|
self._VerifyTriangularSolve(
|
||||||
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
|
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
def testNonSquareCoefficientMatrix(self):
|
||||||
def testNonSquareCoefficientMatrixV1(self):
|
|
||||||
rng = np.random.RandomState(0)
|
rng = np.random.RandomState(0)
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
a = rng.randn(3, 4).astype(dtype)
|
a = rng.randn(3, 4).astype(dtype)
|
||||||
b = rng.randn(4, 4).astype(dtype)
|
b = rng.randn(4, 4).astype(dtype)
|
||||||
with self.assertRaises(ValueError):
|
with self.test_scope():
|
||||||
linalg_ops.matrix_triangular_solve(a, b)
|
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||||
with self.assertRaises(ValueError):
|
linalg_ops.matrix_triangular_solve(a, b)
|
||||||
linalg_ops.matrix_triangular_solve(a, b)
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only # Different error types
|
||||||
def testWrongDimensionsV2(self):
|
def testWrongDimensionsV2(self):
|
||||||
randn = np.random.RandomState(0).randn
|
randn = np.random.RandomState(0).randn
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
|
@ -50,6 +50,14 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto lhs_size = lhs_shape.dims();
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2),
|
||||||
|
errors::InvalidArgument("The coefficient matrix must be square in "
|
||||||
|
"the inner-most two dimensions: ",
|
||||||
|
lhs_shape.DebugString()));
|
||||||
|
|
||||||
xla::XlaOp a = ctx->Input(0);
|
xla::XlaOp a = ctx->Input(0);
|
||||||
xla::XlaOp b = ctx->Input(1);
|
xla::XlaOp b = ctx->Input(1);
|
||||||
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);
|
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);
|
||||||
|
Loading…
Reference in New Issue
Block a user