Add complex tensors support to matrix_determinant.

PiperOrigin-RevId: 161132422
This commit is contained in:
A. Unique TensorFlower 2017-07-06 14:37:39 -07:00 committed by TensorFlower Gardener
parent 335f1f14d3
commit 0cbd249e8b
4 changed files with 45 additions and 10 deletions

View File

@ -169,10 +169,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in False, True: for use_placeholder in False, True:
for shape in self._shapes_to_test: for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test: for dtype in self._dtypes_to_test:
if dtype.is_complex:
self.skipTest(
"tf.matrix_determinant does not work with complex, so this "
"test is being skipped.")
with self.test_session(graph=ops.Graph()) as sess: with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
@ -190,10 +186,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in False, True: for use_placeholder in False, True:
for shape in self._shapes_to_test: for shape in self._shapes_to_test:
for dtype in self._dtypes_to_test: for dtype in self._dtypes_to_test:
if dtype.is_complex:
self.skipTest(
"tf.matrix_determinant does not work with complex, so this "
"test is being skipped.")
with self.test_session(graph=ops.Graph()) as sess: with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/LU" #include "third_party/eigen3/Eigen/LU"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h" #include "tensorflow/core/kernels/linalg_ops_common.h"
@ -62,7 +63,14 @@ class DeterminantOp : public LinearAlgebraOp<Scalar> {
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float); REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double); REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64);
REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>),
complex128);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float); REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double); REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>),
complex64);
REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>),
complex128);
} // namespace tensorflow } // namespace tensorflow

View File

@ -189,7 +189,7 @@ Status SvdShapeFn(InferenceContext* c) {
REGISTER_OP("MatrixDeterminant") REGISTER_OP("MatrixDeterminant")
.Input("input: T") .Input("input: T")
.Output("output: T") .Output("output: T")
.Attr("T: {float, double}") .Attr("T: {float, double, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
ShapeHandle input; ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
@ -560,7 +560,7 @@ REGISTER_OP("BatchSelfAdjointEig")
REGISTER_OP("BatchMatrixDeterminant") REGISTER_OP("BatchMatrixDeterminant")
.Input("input: T") .Input("input: T")
.Output("output: T") .Output("output: T")
.Attr("T: {float, double}") .Attr("T: {float, double, complex64, complex128}")
.Deprecated(13, "Use MatrixDeterminant instead."); .Deprecated(13, "Use MatrixDeterminant instead.");
REGISTER_OP("BatchMatrixInverse") REGISTER_OP("BatchMatrixInverse")

View File

@ -66,6 +66,41 @@ class DeterminantOpTest(test.TestCase):
# A multidimensional batch of 2x2 matrices # A multidimensional batch of 2x2 matrices
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64)) self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.float64))
def testBasicComplex64(self):
# 2x2 matrices
self._compareDeterminant(
np.array([[2., 3.], [3., 4.]]).astype(np.complex64))
self._compareDeterminant(
np.array([[0., 0.], [0., 0.]]).astype(np.complex64))
self._compareDeterminant(
np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype(
np.complex64))
# 5x5 matrices (Eigen forces LU decomposition)
self._compareDeterminant(
np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [
2., 5., 8., 3., 8.
], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex64))
# A multidimensional batch of 2x2 matrices
self._compareDeterminant(np.random.rand(3, 4, 5, 2, 2).astype(np.complex64))
def testBasicComplex128(self):
# 2x2 matrices
self._compareDeterminant(
np.array([[2., 3.], [3., 4.]]).astype(np.complex128))
self._compareDeterminant(
np.array([[0., 0.], [0., 0.]]).astype(np.complex128))
self._compareDeterminant(
np.array([[1. + 1.j, 1. - 1.j], [-1. + 1.j, -1. - 1.j]]).astype(
np.complex128))
# 5x5 matrices (Eigen forces LU decomposition)
self._compareDeterminant(
np.array([[2., 3., 4., 5., 6.], [3., 4., 9., 2., 0.], [
2., 5., 8., 3., 8.
], [1., 6., 7., 4., 7.], [2., 3., 4., 5., 6.]]).astype(np.complex128))
# A multidimensional batch of 2x2 matrices
self._compareDeterminant(
np.random.rand(3, 4, 5, 2, 2).astype(np.complex128))
def testOverflow(self): def testOverflow(self):
max_double = np.finfo("d").max max_double = np.finfo("d").max
huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]]) huge_matrix = np.array([[max_double, 0.0], [0.0, max_double]])