commit
8ae730be26
45
tensorflow/core/api_def/base_api/api_def_Eig.pbtxt
Normal file
45
tensorflow/core/api_def/base_api/api_def_Eig.pbtxt
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Eig"
|
||||||
|
endpoint {
|
||||||
|
name: "Eig"
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: <<END
|
||||||
|
`Tensor` input of shape `[N, N]`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "e"
|
||||||
|
description: <<END
|
||||||
|
Eigenvalues. Shape is `[N]`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "v"
|
||||||
|
description: <<END
|
||||||
|
Eigenvectors. Shape is `[N, N]`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "compute_v"
|
||||||
|
description: <<END
|
||||||
|
If `True` then eigenvectors will be computed and returned in `v`.
|
||||||
|
Otherwise, only the eigenvalues will be computed.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Computes the eigen decomposition of one or more square matrices."
|
||||||
|
description: <<END
|
||||||
|
Computes the eigenvalues and (optionally) right eigenvectors of each inner matrix in
|
||||||
|
`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues
|
||||||
|
are sorted in non-decreasing order.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# a is a tensor.
|
||||||
|
# e is a tensor of eigenvalues.
|
||||||
|
# v is a tensor of eigenvectors.
|
||||||
|
e, v = eig(a)
|
||||||
|
e = eig(a, compute_v=False)
|
||||||
|
```
|
||||||
|
END
|
||||||
|
}
|
4
tensorflow/core/api_def/python_api/api_def_Eig.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Eig.pbtxt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Eig"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -3361,6 +3361,7 @@ cc_library(
|
|||||||
":cholesky_grad",
|
":cholesky_grad",
|
||||||
":cholesky_op",
|
":cholesky_op",
|
||||||
":determinant_op",
|
":determinant_op",
|
||||||
|
":eig_op",
|
||||||
":einsum_op",
|
":einsum_op",
|
||||||
":lu_op",
|
":lu_op",
|
||||||
":matrix_exponential_op",
|
":matrix_exponential_op",
|
||||||
@ -3473,6 +3474,15 @@ tf_kernel_library(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "eig_op",
|
||||||
|
prefix = "eig_op",
|
||||||
|
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
|
||||||
|
":cast_op",
|
||||||
|
":cwise_op",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "matrix_inverse_op",
|
name = "matrix_inverse_op",
|
||||||
prefix = "matrix_inverse_op",
|
prefix = "matrix_inverse_op",
|
||||||
|
22
tensorflow/core/kernels/eig_op_complex128.cc
Normal file
22
tensorflow/core/kernels/eig_op_complex128.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/eig_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("Eig", (EigOp<complex128, complex128>), complex128);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
22
tensorflow/core/kernels/eig_op_complex64.cc
Normal file
22
tensorflow/core/kernels/eig_op_complex64.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/eig_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("Eig", (EigOp<complex64, complex64>), complex64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
22
tensorflow/core/kernels/eig_op_double.cc
Normal file
22
tensorflow/core/kernels/eig_op_double.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/eig_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("Eig", (EigOp<double, complex128>), double);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
22
tensorflow/core/kernels/eig_op_float.cc
Normal file
22
tensorflow/core/kernels/eig_op_float.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/eig_op_impl.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("Eig", (EigOp<float, complex64>), float);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
98
tensorflow/core/kernels/eig_op_impl.h
Normal file
98
tensorflow/core/kernels/eig_op_impl.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
|
||||||
|
|
||||||
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "third_party/eigen3/Eigen/Eigenvalues"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/denormal.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
template <class InputScalar, class OutputScalar>
|
||||||
|
class EigOp : public LinearAlgebraOp<InputScalar, OutputScalar> {
|
||||||
|
public:
|
||||||
|
typedef LinearAlgebraOp<InputScalar, OutputScalar> Base;
|
||||||
|
|
||||||
|
explicit EigOp(OpKernelConstruction* context) : Base(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_));
|
||||||
|
}
|
||||||
|
|
||||||
|
using TensorShapes = typename Base::TensorShapes;
|
||||||
|
using InputMatrix = typename Base::InputMatrix;
|
||||||
|
using InputMatrixMaps = typename Base::InputMatrixMaps;
|
||||||
|
using InputConstMatrixMap = typename Base::InputConstMatrixMap;
|
||||||
|
using InputConstMatrixMaps = typename Base::InputConstMatrixMaps;
|
||||||
|
|
||||||
|
using OutputMatrix = typename Base::OutputMatrix;
|
||||||
|
using OutputMatrixMaps = typename Base::OutputMatrixMaps;
|
||||||
|
using OutputConstMatrixMap = typename Base::OutputConstMatrixMap;
|
||||||
|
using OutputConstMatrixMaps = typename Base::OutputConstMatrixMaps;
|
||||||
|
|
||||||
|
TensorShapes GetOutputMatrixShapes(
|
||||||
|
const TensorShapes& input_matrix_shapes) const final {
|
||||||
|
int64 n = input_matrix_shapes[0].dim_size(0);
|
||||||
|
if (compute_v_) {
|
||||||
|
return TensorShapes({TensorShape({n}), TensorShape({n, n})});
|
||||||
|
} else {
|
||||||
|
return TensorShapes({TensorShape({n})});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputeMatrix(OpKernelContext* context,
|
||||||
|
const InputConstMatrixMaps& inputs,
|
||||||
|
OutputMatrixMaps* outputs) final {
|
||||||
|
const int64 rows = inputs[0].rows();
|
||||||
|
if (rows == 0) {
|
||||||
|
// If X is an empty matrix (0 rows, 0 col), X * X' == X.
|
||||||
|
// Therefore, we return X.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This algorithm relies on denormals, so switch them back on locally.
|
||||||
|
port::ScopedDontFlushDenormal dont_flush_denormals;
|
||||||
|
|
||||||
|
Eigen::ComplexEigenSolver<OutputMatrix> eig(
|
||||||
|
inputs[0],
|
||||||
|
compute_v_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly);
|
||||||
|
// TODO(rmlarsen): Output more detailed error info on failure.
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, eig.info() == Eigen::Success,
|
||||||
|
errors::InvalidArgument("Eigen decomposition was not "
|
||||||
|
"successful. The input might not be valid."));
|
||||||
|
|
||||||
|
outputs->at(0) = eig.eigenvalues().template cast<OutputScalar>();
|
||||||
|
if (compute_v_) {
|
||||||
|
outputs->at(1) = eig.eigenvectors();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool compute_v_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
|
@ -29,8 +29,8 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// static
|
// static
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::ValidateSingleMatrix(
|
void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix(
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
||||||
OP_REQUIRES(context, input_matrix_shapes.size() == 1,
|
OP_REQUIRES(context, input_matrix_shapes.size() == 1,
|
||||||
errors::InvalidArgument("Expected a single input matrix, got %d.",
|
errors::InvalidArgument("Expected a single input matrix, got %d.",
|
||||||
@ -40,8 +40,8 @@ void LinearAlgebraOp<Scalar>::ValidateSingleMatrix(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix(
|
void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix(
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
||||||
OP_REQUIRES(context, input_matrix_shapes.size() == 1,
|
OP_REQUIRES(context, input_matrix_shapes.size() == 1,
|
||||||
errors::InvalidArgument("Expected a single input matrix, got %d.",
|
errors::InvalidArgument("Expected a single input matrix, got %d.",
|
||||||
@ -51,8 +51,8 @@ void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::ValidateSolver(
|
void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver(
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
||||||
OP_REQUIRES(context, input_matrix_shapes.size() == 2,
|
OP_REQUIRES(context, input_matrix_shapes.size() == 2,
|
||||||
errors::InvalidArgument("Expected two input matrices, got %d.",
|
errors::InvalidArgument("Expected two input matrices, got %d.",
|
||||||
@ -68,8 +68,8 @@ void LinearAlgebraOp<Scalar>::ValidateSolver(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::ValidateSquareSolver(
|
void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver(
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
|
||||||
OP_REQUIRES(context, input_matrix_shapes.size() == 2,
|
OP_REQUIRES(context, input_matrix_shapes.size() == 2,
|
||||||
errors::InvalidArgument("Expected two input matrices, got %d.",
|
errors::InvalidArgument("Expected two input matrices, got %d.",
|
||||||
@ -85,8 +85,9 @@ void LinearAlgebraOp<Scalar>::ValidateSquareSolver(
|
|||||||
errors::InvalidArgument("Input matrix and rhs are incompatible."));
|
errors::InvalidArgument("Input matrix and rhs are incompatible."));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
|
void LinearAlgebraOp<InputScalar, OutputScalar>::Compute(
|
||||||
|
OpKernelContext* context) {
|
||||||
TensorInputs inputs;
|
TensorInputs inputs;
|
||||||
TensorShapes input_matrix_shapes;
|
TensorShapes input_matrix_shapes;
|
||||||
TensorShape batch_shape;
|
TensorShape batch_shape;
|
||||||
@ -110,11 +111,10 @@ void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
|
|||||||
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
|
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
|
void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
|
||||||
TensorInputs* inputs,
|
OpKernelContext* context, TensorInputs* inputs,
|
||||||
TensorShapes* input_matrix_shapes,
|
TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
|
||||||
TensorShape* batch_shape) {
|
|
||||||
int input_rank = -1;
|
int input_rank = -1;
|
||||||
for (int i = 0; i < NumMatrixInputs(context); ++i) {
|
for (int i = 0; i < NumMatrixInputs(context); ++i) {
|
||||||
const Tensor& in = context->input(i);
|
const Tensor& in = context->input(i);
|
||||||
@ -155,8 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
|
|||||||
ValidateInputMatrixShapes(context, *input_matrix_shapes);
|
ValidateInputMatrixShapes(context, *input_matrix_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::PrepareOutputs(
|
void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes,
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes,
|
||||||
const TensorShape& batch_shape, TensorOutputs* outputs,
|
const TensorShape& batch_shape, TensorOutputs* outputs,
|
||||||
TensorShapes* output_matrix_shapes) {
|
TensorShapes* output_matrix_shapes) {
|
||||||
@ -214,22 +214,22 @@ void LinearAlgebraOp<Scalar>::PrepareOutputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar>
|
||||||
void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
|
void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice(
|
||||||
OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
|
OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
|
||||||
const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
|
const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
|
||||||
const TensorShapes& output_matrix_shapes) {
|
const TensorShapes& output_matrix_shapes) {
|
||||||
ConstMatrixMaps matrix_inputs;
|
InputConstMatrixMaps matrix_inputs;
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// TODO(kalakris): Handle alignment if possible. Eigen::Map is
|
// TODO(kalakris): Handle alignment if possible. Eigen::Map is
|
||||||
// unaligned by default.
|
// unaligned by default.
|
||||||
matrix_inputs.emplace_back(
|
matrix_inputs.emplace_back(
|
||||||
inputs[i]->flat<Scalar>().data() +
|
inputs[i]->flat<InputScalar>().data() +
|
||||||
matrix_index * input_matrix_shapes[i].num_elements(),
|
matrix_index * input_matrix_shapes[i].num_elements(),
|
||||||
input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
|
input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
MatrixMaps matrix_outputs;
|
OutputMatrixMaps matrix_outputs;
|
||||||
for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
|
for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
|
||||||
// The output matrix shape may not be a matrix.
|
// The output matrix shape may not be a matrix.
|
||||||
int num_output_rows = output_matrix_shapes[i].dims() >= 1
|
int num_output_rows = output_matrix_shapes[i].dims() >= 1
|
||||||
@ -239,7 +239,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
|
|||||||
? output_matrix_shapes[i].dim_size(1)
|
? output_matrix_shapes[i].dim_size(1)
|
||||||
: 1;
|
: 1;
|
||||||
matrix_outputs.emplace_back(
|
matrix_outputs.emplace_back(
|
||||||
outputs[i]->flat<Scalar>().data() +
|
outputs[i]->flat<OutputScalar>().data() +
|
||||||
matrix_index * output_matrix_shapes[i].num_elements(),
|
matrix_index * output_matrix_shapes[i].num_elements(),
|
||||||
num_output_rows, num_output_cols);
|
num_output_rows, num_output_cols);
|
||||||
}
|
}
|
||||||
@ -251,5 +251,7 @@ template class LinearAlgebraOp<float>;
|
|||||||
template class LinearAlgebraOp<double>;
|
template class LinearAlgebraOp<double>;
|
||||||
template class LinearAlgebraOp<complex64>;
|
template class LinearAlgebraOp<complex64>;
|
||||||
template class LinearAlgebraOp<complex128>;
|
template class LinearAlgebraOp<complex128>;
|
||||||
|
template class LinearAlgebraOp<float, complex64>;
|
||||||
|
template class LinearAlgebraOp<double, complex128>;
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -36,7 +36,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Base class for linear algebra operators.
|
// Base class for linear algebra operators.
|
||||||
template <typename Scalar>
|
template <class InputScalar, class OutputScalar = InputScalar>
|
||||||
class LinearAlgebraOp : public OpKernel {
|
class LinearAlgebraOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
@ -109,6 +109,28 @@ class LinearAlgebraOp : public OpKernel {
|
|||||||
// and expect the kernel to perform the computation inplace.
|
// and expect the kernel to perform the computation inplace.
|
||||||
virtual bool EnableInputForwarding() const { return true; }
|
virtual bool EnableInputForwarding() const { return true; }
|
||||||
|
|
||||||
|
using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
|
||||||
|
Eigen::RowMajor>;
|
||||||
|
using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
|
||||||
|
using InputMatrixMap = Eigen::Map<InputMatrix>;
|
||||||
|
using InputConstVectorMap =
|
||||||
|
Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
|
||||||
|
using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
|
||||||
|
using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
|
||||||
|
using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
|
||||||
|
|
||||||
|
using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
|
||||||
|
Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
|
||||||
|
using OutputMatrixMap = Eigen::Map<OutputMatrix>;
|
||||||
|
using OutputConstVectorMap =
|
||||||
|
Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
|
||||||
|
using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
|
||||||
|
using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
|
||||||
|
using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
|
||||||
|
|
||||||
|
// backward compatibility
|
||||||
|
using Scalar = OutputScalar;
|
||||||
using Matrix =
|
using Matrix =
|
||||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||||
@ -126,8 +148,8 @@ class LinearAlgebraOp : public OpKernel {
|
|||||||
// parallelized. The number of threads used is determined by a cost model from
|
// parallelized. The number of threads used is determined by a cost model from
|
||||||
// the value returned by GetCostPerUnit().
|
// the value returned by GetCostPerUnit().
|
||||||
virtual void ComputeMatrix(OpKernelContext* context,
|
virtual void ComputeMatrix(OpKernelContext* context,
|
||||||
const ConstMatrixMaps& inputs,
|
const InputConstMatrixMaps& inputs,
|
||||||
MatrixMaps* outputs) = 0;
|
OutputMatrixMaps* outputs) = 0;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
|
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
|
||||||
|
@ -383,6 +383,15 @@ REGISTER_OP("SelfAdjointEig")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("Eig")
|
||||||
|
.Input("input: T")
|
||||||
|
.Output("e: Tout")
|
||||||
|
.Output("v: Tout")
|
||||||
|
.Attr("compute_v: bool = True")
|
||||||
|
.Attr("T: {float, double, complex64, complex128}")
|
||||||
|
.Attr("Tout: {complex64, complex128}")
|
||||||
|
.SetShapeFn(SelfAdjointEigV2ShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("SelfAdjointEigV2")
|
REGISTER_OP("SelfAdjointEigV2")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Output("e: T")
|
.Output("e: T")
|
||||||
|
@ -12206,6 +12206,50 @@ op {
|
|||||||
type: "type"
|
type: "type"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Eig"
|
||||||
|
input_arg {
|
||||||
|
name: "input"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "e"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "v"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "compute_v"
|
||||||
|
type: "bool"
|
||||||
|
default_value {
|
||||||
|
b: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tout"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Elu"
|
name: "Elu"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -3355,6 +3355,27 @@ cuda_py_test(
|
|||||||
shard_count = 20,
|
shard_count = 20,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "eig_op_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["eig_op_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:linalg_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
],
|
||||||
|
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
|
||||||
|
shard_count = 20,
|
||||||
|
tags = [
|
||||||
|
"no_rocm", # flaky test
|
||||||
|
"no_windows",
|
||||||
|
],
|
||||||
|
# b/127344411: xla_enable_strict_auto_jit = True,
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "self_adjoint_eig_op_test",
|
name = "self_adjoint_eig_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
198
tensorflow/python/kernel_tests/eig_op_test.py
Normal file
198
tensorflow/python/kernel_tests/eig_op_test.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.linalg_ops.eig."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
def _AddTest(test_class, op_name, testcase_name, fn):
|
||||||
|
test_name = "_".join(["test", op_name, testcase_name])
|
||||||
|
if hasattr(test_class, test_name):
|
||||||
|
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||||
|
setattr(test_class, test_name, fn)
|
||||||
|
|
||||||
|
|
||||||
|
class EigTest(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testWrongDimensions(self):
|
||||||
|
# The input to self_adjoint_eig should be a tensor of
|
||||||
|
# at least rank 2.
|
||||||
|
scalar = constant_op.constant(1.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
linalg_ops.eig(scalar)
|
||||||
|
vector = constant_op.constant([1., 2.])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
linalg_ops.eig(vector)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testConcurrentExecutesWithoutError(self):
|
||||||
|
all_ops = []
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
for compute_v_ in True, False:
|
||||||
|
matrix1 = random_ops.random_normal([5, 5], seed=42)
|
||||||
|
matrix2 = random_ops.random_normal([5, 5], seed=42)
|
||||||
|
if compute_v_:
|
||||||
|
e1, v1 = linalg_ops.eig(matrix1)
|
||||||
|
e2, v2 = linalg_ops.eig(matrix2)
|
||||||
|
all_ops += [e1, v1, e2, v2]
|
||||||
|
else:
|
||||||
|
e1 = linalg_ops.eigvals(matrix1)
|
||||||
|
e2 = linalg_ops.eigvals(matrix2)
|
||||||
|
all_ops += [e1, e2]
|
||||||
|
val = self.evaluate(all_ops)
|
||||||
|
self.assertAllEqual(val[0], val[2])
|
||||||
|
# The algorithm is slightly different for compute_v being True and False,
|
||||||
|
# so require approximate equality only here.
|
||||||
|
self.assertAllClose(val[2], val[4])
|
||||||
|
self.assertAllEqual(val[4], val[5])
|
||||||
|
self.assertAllEqual(val[1], val[3])
|
||||||
|
|
||||||
|
def testMatrixThatFailsWhenFlushingDenormsToZero(self):
|
||||||
|
# Test a 32x32 matrix which is known to fail if denorm floats are flushed to
|
||||||
|
# zero.
|
||||||
|
matrix = np.genfromtxt(
|
||||||
|
test.test_src_dir_path(
|
||||||
|
"python/kernel_tests/testdata/"
|
||||||
|
"self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
|
||||||
|
self.assertEqual(matrix.shape, (32, 32))
|
||||||
|
matrix_tensor = constant_op.constant(matrix)
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
(e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
|
||||||
|
self.assertEqual(e.size, 32)
|
||||||
|
self.assertAllClose(
|
||||||
|
np.matmul(v, v.transpose()), np.eye(32, dtype=np.float32), atol=2e-3)
|
||||||
|
self.assertAllClose(matrix,
|
||||||
|
np.matmul(np.matmul(v, np.diag(e)), v.transpose()))
|
||||||
|
|
||||||
|
|
||||||
|
def SortEigenValues(e):
|
||||||
|
perm = np.argsort(e.real + e.imag, -1)
|
||||||
|
return np.take(e, perm, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def SortEigenDecomposition(e, v):
|
||||||
|
if v.ndim < 2:
|
||||||
|
return e, v
|
||||||
|
else:
|
||||||
|
perm = np.argsort(e.real + e.imag, -1)
|
||||||
|
return np.take(e, perm, -1), np.take(v, perm, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def EquilibrateEigenVectorPhases(x, y):
|
||||||
|
"""Equilibrate the phase of the Eigenvectors in the columns of `x` and `y`.
|
||||||
|
|
||||||
|
Eigenvectors are only unique up to an arbitrary phase. This function rotates x
|
||||||
|
such that it matches y. Precondition: The coluns of x and y differ by a
|
||||||
|
multiplicative complex phase factor only.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: `np.ndarray` with Eigenvectors
|
||||||
|
y: `np.ndarray` with Eigenvectors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`np.ndarray` containing an equilibrated version of x.
|
||||||
|
"""
|
||||||
|
phases = np.sum(np.conj(x) * y, -2, keepdims=True)
|
||||||
|
phases /= np.abs(phases)
|
||||||
|
return phases * x
|
||||||
|
|
||||||
|
|
||||||
|
def _GetEigTest(dtype_, shape_, compute_v_):
|
||||||
|
|
||||||
|
def CompareEigenVectors(self, x, y, tol):
|
||||||
|
x = EquilibrateEigenVectorPhases(x, y)
|
||||||
|
self.assertAllClose(x, y, atol=tol)
|
||||||
|
|
||||||
|
def CompareEigenDecompositions(self, x_e, x_v, y_e, y_v, tol):
|
||||||
|
num_batches = int(np.prod(x_e.shape[:-1]))
|
||||||
|
n = x_e.shape[-1]
|
||||||
|
x_e = np.reshape(x_e, [num_batches] + [n])
|
||||||
|
x_v = np.reshape(x_v, [num_batches] + [n, n])
|
||||||
|
y_e = np.reshape(y_e, [num_batches] + [n])
|
||||||
|
y_v = np.reshape(y_v, [num_batches] + [n, n])
|
||||||
|
for i in range(num_batches):
|
||||||
|
x_ei, x_vi = SortEigenDecomposition(x_e[i, :], x_v[i, :, :])
|
||||||
|
y_ei, y_vi = SortEigenDecomposition(y_e[i, :], y_v[i, :, :])
|
||||||
|
self.assertAllClose(x_ei, y_ei, atol=tol, rtol=tol)
|
||||||
|
CompareEigenVectors(self, x_vi, y_vi, tol)
|
||||||
|
|
||||||
|
def Test(self):
|
||||||
|
np.random.seed(1)
|
||||||
|
n = shape_[-1]
|
||||||
|
batch_shape = shape_[:-2]
|
||||||
|
np_dtype = dtype_.as_numpy_dtype
|
||||||
|
# most of matrices are diagonalizable # TODO
|
||||||
|
a = np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
||||||
|
if dtype_.is_complex:
|
||||||
|
a += 1j * np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
|
||||||
|
a = np.tile(a, batch_shape + (1, 1))
|
||||||
|
if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
|
||||||
|
atol = 1e-4
|
||||||
|
else:
|
||||||
|
atol = 1e-12
|
||||||
|
np_e, np_v = np.linalg.eig(a)
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
if compute_v_:
|
||||||
|
tf_e, tf_v = linalg_ops.eig(constant_op.constant(a))
|
||||||
|
|
||||||
|
# Check that V*diag(E)*V^(-1) is close to A.
|
||||||
|
a_ev = math_ops.matmul(
|
||||||
|
math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)),
|
||||||
|
linalg_ops.matrix_inverse(tf_v))
|
||||||
|
self.assertAllClose(self.evaluate(a_ev), a, atol=atol)
|
||||||
|
|
||||||
|
# Compare to numpy.linalg.eig.
|
||||||
|
CompareEigenDecompositions(self, np_e, np_v, self.evaluate(tf_e),
|
||||||
|
self.evaluate(tf_v), atol)
|
||||||
|
else:
|
||||||
|
tf_e = linalg_ops.eigvals(constant_op.constant(a))
|
||||||
|
self.assertAllClose(
|
||||||
|
SortEigenValues(np_e),
|
||||||
|
SortEigenValues(self.evaluate(tf_e)),
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
return Test
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64]
|
||||||
|
if not test.is_built_with_rocm():
|
||||||
|
# ROCm does not support BLAS operations for complex types
|
||||||
|
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128]
|
||||||
|
for compute_v in True, False:
|
||||||
|
for dtype in dtypes_to_test:
|
||||||
|
for size in 1, 2, 5, 10:
|
||||||
|
for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10):
|
||||||
|
shape = batch_dims + (size, size)
|
||||||
|
name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v)
|
||||||
|
_AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v))
|
||||||
|
# No gradient yet
|
||||||
|
test.main()
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
|
"""Tests for tensorflow.ops.linalg_ops.self_adjoint_eig."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -306,6 +306,62 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
|
|||||||
matrix, rhs, l2_regularizer, fast=fast, name=name)
|
matrix, rhs, l2_regularizer, fast=fast, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('eig', 'linalg.eig', v1=[])
|
||||||
|
def eig(tensor, name=None):
|
||||||
|
"""Computes the eigen decomposition of a batch of matrices.
|
||||||
|
|
||||||
|
The eigenvalues
|
||||||
|
and eigenvectors for a non-Hermitian matrix in general are complex. The
|
||||||
|
eigenvectors are not guaranteed to be linearly independent.
|
||||||
|
|
||||||
|
Computes the eigenvalues and right eigenvectors of the innermost
|
||||||
|
N-by-N matrices in `tensor` such that
|
||||||
|
`tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i]`, for i=0...N-1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: `Tensor` of shape `[..., N, N]`. Only the lower triangular part of
|
||||||
|
each inner inner matrix is referenced.
|
||||||
|
name: string, optional name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
e: Eigenvalues. Shape is `[..., N]`. Sorted in non-decreasing order.
|
||||||
|
v: Eigenvectors. Shape is `[..., N, N]`. The columns of the inner most
|
||||||
|
matrices contain eigenvectors of the corresponding matrices in `tensor`
|
||||||
|
"""
|
||||||
|
if tensor.dtype == dtypes.float32 or tensor.dtype == dtypes.complex64:
|
||||||
|
out_dtype = dtypes.complex64
|
||||||
|
elif tensor.dtype == dtypes.float64 or tensor.dtype == dtypes.complex128:
|
||||||
|
out_dtype = dtypes.complex128
|
||||||
|
e, v = gen_linalg_ops.eig(tensor, Tout=out_dtype, compute_v=True, name=name)
|
||||||
|
return e, v
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('eigvals', 'linalg.eigvals', v1=[])
|
||||||
|
def eigvals(tensor, name=None):
|
||||||
|
"""Computes the eigenvalues of one or more matrices.
|
||||||
|
|
||||||
|
Note: If your program backpropagates through this function, you should replace
|
||||||
|
it with a call to tf.linalg.eig (possibly ignoring the second output) to
|
||||||
|
avoid computing the eigen decomposition twice. This is because the
|
||||||
|
eigenvectors are used to compute the gradient w.r.t. the eigenvalues. See
|
||||||
|
_SelfAdjointEigV2Grad in linalg_grad.py.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: `Tensor` of shape `[..., N, N]`.
|
||||||
|
name: string, optional name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
e: Eigenvalues. Shape is `[..., N]`. The vector `e[..., :]` contains the `N`
|
||||||
|
eigenvalues of `tensor[..., :, :]`.
|
||||||
|
"""
|
||||||
|
if tensor.dtype == dtypes.float32 or tensor.dtype == dtypes.complex64:
|
||||||
|
out_dtype = dtypes.complex64
|
||||||
|
elif tensor.dtype == dtypes.float64 or tensor.dtype == dtypes.complex128:
|
||||||
|
out_dtype = dtypes.complex128
|
||||||
|
e, _ = gen_linalg_ops.eig(tensor, Tout=out_dtype, compute_v=False, name=name)
|
||||||
|
return e
|
||||||
|
|
||||||
|
|
||||||
@tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig'])
|
@tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig'])
|
||||||
@deprecation.deprecated_endpoints('self_adjoint_eig')
|
@deprecation.deprecated_endpoints('self_adjoint_eig')
|
||||||
def self_adjoint_eig(tensor, name=None):
|
def self_adjoint_eig(tensor, name=None):
|
||||||
|
@ -1132,6 +1132,10 @@ tf_module {
|
|||||||
name: "EditDistance"
|
name: "EditDistance"
|
||||||
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "Eig"
|
||||||
|
argspec: "args=[\'input\', \'Tout\', \'compute_v\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Einsum"
|
name: "Einsum"
|
||||||
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -108,10 +108,18 @@ tf_module {
|
|||||||
name: "diag_part"
|
name: "diag_part"
|
||||||
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
|
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "eig"
|
||||||
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "eigh"
|
name: "eigh"
|
||||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "eigvals"
|
||||||
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "eigvalsh"
|
name: "eigvalsh"
|
||||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -604,6 +604,14 @@ tf_module {
|
|||||||
name: "edit_distance"
|
name: "edit_distance"
|
||||||
argspec: "args=[\'hypothesis\', \'truth\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'edit_distance\'], "
|
argspec: "args=[\'hypothesis\', \'truth\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'edit_distance\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "eig"
|
||||||
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "eigvals"
|
||||||
|
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "einsum"
|
name: "einsum"
|
||||||
argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
|
argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
|
||||||
|
@ -1132,6 +1132,10 @@ tf_module {
|
|||||||
name: "EditDistance"
|
name: "EditDistance"
|
||||||
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "Eig"
|
||||||
|
argspec: "args=[\'input\', \'Tout\', \'compute_v\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Einsum"
|
name: "Einsum"
|
||||||
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user