diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index 1924ea0f662..af14f34600e 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -169,10 +169,6 @@ class LinearOperatorDerivedClassTest(test.TestCase): for use_placeholder in False, True: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - if dtype.is_complex: - self.skipTest( - "tf.matrix_determinant does not work with complex, so this " - "test is being skipped.") with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( @@ -190,10 +186,6 @@ class LinearOperatorDerivedClassTest(test.TestCase): for use_placeholder in False, True: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - if dtype.is_complex: - self.skipTest( - "tf.matrix_determinant does not work with complex, so this " - "test is being skipped.") with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( diff --git a/tensorflow/core/kernels/determinant_op.cc b/tensorflow/core/kernels/determinant_op.cc index d51563580b0..d7e55a8ba24 100644 --- a/tensorflow/core/kernels/determinant_op.cc +++ b/tensorflow/core/kernels/determinant_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/LU" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/linalg_ops_common.h" @@ -62,7 +63,14 @@ class DeterminantOp : public LinearAlgebraOp { REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), float); REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), double); +REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), complex64); +REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp), + complex128); REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp), float); REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp), double); +REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp), + complex64); +REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp), + complex128); } // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 6e1f2dc0529..b0f95c91fdf 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -189,7 +189,7 @@ Status SvdShapeFn(InferenceContext* c) { REGISTER_OP("MatrixDeterminant") .Input("input: T") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {float, double, complex64, complex128}") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); @@ -560,7 +560,7 @@ REGISTER_OP("BatchSelfAdjointEig") REGISTER_OP("BatchMatrixDeterminant") .Input("input: T") .Output("output: T") - .Attr("T: {float, double}") + .Attr("T: {float, double, complex64, complex128}") .Deprecated(13, "Use MatrixDeterminant instead."); REGISTER_OP("BatchMatrixInverse") diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py index 2d05ab61390..089ec0de795 100644 --- a/tensorflow/python/kernel_tests/determinant_op_test.py +++ b/tensorflow/python/kernel_tests/determinant_op_test.py @@ -66,6 +66,41 @@ class DeterminantOpTest(test.TestCase): # A multidimensional batch of 2x2 matrices self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64)) + def testBasicComplex64(self): + # 2x2 matrices + self._compareDeterminant( + np.array([[2., 3.], [3., 4.]]).astype(np.complex64)) + self._compareDeterminant( + np.array([[0., 0.], [0., 0.]]).astype(np.complex64)) + self._compareDeterminant( + np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype( + np.complex64)) + # 5x5 matrices (Eigen forces LU decomposition) + self._compareDeterminant( + np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [ + 2., 5., 8., 3., 8. + ], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex64)) + # A multidimensional batch of 2x2 matrices + self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.complex64)) + + def testBasicComplex128(self): + # 2x2 matrices + self._compareDeterminant( + np.array([[2., 3.], [3., 4.]]).astype(np.complex128)) + self._compareDeterminant( + np.array([[0., 0.], [0., 0.]]).astype(np.complex128)) + self._compareDeterminant( + np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype( + np.complex128)) + # 5x5 matrices (Eigen forces LU decomposition) + self._compareDeterminant( + np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [ + 2., 5., 8., 3., 8. + ], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex128)) + # A multidimensional batch of 2x2 matrices + self._compareDeterminant( + np.random.rand(3, 4, 5, 2, 2).astype(np.complex128)) + def testOverflow(self): max_double = np.finfo("d").max huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]])