Add support for complex in matrix_solve_ls_op.
Split into separate files for each data type to speed up build. PiperOrigin-RevId: 165744539
This commit is contained in:
parent
51441302d4
commit
109ecf823d
23
tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<std::complex<double>>),
|
||||
complex128);
|
||||
|
||||
} // namespace tensorflow
|
23
tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<std::complex<float>>),
|
||||
complex64);
|
||||
|
||||
} // namespace tensorflow
|
23
tensorflow/core/kernels/matrix_solve_ls_op_double.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_double.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
||||
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
||||
|
||||
} // namespace tensorflow
|
23
tensorflow/core/kernels/matrix_solve_ls_op_float.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_float.cc
Normal file
@ -0,0 +1,23 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
||||
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
||||
|
||||
} // namespace tensorflow
|
@ -158,9 +158,4 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
|
||||
bool fast_;
|
||||
};
|
||||
|
||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
||||
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
||||
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
||||
|
||||
} // namespace tensorflow
|
@ -422,7 +422,7 @@ REGISTER_OP("MatrixSolveLs")
|
||||
.Input("rhs: T")
|
||||
.Input("l2_regularizer: double")
|
||||
.Output("output: T")
|
||||
.Attr("T: {double, float}")
|
||||
.Attr("T: {double, float, complex64, complex128}")
|
||||
.Attr("fast: bool = True")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle l2_regularizer;
|
||||
@ -433,28 +433,31 @@ REGISTER_OP("MatrixSolveLs")
|
||||
Solves one or more linear least-squares problems.
|
||||
|
||||
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
||||
form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`.
|
||||
form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
|
||||
type as `matrix` and shape `[..., M, K]`.
|
||||
The output is a tensor shape `[..., N, K]` where each output matrix solves
|
||||
each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]
|
||||
each of the equations
|
||||
`matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
|
||||
in the least squares sense.
|
||||
|
||||
matrix and right-hand sides in the batch:
|
||||
We use the following notation for (complex) matrix and right-hand sides
|
||||
in the batch:
|
||||
|
||||
`matrix`=\\(A \in \Re^{m \times n}\\),
|
||||
`rhs`=\\(B \in \Re^{m \times k}\\),
|
||||
`output`=\\(X \in \Re^{n \times k}\\),
|
||||
`l2_regularizer`=\\(\lambda\\).
|
||||
`matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
|
||||
`rhs`=\\(B \in \mathbb{C}^{m \times k}\\),
|
||||
`output`=\\(X \in \mathbb{C}^{n \times k}\\),
|
||||
`l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
|
||||
|
||||
If `fast` is `True`, then the solution is computed by solving the normal
|
||||
equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
|
||||
\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
|
||||
\\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
|
||||
problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
|
||||
\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
|
||||
\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
|
||||
\\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
|
||||
minimum-norm solution to the under-determined linear system, i.e.
|
||||
\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to
|
||||
\\(A Z = B\\). Notice that the fast path is only numerically stable when
|
||||
\\(A\\) is numerically full rank and has a condition number
|
||||
\\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
|
||||
subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
|
||||
when \\(A\\) is numerically full rank and has a condition number
|
||||
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
|
||||
sufficiently large.
|
||||
|
||||
|
@ -65,9 +65,12 @@ def BatchRegularizedLeastSquares(matrices, rhss, l2_regularization=0.0):
|
||||
class MatrixSolveLsOpTest(test.TestCase):
|
||||
|
||||
def _verifySolve(self, x, y):
|
||||
for np_type in [np.float32, np.float64]:
|
||||
for np_type in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
a = x.astype(np_type)
|
||||
b = y.astype(np_type)
|
||||
if np_type in [np.complex64, np.complex128]:
|
||||
a.imag = a.real
|
||||
b.imag = b.real
|
||||
np_ans, _, _, _ = np.linalg.lstsq(a, b)
|
||||
for fast in [True, False]:
|
||||
with self.test_session():
|
||||
|
Loading…
x
Reference in New Issue
Block a user