diff --git a/tensorflow/core/api_def/base_api/api_def_Eig.pbtxt b/tensorflow/core/api_def/base_api/api_def_Eig.pbtxt new file mode 100644 index 00000000000..b85082c0cc8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Eig.pbtxt @@ -0,0 +1,45 @@ +op { + graph_op_name: "Eig" + endpoint { + name: "Eig" + } + in_arg { + name: "input" + description: <<END +`Tensor` input of shape `[N, N]`. +END + } + out_arg { + name: "e" + description: <<END +Eigenvalues. Shape is `[N]`. +END + } + out_arg { + name: "v" + description: <<END +Eigenvectors. Shape is `[N, N]`. +END + } + attr { + name: "compute_v" + description: <<END +If `True` then eigenvectors will be computed and returned in `v`. +Otherwise, only the eigenvalues will be computed. +END + } + summary: "Computes the eigen decomposition of one or more square matrices." + description: <<END +Computes the eigenvalues and (optionally) right eigenvectors of each inner matrix in +`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues +are sorted in non-decreasing order. + +```python +# a is a tensor. +# e is a tensor of eigenvalues. +# v is a tensor of eigenvectors. +e, v = eig(a) +e = eig(a, compute_v=False) +``` +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_Eig.pbtxt b/tensorflow/core/api_def/python_api/api_def_Eig.pbtxt new file mode 100644 index 00000000000..08a413a9941 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Eig.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Eig" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index d5607f641af..e0625f3a330 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3361,6 +3361,7 @@ cc_library( ":cholesky_grad", ":cholesky_op", ":determinant_op", + ":eig_op", ":einsum_op", ":lu_op", ":matrix_exponential_op", @@ -3473,6 +3474,15 @@ tf_kernel_library( ]), ) +tf_kernel_library( + name = "eig_op", + prefix = "eig_op", + deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([ + ":cast_op", + ":cwise_op", + ]), +) + tf_kernel_library( name = "matrix_inverse_op", prefix = "matrix_inverse_op", diff --git a/tensorflow/core/kernels/eig_op_complex128.cc b/tensorflow/core/kernels/eig_op_complex128.cc new file mode 100644 index 00000000000..988cc2f98d9 --- /dev/null +++ b/tensorflow/core/kernels/eig_op_complex128.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/eig_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("Eig", (EigOp<complex128, complex128>), complex128); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/eig_op_complex64.cc b/tensorflow/core/kernels/eig_op_complex64.cc new file mode 100644 index 00000000000..6a3f7928715 --- /dev/null +++ b/tensorflow/core/kernels/eig_op_complex64.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/eig_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("Eig", (EigOp<complex64, complex64>), complex64); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/eig_op_double.cc b/tensorflow/core/kernels/eig_op_double.cc new file mode 100644 index 00000000000..2cd931cc135 --- /dev/null +++ b/tensorflow/core/kernels/eig_op_double.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/eig_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("Eig", (EigOp<double, complex128>), double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/eig_op_float.cc b/tensorflow/core/kernels/eig_op_float.cc new file mode 100644 index 00000000000..a06f76e935f --- /dev/null +++ b/tensorflow/core/kernels/eig_op_float.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/eig_op_impl.h" + +namespace tensorflow { + +REGISTER_LINALG_OP("Eig", (EigOp<float, complex64>), float); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/eig_op_impl.h b/tensorflow/core/kernels/eig_op_impl.h new file mode 100644 index 00000000000..4ebb6bde08b --- /dev/null +++ b/tensorflow/core/kernels/eig_op_impl.h @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_ + +// See docs in ../ops/linalg_ops.cc. + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/Eigen/Eigenvalues" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/denormal.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +template <class InputScalar, class OutputScalar> +class EigOp : public LinearAlgebraOp<InputScalar, OutputScalar> { + public: + typedef LinearAlgebraOp<InputScalar, OutputScalar> Base; + + explicit EigOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_)); + } + + using TensorShapes = typename Base::TensorShapes; + using InputMatrix = typename Base::InputMatrix; + using InputMatrixMaps = typename Base::InputMatrixMaps; + using InputConstMatrixMap = typename Base::InputConstMatrixMap; + using InputConstMatrixMaps = typename Base::InputConstMatrixMaps; + + using OutputMatrix = typename Base::OutputMatrix; + using OutputMatrixMaps = typename Base::OutputMatrixMaps; + using OutputConstMatrixMap = typename Base::OutputConstMatrixMap; + using OutputConstMatrixMaps = typename Base::OutputConstMatrixMaps; + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + int64 n = input_matrix_shapes[0].dim_size(0); + if (compute_v_) { + return TensorShapes({TensorShape({n}), TensorShape({n, n})}); + } else { + return TensorShapes({TensorShape({n})}); + } + } + + void ComputeMatrix(OpKernelContext* context, + const InputConstMatrixMaps& inputs, + OutputMatrixMaps* outputs) final { + const int64 rows = inputs[0].rows(); + if (rows == 0) { + // If X is an empty matrix (0 rows, 0 col), X * X' == X. + // Therefore, we return X. + return; + } + + // This algorithm relies on denormals, so switch them back on locally. + port::ScopedDontFlushDenormal dont_flush_denormals; + + Eigen::ComplexEigenSolver<OutputMatrix> eig( + inputs[0], + compute_v_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); + // TODO(rmlarsen): Output more detailed error info on failure. + OP_REQUIRES( + context, eig.info() == Eigen::Success, + errors::InvalidArgument("Eigen decomposition was not " + "successful. The input might not be valid.")); + + outputs->at(0) = eig.eigenvalues().template cast<OutputScalar>(); + if (compute_v_) { + outputs->at(1) = eig.eigenvectors(); + } + } + + private: + bool compute_v_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_ diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc index b58bcf58348..3836ff796eb 100644 --- a/tensorflow/core/kernels/linalg_ops_common.cc +++ b/tensorflow/core/kernels/linalg_ops_common.cc @@ -29,8 +29,8 @@ limitations under the License. namespace tensorflow { // static -template <typename Scalar> -void LinearAlgebraOp<Scalar>::ValidateSingleMatrix( +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix( OpKernelContext* context, const TensorShapes& input_matrix_shapes) { OP_REQUIRES(context, input_matrix_shapes.size() == 1, errors::InvalidArgument("Expected a single input matrix, got %d.", @@ -40,8 +40,8 @@ void LinearAlgebraOp<Scalar>::ValidateSingleMatrix( } // static -template <typename Scalar> -void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix( +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix( OpKernelContext* context, const TensorShapes& input_matrix_shapes) { OP_REQUIRES(context, input_matrix_shapes.size() == 1, errors::InvalidArgument("Expected a single input matrix, got %d.", @@ -51,8 +51,8 @@ void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix( } // static -template <typename Scalar> -void LinearAlgebraOp<Scalar>::ValidateSolver( +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver( OpKernelContext* context, const TensorShapes& input_matrix_shapes) { OP_REQUIRES(context, input_matrix_shapes.size() == 2, errors::InvalidArgument("Expected two input matrices, got %d.", @@ -68,8 +68,8 @@ void LinearAlgebraOp<Scalar>::ValidateSolver( } // static -template <typename Scalar> -void LinearAlgebraOp<Scalar>::ValidateSquareSolver( +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver( OpKernelContext* context, const TensorShapes& input_matrix_shapes) { OP_REQUIRES(context, input_matrix_shapes.size() == 2, errors::InvalidArgument("Expected two input matrices, got %d.", @@ -85,8 +85,9 @@ void LinearAlgebraOp<Scalar>::ValidateSquareSolver( errors::InvalidArgument("Input matrix and rhs are incompatible.")); } -template <typename Scalar> -void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) { +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::Compute( + OpKernelContext* context) { TensorInputs inputs; TensorShapes input_matrix_shapes; TensorShape batch_shape; @@ -110,11 +111,10 @@ void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) { batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard); } -template <typename Scalar> -void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context, - TensorInputs* inputs, - TensorShapes* input_matrix_shapes, - TensorShape* batch_shape) { +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs( + OpKernelContext* context, TensorInputs* inputs, + TensorShapes* input_matrix_shapes, TensorShape* batch_shape) { int input_rank = -1; for (int i = 0; i < NumMatrixInputs(context); ++i) { const Tensor& in = context->input(i); @@ -155,8 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context, ValidateInputMatrixShapes(context, *input_matrix_shapes); } -template <typename Scalar> -void LinearAlgebraOp<Scalar>::PrepareOutputs( +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs( OpKernelContext* context, const TensorShapes& input_matrix_shapes, const TensorShape& batch_shape, TensorOutputs* outputs, TensorShapes* output_matrix_shapes) { @@ -214,22 +214,22 @@ void LinearAlgebraOp<Scalar>::PrepareOutputs( } } -template <typename Scalar> -void LinearAlgebraOp<Scalar>::ComputeTensorSlice( +template <class InputScalar, class OutputScalar> +void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice( OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs, const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs, const TensorShapes& output_matrix_shapes) { - ConstMatrixMaps matrix_inputs; + InputConstMatrixMaps matrix_inputs; for (size_t i = 0; i < inputs.size(); ++i) { // TODO(kalakris): Handle alignment if possible. Eigen::Map is // unaligned by default. matrix_inputs.emplace_back( - inputs[i]->flat<Scalar>().data() + + inputs[i]->flat<InputScalar>().data() + matrix_index * input_matrix_shapes[i].num_elements(), input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1)); } - MatrixMaps matrix_outputs; + OutputMatrixMaps matrix_outputs; for (size_t i = 0; i < output_matrix_shapes.size(); ++i) { // The output matrix shape may not be a matrix. int num_output_rows = output_matrix_shapes[i].dims() >= 1 @@ -239,7 +239,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice( ? output_matrix_shapes[i].dim_size(1) : 1; matrix_outputs.emplace_back( - outputs[i]->flat<Scalar>().data() + + outputs[i]->flat<OutputScalar>().data() + matrix_index * output_matrix_shapes[i].num_elements(), num_output_rows, num_output_cols); } @@ -251,5 +251,7 @@ template class LinearAlgebraOp<float>; template class LinearAlgebraOp<double>; template class LinearAlgebraOp<complex64>; template class LinearAlgebraOp<complex128>; +template class LinearAlgebraOp<float, complex64>; +template class LinearAlgebraOp<double, complex128>; } // namespace tensorflow diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h index 11ecf7d676e..65c2fb90f0e 100644 --- a/tensorflow/core/kernels/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { // Base class for linear algebra operators. -template <typename Scalar> +template <class InputScalar, class OutputScalar = InputScalar> class LinearAlgebraOp : public OpKernel { public: explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -109,6 +109,28 @@ class LinearAlgebraOp : public OpKernel { // and expect the kernel to perform the computation inplace. virtual bool EnableInputForwarding() const { return true; } + using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic, + Eigen::RowMajor>; + using InputConstMatrixMap = Eigen::Map<const InputMatrix>; + using InputMatrixMap = Eigen::Map<InputMatrix>; + using InputConstVectorMap = + Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>; + using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>; + using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>; + using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real; + + using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic, + Eigen::Dynamic, Eigen::RowMajor>; + using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>; + using OutputMatrixMap = Eigen::Map<OutputMatrix>; + using OutputConstVectorMap = + Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>; + using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>; + using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>; + using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real; + + // backward compatibility + using Scalar = OutputScalar; using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; using ConstMatrixMap = Eigen::Map<const Matrix>; @@ -126,8 +148,8 @@ class LinearAlgebraOp : public OpKernel { // parallelized. The number of threads used is determined by a cost model from // the value returned by GetCostPerUnit(). virtual void ComputeMatrix(OpKernelContext* context, - const ConstMatrixMaps& inputs, - MatrixMaps* outputs) = 0; + const InputConstMatrixMaps& inputs, + OutputMatrixMaps* outputs) = 0; private: using TensorInputs = gtl::InlinedVector<const Tensor*, 4>; diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index f037d38ef81..4572df279b7 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -383,6 +383,15 @@ REGISTER_OP("SelfAdjointEig") return Status::OK(); }); +REGISTER_OP("Eig") + .Input("input: T") + .Output("e: Tout") + .Output("v: Tout") + .Attr("compute_v: bool = True") + .Attr("T: {float, double, complex64, complex128}") + .Attr("Tout: {complex64, complex128}") + .SetShapeFn(SelfAdjointEigV2ShapeFn); + REGISTER_OP("SelfAdjointEigV2") .Input("input: T") .Output("e: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 617634980dd..1da4cef1557 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -12206,6 +12206,50 @@ op { type: "type" } } +op { + name: "Eig" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "e" + type_attr: "T" + } + output_arg { + name: "v" + type_attr: "T" + } + attr { + name: "compute_v" + type: "bool" + default_value { + b: true + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + attr { + name: "Tout" + type: "type" + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "Elu" input_arg { diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index ca50ed1d566..7176b894246 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3355,6 +3355,27 @@ cuda_py_test( shard_count = 20, ) +tf_py_test( + name = "eig_op_test", + size = "medium", + srcs = ["eig_op_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + ], + data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"], + shard_count = 20, + tags = [ + "no_rocm", # flaky test + "no_windows", + ], + # b/127344411: xla_enable_strict_auto_jit = True, +) + cuda_py_test( name = "self_adjoint_eig_op_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py new file mode 100644 index 00000000000..ffc61b7bcfe --- /dev/null +++ b/tensorflow/python/kernel_tests/eig_op_test.py @@ -0,0 +1,198 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.linalg_ops.eig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +def _AddTest(test_class, op_name, testcase_name, fn): + test_name = "_".join(["test", op_name, testcase_name]) + if hasattr(test_class, test_name): + raise RuntimeError("Test %s defined more than once" % test_name) + setattr(test_class, test_name, fn) + + +class EigTest(test.TestCase): + + @test_util.run_deprecated_v1 + def testWrongDimensions(self): + # The input to self_adjoint_eig should be a tensor of + # at least rank 2. + scalar = constant_op.constant(1.) + with self.assertRaises(ValueError): + linalg_ops.eig(scalar) + vector = constant_op.constant([1., 2.]) + with self.assertRaises(ValueError): + linalg_ops.eig(vector) + + @test_util.run_deprecated_v1 + def testConcurrentExecutesWithoutError(self): + all_ops = [] + with self.session(use_gpu=True) as sess: + for compute_v_ in True, False: + matrix1 = random_ops.random_normal([5, 5], seed=42) + matrix2 = random_ops.random_normal([5, 5], seed=42) + if compute_v_: + e1, v1 = linalg_ops.eig(matrix1) + e2, v2 = linalg_ops.eig(matrix2) + all_ops += [e1, v1, e2, v2] + else: + e1 = linalg_ops.eigvals(matrix1) + e2 = linalg_ops.eigvals(matrix2) + all_ops += [e1, e2] + val = self.evaluate(all_ops) + self.assertAllEqual(val[0], val[2]) + # The algorithm is slightly different for compute_v being True and False, + # so require approximate equality only here. + self.assertAllClose(val[2], val[4]) + self.assertAllEqual(val[4], val[5]) + self.assertAllEqual(val[1], val[3]) + + def testMatrixThatFailsWhenFlushingDenormsToZero(self): + # Test a 32x32 matrix which is known to fail if denorm floats are flushed to + # zero. + matrix = np.genfromtxt( + test.test_src_dir_path( + "python/kernel_tests/testdata/" + "self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32) + self.assertEqual(matrix.shape, (32, 32)) + matrix_tensor = constant_op.constant(matrix) + with self.session(use_gpu=True) as sess: + (e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor)) + self.assertEqual(e.size, 32) + self.assertAllClose( + np.matmul(v, v.transpose()), np.eye(32, dtype=np.float32), atol=2e-3) + self.assertAllClose(matrix, + np.matmul(np.matmul(v, np.diag(e)), v.transpose())) + + +def SortEigenValues(e): + perm = np.argsort(e.real + e.imag, -1) + return np.take(e, perm, -1) + + +def SortEigenDecomposition(e, v): + if v.ndim < 2: + return e, v + else: + perm = np.argsort(e.real + e.imag, -1) + return np.take(e, perm, -1), np.take(v, perm, -1) + + +def EquilibrateEigenVectorPhases(x, y): + """Equilibrate the phase of the Eigenvectors in the columns of `x` and `y`. + + Eigenvectors are only unique up to an arbitrary phase. This function rotates x + such that it matches y. Precondition: The coluns of x and y differ by a + multiplicative complex phase factor only. + + Args: + x: `np.ndarray` with Eigenvectors + y: `np.ndarray` with Eigenvectors + + Returns: + `np.ndarray` containing an equilibrated version of x. + """ + phases = np.sum(np.conj(x) * y, -2, keepdims=True) + phases /= np.abs(phases) + return phases * x + + +def _GetEigTest(dtype_, shape_, compute_v_): + + def CompareEigenVectors(self, x, y, tol): + x = EquilibrateEigenVectorPhases(x, y) + self.assertAllClose(x, y, atol=tol) + + def CompareEigenDecompositions(self, x_e, x_v, y_e, y_v, tol): + num_batches = int(np.prod(x_e.shape[:-1])) + n = x_e.shape[-1] + x_e = np.reshape(x_e, [num_batches] + [n]) + x_v = np.reshape(x_v, [num_batches] + [n, n]) + y_e = np.reshape(y_e, [num_batches] + [n]) + y_v = np.reshape(y_v, [num_batches] + [n, n]) + for i in range(num_batches): + x_ei, x_vi = SortEigenDecomposition(x_e[i, :], x_v[i, :, :]) + y_ei, y_vi = SortEigenDecomposition(y_e[i, :], y_v[i, :, :]) + self.assertAllClose(x_ei, y_ei, atol=tol, rtol=tol) + CompareEigenVectors(self, x_vi, y_vi, tol) + + def Test(self): + np.random.seed(1) + n = shape_[-1] + batch_shape = shape_[:-2] + np_dtype = dtype_.as_numpy_dtype + # most of matrices are diagonalizable # TODO + a = np.random.uniform( + low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) + if dtype_.is_complex: + a += 1j * np.random.uniform( + low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype) + a = np.tile(a, batch_shape + (1, 1)) + if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64): + atol = 1e-4 + else: + atol = 1e-12 + np_e, np_v = np.linalg.eig(a) + with self.session(use_gpu=True): + if compute_v_: + tf_e, tf_v = linalg_ops.eig(constant_op.constant(a)) + + # Check that V*diag(E)*V^(-1) is close to A. + a_ev = math_ops.matmul( + math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)), + linalg_ops.matrix_inverse(tf_v)) + self.assertAllClose(self.evaluate(a_ev), a, atol=atol) + + # Compare to numpy.linalg.eig. + CompareEigenDecompositions(self, np_e, np_v, self.evaluate(tf_e), + self.evaluate(tf_v), atol) + else: + tf_e = linalg_ops.eigvals(constant_op.constant(a)) + self.assertAllClose( + SortEigenValues(np_e), + SortEigenValues(self.evaluate(tf_e)), + atol=atol) + + return Test + + +if __name__ == "__main__": + dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] + if not test.is_built_with_rocm(): + # ROCm does not support BLAS operations for complex types + dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] + for compute_v in True, False: + for dtype in dtypes_to_test: + for size in 1, 2, 5, 10: + for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10): + shape = batch_dims + (size, size) + name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v) + _AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v)) + # No gradient yet + test.main() diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py index a42d7922bfb..0ada446e84b 100644 --- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py +++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow.ops.math_ops.matrix_inverse.""" +"""Tests for tensorflow.ops.linalg_ops.self_adjoint_eig.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 914e5748534..e49434ffd4e 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -306,6 +306,62 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): matrix, rhs, l2_regularizer, fast=fast, name=name) +@tf_export('eig', 'linalg.eig', v1=[]) +def eig(tensor, name=None): + """Computes the eigen decomposition of a batch of matrices. + + The eigenvalues + and eigenvectors for a non-Hermitian matrix in general are complex. The + eigenvectors are not guaranteed to be linearly independent. + + Computes the eigenvalues and right eigenvectors of the innermost + N-by-N matrices in `tensor` such that + `tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i]`, for i=0...N-1. + + Args: + tensor: `Tensor` of shape `[..., N, N]`. Only the lower triangular part of + each inner inner matrix is referenced. + name: string, optional name of the operation. + + Returns: + e: Eigenvalues. Shape is `[..., N]`. Sorted in non-decreasing order. + v: Eigenvectors. Shape is `[..., N, N]`. The columns of the inner most + matrices contain eigenvectors of the corresponding matrices in `tensor` + """ + if tensor.dtype == dtypes.float32 or tensor.dtype == dtypes.complex64: + out_dtype = dtypes.complex64 + elif tensor.dtype == dtypes.float64 or tensor.dtype == dtypes.complex128: + out_dtype = dtypes.complex128 + e, v = gen_linalg_ops.eig(tensor, Tout=out_dtype, compute_v=True, name=name) + return e, v + + +@tf_export('eigvals', 'linalg.eigvals', v1=[]) +def eigvals(tensor, name=None): + """Computes the eigenvalues of one or more matrices. + + Note: If your program backpropagates through this function, you should replace + it with a call to tf.linalg.eig (possibly ignoring the second output) to + avoid computing the eigen decomposition twice. This is because the + eigenvectors are used to compute the gradient w.r.t. the eigenvalues. See + _SelfAdjointEigV2Grad in linalg_grad.py. + + Args: + tensor: `Tensor` of shape `[..., N, N]`. + name: string, optional name of the operation. + + Returns: + e: Eigenvalues. Shape is `[..., N]`. The vector `e[..., :]` contains the `N` + eigenvalues of `tensor[..., :, :]`. + """ + if tensor.dtype == dtypes.float32 or tensor.dtype == dtypes.complex64: + out_dtype = dtypes.complex64 + elif tensor.dtype == dtypes.float64 or tensor.dtype == dtypes.complex128: + out_dtype = dtypes.complex128 + e, _ = gen_linalg_ops.eig(tensor, Tout=out_dtype, compute_v=False, name=name) + return e + + @tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig']) @deprecation.deprecated_endpoints('self_adjoint_eig') def self_adjoint_eig(tensor, name=None): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index e7d5f1aec78..8ae11431a08 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1132,6 +1132,10 @@ tf_module { name: "EditDistance" argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } + member_method { + name: "Eig" + argspec: "args=[\'input\', \'Tout\', \'compute_v\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } member_method { name: "Einsum" argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt index 3150ea14464..a25583d7fdd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -108,10 +108,18 @@ tf_module { name: "diag_part" argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], " } + member_method { + name: "eig" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "eigh" argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "eigvals" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "eigvalsh" argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index f3d5aec9215..d67870a92b8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -604,6 +604,14 @@ tf_module { name: "edit_distance" argspec: "args=[\'hypothesis\', \'truth\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'edit_distance\'], " } + member_method { + name: "eig" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "eigvals" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "einsum" argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index e7d5f1aec78..8ae11431a08 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1132,6 +1132,10 @@ tf_module { name: "EditDistance" argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " } + member_method { + name: "Eig" + argspec: "args=[\'input\', \'Tout\', \'compute_v\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } member_method { name: "Einsum" argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "