diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index b0c78469118..b8f04f7d791 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -300,12 +300,14 @@ bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const { bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const { // b/128001705: SelfAdjointEigV2 and Svd performance issues. // b/135640736: MatrixInverse performance issues. + // b/111271662: MatrixSolve performance issues. // https://github.com/tensorflow/tensorflow/pull/31012: // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes // create convolutions too large for CuDNN to handle. return node.type_string() == "SelfAdjointEigV2" || node.type_string() == "Svd" || node.type_string() == "Qr" || node.type_string() == "MatrixInverse" || + node.type_string() == "MatrixSolve" || node.type_string() == "ResizeNearestNeighbor" || node.type_string() == "ResizeBilinear" || node.type_string() == "ResizeBilinearGrad"; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 66cc89f10eb..dd9cf615e4d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -311,6 +311,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_solve_op_test", + size = "small", + timeout = "moderate", + srcs = ["matrix_solve_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", diff --git a/tensorflow/compiler/tests/matrix_solve_op_test.py b/tensorflow/compiler/tests/matrix_solve_op_test.py new file mode 100644 index 00000000000..fb12c6c453e --- /dev/null +++ b/tensorflow/compiler/tests/matrix_solve_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA implementation of tf.linalg.solve.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import googletest + + +class MatrixSolveOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _verifySolve(self, x, y, adjoint): + for np_type in self.float_types & {np.float32, np.float64}: + tol = 1e-4 if np_type == np.float32 else 1e-12 + a = x.astype(np_type) + b = y.astype(np_type) + np_ans = np.linalg.solve(np.swapaxes(a, -2, -1) if adjoint else a, b) + with self.session() as sess: + with self.test_scope(): + tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) + out = sess.run(tf_ans) + self.assertEqual(tf_ans.shape, out.shape) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out, atol=tol, rtol=tol) + + @parameterized.named_parameters( + ("Scalar", 1, 1, [], [], False), + ("Vector", 5, 1, [], [], False), + ("MultipleRHS", 5, 4, [], [], False), + ("Adjoint", 5, 4, [], [], True), + ("BatchedScalar", 1, 4, [2], [2], False), + ("BatchedVector", 5, 4, [2], [2], False), + ("BatchedRank2", 5, 4, [7, 4], [7, 4], False), + ("BatchedAdjoint", 5, 4, [7, 4], [7, 4], True), + ) + def testSolve(self, n, nrhs, batch_dims, rhs_batch_dims, adjoint): + matrix = np.random.normal(-5.0, 5.0, batch_dims + [n, n]) + rhs = np.random.normal(-5.0, 5.0, rhs_batch_dims + [n, nrhs]) + self._verifySolve(matrix, rhs, adjoint=adjoint) + + @parameterized.named_parameters( + ("Simple", False), + ("Adjoint", True), + ) + def testConcurrent(self, adjoint): + with self.session() as sess: + lhs1 = random_ops.random_normal([3, 3], seed=42) + lhs2 = random_ops.random_normal([3, 3], seed=42) + rhs1 = random_ops.random_normal([3, 3], seed=42) + rhs2 = random_ops.random_normal([3, 3], seed=42) + with self.test_scope(): + s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint) + s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint) + self.assertAllEqual(*sess.run([s1, s2])) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index a3ae36ce0d5..b6e89c6540b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -57,6 +57,7 @@ tf_kernel_library( "matrix_band_part_op.cc", "matrix_diag_ops.cc", "matrix_inverse_op.cc", + "matrix_solve_op.cc", "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", "next_after_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc new file mode 100644 index 00000000000..8a4e71068b8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace tensorflow { +namespace { + +class MatrixSolveOp : public XlaOpKernel { + public: + explicit MatrixSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape matrix_shape = ctx->InputShape(0); + int64 matrix_ndims = matrix_shape.dims(); + OP_REQUIRES(ctx, matrix_ndims >= 2, + errors::InvalidArgument( + "Input matrix must have rank >= 2, got ", matrix_ndims)); + OP_REQUIRES(ctx, + matrix_shape.dim_size(matrix_ndims - 2) == + matrix_shape.dim_size(matrix_ndims - 1), + errors::InvalidArgument( + "Input matrices must be square, got", + matrix_shape.dim_size(matrix_ndims - 2), + " != ", matrix_shape.dim_size(matrix_ndims - 1))); + + xla::XlaOp matrix = ctx->Input(0); + xla::XlaOp rhs = ctx->Input(1); + + // TODO(b/111271662): Using LU decomposition instead of QR should be faster. + auto qr = xla::QRDecomposition(matrix, /*full_matrices=*/false); + OP_REQUIRES_OK(ctx, qr.status()); + + xla::XlaOp inv = xla::TriangularSolve( + qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q), + /*left_side=*/true, + /*lower=*/false, /*unit_diagonal=*/false, + /*transpose_a=*/ + xla::TriangularSolveOptions::NO_TRANSPOSE); + + xla::XlaOp output = + xla::BatchDot(inv, adjoint_, rhs, + /*transpose_y=*/false, xla::PrecisionConfig::HIGHEST); + ctx->SetOutput(0, output); + } + + private: + bool adjoint_; + + TF_DISALLOW_COPY_AND_ASSIGN(MatrixSolveOp); +}; + +// TODO(b/111271662): Support integer and complex types. +REGISTER_XLA_OP(Name("MatrixSolve").TypeConstraint("T", kFloatTypes), + MatrixSolveOp); + +} // namespace +} // namespace tensorflow