Add complex tensors support to matrix_determinant
.
PiperOrigin-RevId: 161132422
This commit is contained in:
parent
335f1f14d3
commit
0cbd249e8b
@ -169,10 +169,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
|||||||
for use_placeholder in False, True:
|
for use_placeholder in False, True:
|
||||||
for shape in self._shapes_to_test:
|
for shape in self._shapes_to_test:
|
||||||
for dtype in self._dtypes_to_test:
|
for dtype in self._dtypes_to_test:
|
||||||
if dtype.is_complex:
|
|
||||||
self.skipTest(
|
|
||||||
"tf.matrix_determinant does not work with complex, so this "
|
|
||||||
"test is being skipped.")
|
|
||||||
with self.test_session(graph=ops.Graph()) as sess:
|
with self.test_session(graph=ops.Graph()) as sess:
|
||||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||||
@ -190,10 +186,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
|||||||
for use_placeholder in False, True:
|
for use_placeholder in False, True:
|
||||||
for shape in self._shapes_to_test:
|
for shape in self._shapes_to_test:
|
||||||
for dtype in self._dtypes_to_test:
|
for dtype in self._dtypes_to_test:
|
||||||
if dtype.is_complex:
|
|
||||||
self.skipTest(
|
|
||||||
"tf.matrix_determinant does not work with complex, so this "
|
|
||||||
"test is being skipped.")
|
|
||||||
with self.test_session(graph=ops.Graph()) as sess:
|
with self.test_session(graph=ops.Graph()) as sess:
|
||||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "third_party/eigen3/Eigen/LU"
|
#include "third_party/eigen3/Eigen/LU"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||||
@ -62,7 +63,14 @@ class DeterminantOp : public LinearAlgebraOp<Scalar> {
|
|||||||
|
|
||||||
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
|
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
|
||||||
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
|
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
|
||||||
|
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64);
|
||||||
|
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>),
|
||||||
|
complex128);
|
||||||
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
|
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
|
||||||
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
|
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
|
||||||
|
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>),
|
||||||
|
complex64);
|
||||||
|
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>),
|
||||||
|
complex128);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -189,7 +189,7 @@ Status SvdShapeFn(InferenceContext* c) {
|
|||||||
REGISTER_OP("MatrixDeterminant")
|
REGISTER_OP("MatrixDeterminant")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, double}")
|
.Attr("T: {float, double, complex64, complex128}")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle input;
|
ShapeHandle input;
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
|
||||||
@ -560,7 +560,7 @@ REGISTER_OP("BatchSelfAdjointEig")
|
|||||||
REGISTER_OP("BatchMatrixDeterminant")
|
REGISTER_OP("BatchMatrixDeterminant")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, double}")
|
.Attr("T: {float, double, complex64, complex128}")
|
||||||
.Deprecated(13, "Use MatrixDeterminant instead.");
|
.Deprecated(13, "Use MatrixDeterminant instead.");
|
||||||
|
|
||||||
REGISTER_OP("BatchMatrixInverse")
|
REGISTER_OP("BatchMatrixInverse")
|
||||||
|
@ -66,6 +66,41 @@ class DeterminantOpTest(test.TestCase):
|
|||||||
# A multidimensional batch of 2x2 matrices
|
# A multidimensional batch of 2x2 matrices
|
||||||
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64))
|
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64))
|
||||||
|
|
||||||
|
def testBasicComplex64(self):
|
||||||
|
# 2x2 matrices
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[2., 3.], [3., 4.]]).astype(np.complex64))
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[0., 0.], [0., 0.]]).astype(np.complex64))
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype(
|
||||||
|
np.complex64))
|
||||||
|
# 5x5 matrices (Eigen forces LU decomposition)
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [
|
||||||
|
2., 5., 8., 3., 8.
|
||||||
|
], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex64))
|
||||||
|
# A multidimensional batch of 2x2 matrices
|
||||||
|
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.complex64))
|
||||||
|
|
||||||
|
def testBasicComplex128(self):
|
||||||
|
# 2x2 matrices
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[2., 3.], [3., 4.]]).astype(np.complex128))
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[0., 0.], [0., 0.]]).astype(np.complex128))
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype(
|
||||||
|
np.complex128))
|
||||||
|
# 5x5 matrices (Eigen forces LU decomposition)
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [
|
||||||
|
2., 5., 8., 3., 8.
|
||||||
|
], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex128))
|
||||||
|
# A multidimensional batch of 2x2 matrices
|
||||||
|
self._compareDeterminant(
|
||||||
|
np.random.rand(3, 4, 5, 2, 2).astype(np.complex128))
|
||||||
|
|
||||||
def testOverflow(self):
|
def testOverflow(self):
|
||||||
max_double = np.finfo("d").max
|
max_double = np.finfo("d").max
|
||||||
huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]])
|
huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]])
|
||||||
|
Loading…
Reference in New Issue
Block a user