Add complex tensors support to matrix_determinant
.
PiperOrigin-RevId: 161132422
This commit is contained in:
parent
335f1f14d3
commit
0cbd249e8b
@ -169,10 +169,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
if dtype.is_complex:
|
||||
self.skipTest(
|
||||
"tf.matrix_determinant does not work with complex, so this "
|
||||
"test is being skipped.")
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
@ -190,10 +186,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
for use_placeholder in False, True:
|
||||
for shape in self._shapes_to_test:
|
||||
for dtype in self._dtypes_to_test:
|
||||
if dtype.is_complex:
|
||||
self.skipTest(
|
||||
"tf.matrix_determinant does not work with complex, so this "
|
||||
"test is being skipped.")
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/Eigen/LU"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
@ -62,7 +63,14 @@ class DeterminantOp : public LinearAlgebraOp<Scalar> {
|
||||
|
||||
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
|
||||
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
|
||||
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64);
|
||||
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>),
|
||||
complex128);
|
||||
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
|
||||
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
|
||||
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>),
|
||||
complex64);
|
||||
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>),
|
||||
complex128);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -189,7 +189,7 @@ Status SvdShapeFn(InferenceContext* c) {
|
||||
REGISTER_OP("MatrixDeterminant")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {float, double, complex64, complex128}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
|
||||
@ -560,7 +560,7 @@ REGISTER_OP("BatchSelfAdjointEig")
|
||||
REGISTER_OP("BatchMatrixDeterminant")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Attr("T: {float, double, complex64, complex128}")
|
||||
.Deprecated(13, "Use MatrixDeterminant instead.");
|
||||
|
||||
REGISTER_OP("BatchMatrixInverse")
|
||||
|
@ -66,6 +66,41 @@ class DeterminantOpTest(test.TestCase):
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64))
|
||||
|
||||
def testBasicComplex64(self):
|
||||
# 2x2 matrices
|
||||
self._compareDeterminant(
|
||||
np.array([[2., 3.], [3., 4.]]).astype(np.complex64))
|
||||
self._compareDeterminant(
|
||||
np.array([[0., 0.], [0., 0.]]).astype(np.complex64))
|
||||
self._compareDeterminant(
|
||||
np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype(
|
||||
np.complex64))
|
||||
# 5x5 matrices (Eigen forces LU decomposition)
|
||||
self._compareDeterminant(
|
||||
np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [
|
||||
2., 5., 8., 3., 8.
|
||||
], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex64))
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.complex64))
|
||||
|
||||
def testBasicComplex128(self):
|
||||
# 2x2 matrices
|
||||
self._compareDeterminant(
|
||||
np.array([[2., 3.], [3., 4.]]).astype(np.complex128))
|
||||
self._compareDeterminant(
|
||||
np.array([[0., 0.], [0., 0.]]).astype(np.complex128))
|
||||
self._compareDeterminant(
|
||||
np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype(
|
||||
np.complex128))
|
||||
# 5x5 matrices (Eigen forces LU decomposition)
|
||||
self._compareDeterminant(
|
||||
np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [
|
||||
2., 5., 8., 3., 8.
|
||||
], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex128))
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._compareDeterminant(
|
||||
np.random.rand(3, 4, 5, 2, 2).astype(np.complex128))
|
||||
|
||||
def testOverflow(self):
|
||||
max_double = np.finfo("d").max
|
||||
huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]])
|
||||
|
Loading…
Reference in New Issue
Block a user