From af926984871a130eec2816815cfc98a362d4f5b6 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Wed, 17 Jun 2020 00:44:26 -0700 Subject: [PATCH] [TF:XLA] Update TF:XLA tests for matrix_triangular_solve to test V1 and V2. TF:V1 raises an error on non-square coefficient matrices TF:V2 allows non-square coefficient matrices. PiperOrigin-RevId: 316839892 Change-Id: I34c2567ba3579c8f0fd4bc6da57abe14bc6471b2 --- .../tests/matrix_triangular_solve_op_test.py | 12 +++++------- .../tf2xla/kernels/matrix_triangular_solve_op.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index b07b254c600..0202c582ef3 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -135,18 +135,16 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): self._VerifyTriangularSolve( a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) - @test_util.run_deprecated_v1 - def testNonSquareCoefficientMatrixV1(self): + def testNonSquareCoefficientMatrix(self): rng = np.random.RandomState(0) for dtype in self.float_types: a = rng.randn(3, 4).astype(dtype) b = rng.randn(4, 4).astype(dtype) - with self.assertRaises(ValueError): - linalg_ops.matrix_triangular_solve(a, b) - with self.assertRaises(ValueError): - linalg_ops.matrix_triangular_solve(a, b) + with self.test_scope(): + with self.assertRaises((ValueError, errors.InvalidArgumentError)): + linalg_ops.matrix_triangular_solve(a, b) - @test_util.run_v2_only + @test_util.run_v2_only # Different error types def testWrongDimensionsV2(self): randn = np.random.RandomState(0).randn for dtype in self.float_types: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 5a719484e05..8d222d947c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -50,6 +50,14 @@ class MatrixTriangularSolveOp : public XlaOpKernel { return; } + auto lhs_size = lhs_shape.dims(); + OP_REQUIRES( + ctx, + lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2), + errors::InvalidArgument("The coefficient matrix must be square in " + "the inner-most two dimensions: ", + lhs_shape.DebugString())); + xla::XlaOp a = ctx->Input(0); xla::XlaOp b = ctx->Input(1); std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);