Add support for complex in matrix_solve_ls_op.
Split into separate files for each data type to speed up build. PiperOrigin-RevId: 165744539
This commit is contained in:
parent
51441302d4
commit
109ecf823d
23
tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_complex128.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<std::complex<double>>),
|
||||||
|
complex128);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
23
tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_complex64.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<std::complex<float>>),
|
||||||
|
complex64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
23
tensorflow/core/kernels/matrix_solve_ls_op_double.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_double.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
||||||
|
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
23
tensorflow/core/kernels/matrix_solve_ls_op_float.cc
Normal file
23
tensorflow/core/kernels/matrix_solve_ls_op_float.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
||||||
|
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -158,9 +158,4 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
|
|||||||
bool fast_;
|
bool fast_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
|
||||||
REGISTER_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
|
||||||
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float>), float);
|
|
||||||
REGISTER_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double>), double);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
@ -422,7 +422,7 @@ REGISTER_OP("MatrixSolveLs")
|
|||||||
.Input("rhs: T")
|
.Input("rhs: T")
|
||||||
.Input("l2_regularizer: double")
|
.Input("l2_regularizer: double")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {double, float}")
|
.Attr("T: {double, float, complex64, complex128}")
|
||||||
.Attr("fast: bool = True")
|
.Attr("fast: bool = True")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle l2_regularizer;
|
ShapeHandle l2_regularizer;
|
||||||
@ -433,28 +433,31 @@ REGISTER_OP("MatrixSolveLs")
|
|||||||
Solves one or more linear least-squares problems.
|
Solves one or more linear least-squares problems.
|
||||||
|
|
||||||
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
||||||
form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`.
|
form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
|
||||||
|
type as `matrix` and shape `[..., M, K]`.
|
||||||
The output is a tensor shape `[..., N, K]` where each output matrix solves
|
The output is a tensor shape `[..., N, K]` where each output matrix solves
|
||||||
each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]
|
each of the equations
|
||||||
|
`matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
|
||||||
in the least squares sense.
|
in the least squares sense.
|
||||||
|
|
||||||
matrix and right-hand sides in the batch:
|
We use the following notation for (complex) matrix and right-hand sides
|
||||||
|
in the batch:
|
||||||
|
|
||||||
`matrix`=\\(A \in \Re^{m \times n}\\),
|
`matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
|
||||||
`rhs`=\\(B \in \Re^{m \times k}\\),
|
`rhs`=\\(B \in \mathbb{C}^{m \times k}\\),
|
||||||
`output`=\\(X \in \Re^{n \times k}\\),
|
`output`=\\(X \in \mathbb{C}^{n \times k}\\),
|
||||||
`l2_regularizer`=\\(\lambda\\).
|
`l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
|
||||||
|
|
||||||
If `fast` is `True`, then the solution is computed by solving the normal
|
If `fast` is `True`, then the solution is computed by solving the normal
|
||||||
equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
|
equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
|
||||||
\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
|
\\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
|
||||||
problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
|
problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
|
||||||
\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
|
\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
|
||||||
\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
|
\\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
|
||||||
minimum-norm solution to the under-determined linear system, i.e.
|
minimum-norm solution to the under-determined linear system, i.e.
|
||||||
\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||Z||_F^2 \\), subject to
|
\\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
|
||||||
\\(A Z = B\\). Notice that the fast path is only numerically stable when
|
subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
|
||||||
\\(A\\) is numerically full rank and has a condition number
|
when \\(A\\) is numerically full rank and has a condition number
|
||||||
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
|
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
|
||||||
sufficiently large.
|
sufficiently large.
|
||||||
|
|
||||||
|
@ -65,9 +65,12 @@ def BatchRegularizedLeastSquares(matrices, rhss, l2_regularization=0.0):
|
|||||||
class MatrixSolveLsOpTest(test.TestCase):
|
class MatrixSolveLsOpTest(test.TestCase):
|
||||||
|
|
||||||
def _verifySolve(self, x, y):
|
def _verifySolve(self, x, y):
|
||||||
for np_type in [np.float32, np.float64]:
|
for np_type in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||||
a = x.astype(np_type)
|
a = x.astype(np_type)
|
||||||
b = y.astype(np_type)
|
b = y.astype(np_type)
|
||||||
|
if np_type in [np.complex64, np.complex128]:
|
||||||
|
a.imag = a.real
|
||||||
|
b.imag = b.real
|
||||||
np_ans, _, _, _ = np.linalg.lstsq(a, b)
|
np_ans, _, _, _ = np.linalg.lstsq(a, b)
|
||||||
for fast in [True, False]:
|
for fast in [True, False]:
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user