diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc b/tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc new file mode 100644 index 00000000000..22274cc3daf --- /dev/null +++ b/tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp>), + complex128); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc b/tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc new file mode 100644 index 00000000000..c8421a3efba --- /dev/null +++ b/tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp>), + complex64); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_double.cc b/tensorflow/core/kernels/matrix_solve_ls_op_double.cc new file mode 100644 index 00000000000..c7d03cb1052 --- /dev/null +++ b/tensorflow/core/kernels/matrix_solve_ls_op_double.cc @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp), double); +REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp), double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_float.cc b/tensorflow/core/kernels/matrix_solve_ls_op_float.cc new file mode 100644 index 00000000000..c98a84beded --- /dev/null +++ b/tensorflow/core/kernels/matrix_solve_ls_op_float.cc @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp), float); +REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp), float); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h similarity index 96% rename from tensorflow/core/kernels/matrix_solve_ls_op.cc rename to tensorflow/core/kernels/matrix_solve_ls_op_impl.h index 381a5ec7b9d..0e09078365e 100644 --- a/tensorflow/core/kernels/matrix_solve_ls_op.cc +++ b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h @@ -158,9 +158,4 @@ class MatrixSolveLsOp : public LinearAlgebraOp { bool fast_; }; -REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp), float); -REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp), double); -REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp), float); -REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp), double); - } // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 52f69f76a4f..5b75bda1f1b 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -422,7 +422,7 @@ REGISTER_OP("MatrixSolveLs") .Input("rhs: T") .Input("l2_regularizer: double") .Output("output: T") - .Attr("T: {double, float}") + .Attr("T: {double, float, complex64, complex128}") .Attr("fast: bool = True") .SetShapeFn([](InferenceContext* c) { ShapeHandle l2_regularizer; @@ -433,28 +433,31 @@ REGISTER_OP("MatrixSolveLs") Solves one or more linear least-squares problems. `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`. +form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same +type as `matrix` and shape `[..., M, K]`. The output is a tensor shape `[..., N, K]` where each output matrix solves -each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] +each of the equations +`matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` in the least squares sense. -matrix and right-hand sides in the batch: +We use the following notation for (complex) matrix and right-hand sides +in the batch: -`matrix`=\\(A \in \Re^{m \times n}\\), -`rhs`=\\(B \in \Re^{m \times k}\\), -`output`=\\(X \in \Re^{n \times k}\\), -`l2_regularizer`=\\(\lambda\\). +`matrix`=\\(A \in \mathbb{C}^{m \times n}\\), +`rhs`=\\(B \in \mathbb{C}^{m \times k}\\), +`output`=\\(X \in \mathbb{C}^{n \times k}\\), +`l2_regularizer`=\\(\lambda \in \mathbb{R}\\). If `fast` is `True`, then the solution is computed by solving the normal equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares +\\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as -\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +\\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the minimum-norm solution to the under-determined linear system, i.e. -\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to -\\(A Z = B\\). Notice that the fast path is only numerically stable when -\\(A\\) is numerically full rank and has a condition number +\\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), +subject to \\(A Z = B\\). Notice that the fast path is only numerically stable +when \\(A\\) is numerically full rank and has a condition number \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is sufficiently large. diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py index 9a7645ff616..ece222fefc8 100644 --- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py @@ -65,9 +65,12 @@ def BatchRegularizedLeastSquares(matrices, rhss, l2_regularization=0.0): class MatrixSolveLsOpTest(test.TestCase): def _verifySolve(self, x, y): - for np_type in [np.float32, np.float64]: + for np_type in [np.float32, np.float64, np.complex64, np.complex128]: a = x.astype(np_type) b = y.astype(np_type) + if np_type in [np.complex64, np.complex128]: + a.imag = a.real + b.imag = b.real np_ans, _, _, _ = np.linalg.lstsq(a, b) for fast in [True, False]: with self.test_session():