Add an op for singular value decomposition (SVD) of a dense matrix or batches of dense matrices. This calls Eigen::JacobiSVD<Matrix, Eigen::HouseholderQRPreconditioner> which is known to be rather slow. This change is primarily intended to get the TensorFlow interfaces and functionality in place. We intend to swap out the "backend" with a higher performance algorithm implementation in the future.
This CL also contains a small refactoring of the LinearAlgebraOp base class: 1. I moved the initial processing of inputs and outputs into separate helper functions so Compute() is not so long. 2. The derived classes are now allowed to return fewer output matrix shapes (n) than the number of op outputs (m) in which case empty (shape[0]) tensors are returned for the last m-n outputs. Fixed a few Python linter errors that were blocking presubmit. Change: 128990912
This commit is contained in:
parent
48e869f0e3
commit
c0944a38a4
@ -1023,6 +1023,7 @@ tf_kernel_libraries(
|
|||||||
"matrix_solve_ls_op",
|
"matrix_solve_ls_op",
|
||||||
"matrix_solve_op",
|
"matrix_solve_op",
|
||||||
"matrix_triangular_solve_op",
|
"matrix_triangular_solve_op",
|
||||||
|
"svd_op",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":linalg_ops_common",
|
":linalg_ops_common",
|
||||||
|
@ -90,19 +90,35 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::Compute(
|
|||||||
TensorInputs inputs;
|
TensorInputs inputs;
|
||||||
TensorShapes input_matrix_shapes;
|
TensorShapes input_matrix_shapes;
|
||||||
TensorShape batch_shape;
|
TensorShape batch_shape;
|
||||||
|
AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape);
|
||||||
|
|
||||||
|
TensorShapes output_matrix_shapes;
|
||||||
|
TensorOutputs outputs;
|
||||||
|
PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
|
||||||
|
&output_matrix_shapes);
|
||||||
|
|
||||||
|
// Process the individual matrix problems in parallel using a threadpool.
|
||||||
|
auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
|
||||||
|
&output_matrix_shapes, context](int64 begin, int64 end) {
|
||||||
|
for (int64 i = begin; i < end; ++i) {
|
||||||
|
ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
|
||||||
|
output_matrix_shapes);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||||
|
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, bool SupportsBatchOperation>
|
||||||
|
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::AnalyzeInputs(
|
||||||
|
OpKernelContext* context, TensorInputs* inputs,
|
||||||
|
TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
|
||||||
int input_rank = -1;
|
int input_rank = -1;
|
||||||
int num_batch_matrices = 1;
|
|
||||||
for (int i = 0; i < NumMatrixInputs(context); ++i) {
|
for (int i = 0; i < NumMatrixInputs(context); ++i) {
|
||||||
const Tensor& in = context->input(i);
|
const Tensor& in = context->input(i);
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
// If the tensor rank is greater than 2, we consider the inner-most
|
|
||||||
// dimensions as matrices, and loop over all the other outer ("batch")
|
|
||||||
// dimensions to compute the results.
|
|
||||||
input_rank = in.dims();
|
input_rank = in.dims();
|
||||||
for (int dim = 0; dim < input_rank - 2; ++dim) {
|
|
||||||
num_batch_matrices *= in.dim_size(dim);
|
|
||||||
batch_shape.AddDim(in.dim_size(dim));
|
|
||||||
}
|
|
||||||
if (SupportsBatchOperation) {
|
if (SupportsBatchOperation) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input_rank >= 2,
|
context, input_rank >= 2,
|
||||||
@ -114,6 +130,13 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::Compute(
|
|||||||
errors::InvalidArgument("Input tensor ", i,
|
errors::InvalidArgument("Input tensor ", i,
|
||||||
" must have rank == 2, got", input_rank));
|
" must have rank == 2, got", input_rank));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the tensor rank is greater than 2, we consider the inner-most
|
||||||
|
// dimensions as matrices, and loop over all the other outer ("batch")
|
||||||
|
// dimensions to compute the results.
|
||||||
|
for (int dim = 0; dim < input_rank - 2; ++dim) {
|
||||||
|
batch_shape->AddDim(in.dim_size(dim));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Make sure that all inputs have the same rank and outer dimensions.
|
// Make sure that all inputs have the same rank and outer dimensions.
|
||||||
OP_REQUIRES(context, input_rank == in.dims(),
|
OP_REQUIRES(context, input_rank == in.dims(),
|
||||||
@ -121,7 +144,7 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::Compute(
|
|||||||
"All input tensors must have the same rank."));
|
"All input tensors must have the same rank."));
|
||||||
for (int dim = 0; dim < input_rank - 2; ++dim) {
|
for (int dim = 0; dim < input_rank - 2; ++dim) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, in.dim_size(dim) == batch_shape.dim_size(dim),
|
context, in.dim_size(dim) == batch_shape->dim_size(dim),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"All input tensors must have the same outer dimensions."));
|
"All input tensors must have the same outer dimensions."));
|
||||||
}
|
}
|
||||||
@ -131,64 +154,59 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperation>::Compute(
|
|||||||
const int col_dimension = input_rank - 1;
|
const int col_dimension = input_rank - 1;
|
||||||
const int64 num_rows = in.dim_size(row_dimension);
|
const int64 num_rows = in.dim_size(row_dimension);
|
||||||
const int64 num_cols = in.dim_size(col_dimension);
|
const int64 num_cols = in.dim_size(col_dimension);
|
||||||
input_matrix_shapes.push_back(TensorShape({num_rows, num_cols}));
|
// TODO(rmlarsen): Use emplace_back when it is added to InlinedVector. Same
|
||||||
inputs.push_back(in);
|
// in several places below.
|
||||||
|
input_matrix_shapes->push_back(TensorShape({num_rows, num_cols}));
|
||||||
|
inputs->push_back(in);
|
||||||
}
|
}
|
||||||
// Have the derived class validate that the inputs are as expected.
|
// Have the derived class validate that the inputs are as expected.
|
||||||
ValidateInputMatrixShapes(context, input_matrix_shapes);
|
ValidateInputMatrixShapes(context, *input_matrix_shapes);
|
||||||
|
|
||||||
// Get shape for each of the matrix outputs.
|
|
||||||
const TensorShapes output_matrix_shapes =
|
|
||||||
GetOutputMatrixShapes(input_matrix_shapes);
|
|
||||||
// Make sure the number of outputs is what the derived class expects.
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, output_matrix_shapes.size() == context->num_outputs(),
|
|
||||||
errors::Internal(
|
|
||||||
"Derived class expected (%d) output matrices for op, got (%d).",
|
|
||||||
output_matrix_shapes.size(), context->num_outputs()));
|
|
||||||
|
|
||||||
// Allocate outputs.
|
|
||||||
TensorShapes output_shapes;
|
|
||||||
TensorOutputs outputs;
|
|
||||||
for (int i = 0; i < context->num_outputs(); ++i) {
|
|
||||||
OP_REQUIRES(context, output_matrix_shapes[i].dims() <= 2,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
|
|
||||||
i, output_matrix_shapes[i].dims()));
|
|
||||||
|
|
||||||
// The final output has the shape of the outer batch dimensions concatenated
|
|
||||||
// with the output_matrix_shape (if the output is not scalar).
|
|
||||||
TensorShape output_shape;
|
|
||||||
if (input_rank == 2) {
|
|
||||||
output_shape = output_matrix_shapes[i];
|
|
||||||
} else {
|
|
||||||
output_shape = batch_shape;
|
|
||||||
// Add the inner dimensions that depend on the operation implemented by
|
|
||||||
// the derived class.
|
|
||||||
for (int dim = 0; dim < output_matrix_shapes[i].dims(); ++dim) {
|
|
||||||
output_shape.AddDim(output_matrix_shapes[i].dim_size(dim));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output_shapes.push_back(output_shape);
|
|
||||||
Tensor* out = nullptr;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(i, output_shape, &out));
|
|
||||||
outputs.push_back(out);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
|
|
||||||
&output_matrix_shapes, context](int64 begin, int64 end) {
|
|
||||||
for (int64 i = begin; i < end; ++i) {
|
|
||||||
ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
|
|
||||||
output_matrix_shapes);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, num_batch_matrices,
|
|
||||||
GetCostPerUnit(input_matrix_shapes), shard);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Scalar, bool SupportsBatchOperationT>
|
template <typename Scalar, bool SupportsBatchOperation>
|
||||||
void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeTensorSlice(
|
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::PrepareOutputs(
|
||||||
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes,
|
||||||
|
const TensorShape& batch_shape, TensorOutputs* outputs,
|
||||||
|
TensorShapes* output_matrix_shapes) {
|
||||||
|
// Get shape for each of the matrix outputs produced by the derived class.
|
||||||
|
*output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes);
|
||||||
|
const int num_outputs = output_matrix_shapes->size();
|
||||||
|
|
||||||
|
// Make sure the number of op outputs is what the derived class expects.
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, num_outputs <= context->num_outputs(),
|
||||||
|
errors::Internal(
|
||||||
|
"Derived class expected more outputs (%d) that the op has (%d).",
|
||||||
|
num_outputs, context->num_outputs()));
|
||||||
|
|
||||||
|
// Allocate outputs.
|
||||||
|
for (int i = 0; i < context->num_outputs(); ++i) {
|
||||||
|
TensorShape output_tensor_shape({0});
|
||||||
|
if (i < num_outputs) {
|
||||||
|
// This output is used, set up output shape and allocate it.
|
||||||
|
const TensorShape& output_matrix_shape = output_matrix_shapes->at(i);
|
||||||
|
OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
|
||||||
|
i, output_matrix_shape.dims()));
|
||||||
|
|
||||||
|
// The final output has the shape of the outer batch dimensions
|
||||||
|
// concatenated with the output_matrix_shape (if the output is not
|
||||||
|
// scalar).
|
||||||
|
output_tensor_shape = batch_shape;
|
||||||
|
for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) {
|
||||||
|
output_tensor_shape.AddDim(output_matrix_shape.dim_size(dim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor* out = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(i, output_tensor_shape, &out));
|
||||||
|
outputs->push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, bool SupportsBatchOperation>
|
||||||
|
void LinearAlgebraOp<Scalar, SupportsBatchOperation>::ComputeTensorSlice(
|
||||||
OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
|
OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
|
||||||
const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
|
const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
|
||||||
const TensorShapes& output_matrix_shapes) {
|
const TensorShapes& output_matrix_shapes) {
|
||||||
@ -204,7 +222,7 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeTensorSlice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
MatrixMaps matrix_outputs;
|
MatrixMaps matrix_outputs;
|
||||||
for (int i = 0; i < outputs.size(); ++i) {
|
for (int i = 0; i < output_matrix_shapes.size(); ++i) {
|
||||||
// The output matrix shape may not be a matrix.
|
// The output matrix shape may not be a matrix.
|
||||||
int num_output_rows = output_matrix_shapes[i].dims() >= 1
|
int num_output_rows = output_matrix_shapes[i].dims() >= 1
|
||||||
? output_matrix_shapes[i].dim_size(0)
|
? output_matrix_shapes[i].dim_size(0)
|
||||||
|
@ -43,7 +43,7 @@ template <typename Scalar, bool SupportsBatchOperationT>
|
|||||||
class LinearAlgebraOp : public OpKernel {
|
class LinearAlgebraOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
~LinearAlgebraOp() override {}
|
|
||||||
void Compute(OpKernelContext* context) override;
|
void Compute(OpKernelContext* context) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -81,18 +81,25 @@ class LinearAlgebraOp : public OpKernel {
|
|||||||
|
|
||||||
// Returns the output shapes of each individual matrix operation. Output
|
// Returns the output shapes of each individual matrix operation. Output
|
||||||
// matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
|
// matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
|
||||||
// For many ops the output dimensions are the same as the input dimensions,
|
//
|
||||||
|
// The derived class may return a number of shapes (N) less than
|
||||||
|
// context->num_outputs() (M) to indicate that a only leading subset of
|
||||||
|
// the outputs will be populated. In this case, a dummy scalar tensor with
|
||||||
|
// value zero will be return for the last M-N outputs.
|
||||||
|
//
|
||||||
|
// For many ops, the output dimensions are the same as the input dimensions,
|
||||||
// so we provide that as a default implementation for convenience.
|
// so we provide that as a default implementation for convenience.
|
||||||
virtual TensorShapes GetOutputMatrixShapes(
|
virtual TensorShapes GetOutputMatrixShapes(
|
||||||
const TensorShapes& input_matrix_shapes) const {
|
const TensorShapes& input_matrix_shapes) const {
|
||||||
return input_matrix_shapes;
|
return input_matrix_shapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the cost per matrix operation. Cost per unit is assumed to be
|
// Returns the cost per matrix operation. This is used to determine the
|
||||||
// roughly 1ns, based on comments in core/util/work_sharder.cc.
|
// number of threads to use for parallelizing calls to ComputeMatrix in
|
||||||
// Many linear algebra ops take roughly max(m,n) * min(m,n)^2, where the first
|
// batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
|
||||||
// input matrix is m-by-n. We provide that as a default implementation for
|
// in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
|
||||||
// convenience.
|
// * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
|
||||||
|
// default implementation for convenience.
|
||||||
virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
|
virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
|
||||||
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
||||||
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
|
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
|
||||||
@ -111,7 +118,9 @@ class LinearAlgebraOp : public OpKernel {
|
|||||||
// Performs a single matrix computation given input matrices, and
|
// Performs a single matrix computation given input matrices, and
|
||||||
// stores the result in outputs. For batch operations, this will be called
|
// stores the result in outputs. For batch operations, this will be called
|
||||||
// repeatedly for a single call to Compute() when multiple matrices exist in
|
// repeatedly for a single call to Compute() when multiple matrices exist in
|
||||||
// input Tensors with rank > 2.
|
// input Tensors with rank > 2. In this case the calls to ComputeMatrix are
|
||||||
|
// parallelized. The number of threads used is determined by a cost model from
|
||||||
|
// the value returned by GetCostPerUnit().
|
||||||
virtual void ComputeMatrix(OpKernelContext* context,
|
virtual void ComputeMatrix(OpKernelContext* context,
|
||||||
const ConstMatrixMaps& inputs,
|
const ConstMatrixMaps& inputs,
|
||||||
MatrixMaps* outputs) = 0;
|
MatrixMaps* outputs) = 0;
|
||||||
@ -142,6 +151,15 @@ class LinearAlgebraOp : public OpKernel {
|
|||||||
const TensorShapes& input_matrix_shapes,
|
const TensorShapes& input_matrix_shapes,
|
||||||
const TensorOutputs& outputs,
|
const TensorOutputs& outputs,
|
||||||
const TensorShapes& output_matrix_shapes);
|
const TensorShapes& output_matrix_shapes);
|
||||||
|
|
||||||
|
void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
|
||||||
|
TensorShapes* input_matrix_shapes,
|
||||||
|
TensorShape* batch_shape);
|
||||||
|
|
||||||
|
void PrepareOutputs(OpKernelContext* context,
|
||||||
|
const TensorShapes& input_matrix_shapes,
|
||||||
|
const TensorShape& batch_shape, TensorOutputs* outputs,
|
||||||
|
TensorShapes* output_matrix_shapes);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Declare that LinearAlgebraOp is explicitly instantiated in
|
// Declare that LinearAlgebraOp is explicitly instantiated in
|
||||||
|
105
tensorflow/core/kernels/svd_op.cc
Normal file
105
tensorflow/core/kernels/svd_op.cc
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/SVD"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
template <class Scalar, bool SupportsBatchOperation>
|
||||||
|
class SvdOp : public LinearAlgebraOp<Scalar, SupportsBatchOperation> {
|
||||||
|
public:
|
||||||
|
typedef LinearAlgebraOp<Scalar, SupportsBatchOperation> Base;
|
||||||
|
|
||||||
|
explicit SvdOp(OpKernelConstruction* context) : Base(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("compute_uv", &compute_uv_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
|
||||||
|
}
|
||||||
|
|
||||||
|
using TensorShapes = typename Base::TensorShapes;
|
||||||
|
|
||||||
|
void ValidateInputMatrixShapes(
|
||||||
|
OpKernelContext* context,
|
||||||
|
const TensorShapes& input_matrix_shapes) const final {
|
||||||
|
Base::ValidateSingleMatrix(context, input_matrix_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShapes GetOutputMatrixShapes(
|
||||||
|
const TensorShapes& input_matrix_shapes) const final {
|
||||||
|
int64 m = input_matrix_shapes[0].dim_size(0);
|
||||||
|
int64 n = input_matrix_shapes[0].dim_size(1);
|
||||||
|
int64 min_size = std::min(m, n);
|
||||||
|
if (compute_uv_) {
|
||||||
|
return TensorShapes({TensorShape({min_size}),
|
||||||
|
TensorShape({m, full_matrices_ ? m : min_size}),
|
||||||
|
TensorShape({n, full_matrices_ ? n : min_size})});
|
||||||
|
} else {
|
||||||
|
return TensorShapes({TensorShape({min_size})});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(rmlarsen): This should depend on compute_uv. See b/30409375.
|
||||||
|
int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
|
||||||
|
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
||||||
|
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
|
||||||
|
double cost = 12 * std::max(m, n) * std::min(m, n) * std::min(m, n);
|
||||||
|
return cost >= static_cast<double>(kint64max) ? kint64max
|
||||||
|
: static_cast<int64>(cost);
|
||||||
|
}
|
||||||
|
|
||||||
|
using Matrix = typename Base::Matrix;
|
||||||
|
using MatrixMaps = typename Base::MatrixMaps;
|
||||||
|
using ConstMatrixMap = typename Base::ConstMatrixMap;
|
||||||
|
using ConstMatrixMaps = typename Base::ConstMatrixMaps;
|
||||||
|
|
||||||
|
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
|
||||||
|
MatrixMaps* outputs) final {
|
||||||
|
Eigen::JacobiSVD<Matrix, Eigen::HouseholderQRPreconditioner> svd;
|
||||||
|
if (compute_uv_) {
|
||||||
|
svd.compute(inputs[0],
|
||||||
|
(full_matrices_ ? Eigen::ComputeFullU | Eigen::ComputeFullV
|
||||||
|
: Eigen::ComputeThinU | Eigen::ComputeThinV));
|
||||||
|
outputs->at(0) = svd.singularValues();
|
||||||
|
outputs->at(1) = svd.matrixU();
|
||||||
|
outputs->at(2) = svd.matrixV();
|
||||||
|
} else {
|
||||||
|
svd.compute(inputs[0]);
|
||||||
|
outputs->at(0) = svd.singularValues();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool compute_uv_;
|
||||||
|
bool full_matrices_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(SvdOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_LINALG_OP("Svd", (SvdOp<float, false>), float);
|
||||||
|
REGISTER_LINALG_OP("Svd", (SvdOp<double, false>), double);
|
||||||
|
REGISTER_LINALG_OP("BatchSvd", (SvdOp<float, true>), float);
|
||||||
|
REGISTER_LINALG_OP("BatchSvd", (SvdOp<double, true>), double);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -128,7 +128,7 @@ REGISTER_OP("MatrixDeterminant")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the determinant of a square matrix.
|
Computes the determinant of a square matrix.
|
||||||
|
|
||||||
input: A tensor of shape `[M, M]`.
|
input: A tensor of shape `[M, M]`.
|
||||||
output: A scalar, equal to the determinant of the input.
|
output: A scalar, equal to the determinant of the input.
|
||||||
@ -152,7 +152,7 @@ REGISTER_OP("BatchMatrixDeterminant")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the determinants for a batch of square matrices.
|
Computes the determinants for a batch of square matrices.
|
||||||
|
|
||||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||||
form square matrices. The output is a tensor containing the determinants
|
form square matrices. The output is a tensor containing the determinants
|
||||||
@ -169,7 +169,7 @@ REGISTER_OP("MatrixInverse")
|
|||||||
.Attr("T: {double, float}")
|
.Attr("T: {double, float}")
|
||||||
.SetShapeFn(UnchangedSquareShapeFn)
|
.SetShapeFn(UnchangedSquareShapeFn)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the inverse of a square invertible matrix or its adjoint (conjugate
|
Computes the inverse of a square invertible matrix or its adjoint (conjugate
|
||||||
transpose).
|
transpose).
|
||||||
|
|
||||||
The op uses LU decomposition with partial pivoting to compute the inverse.
|
The op uses LU decomposition with partial pivoting to compute the inverse.
|
||||||
@ -191,7 +191,7 @@ REGISTER_OP("BatchMatrixInverse")
|
|||||||
.Attr("T: {double, float}")
|
.Attr("T: {double, float}")
|
||||||
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the inverse of square invertible matrices or their adjoints
|
Computes the inverse of square invertible matrices or their adjoints
|
||||||
(conjugate transposes).
|
(conjugate transposes).
|
||||||
|
|
||||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||||
@ -214,7 +214,7 @@ REGISTER_OP("Cholesky")
|
|||||||
.Attr("T: {double, float}")
|
.Attr("T: {double, float}")
|
||||||
.SetShapeFn(UnchangedSquareShapeFn)
|
.SetShapeFn(UnchangedSquareShapeFn)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the Cholesky decomposition of a square matrix.
|
Computes the Cholesky decomposition of a square matrix.
|
||||||
|
|
||||||
The input has to be symmetric and positive definite. Only the lower-triangular
|
The input has to be symmetric and positive definite. Only the lower-triangular
|
||||||
part of the input will be used for this operation. The upper-triangular part
|
part of the input will be used for this operation. The upper-triangular part
|
||||||
@ -233,7 +233,7 @@ REGISTER_OP("BatchCholesky")
|
|||||||
.Attr("T: {double, float}")
|
.Attr("T: {double, float}")
|
||||||
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the Cholesky decomposition of a batch of square matrices.
|
Computes the Cholesky decomposition of a batch of square matrices.
|
||||||
|
|
||||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||||
form square matrices, with the same constraints as the single matrix Cholesky
|
form square matrices, with the same constraints as the single matrix Cholesky
|
||||||
@ -251,7 +251,7 @@ REGISTER_OP("CholeskyGrad")
|
|||||||
.Attr("T: {float, double}")
|
.Attr("T: {float, double}")
|
||||||
.SetShapeFn(UnchangedSquareShapeFn)
|
.SetShapeFn(UnchangedSquareShapeFn)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
Computes the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||||
|
|
||||||
For an explanation see "Differentiation of the Cholesky algorithm" by
|
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||||
Iain Murray http://arxiv.org/abs/1602.07527.
|
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||||
@ -270,7 +270,7 @@ REGISTER_OP("BatchCholeskyGrad")
|
|||||||
.Attr("T: {float, double}")
|
.Attr("T: {float, double}")
|
||||||
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
.SetShapeFn(BatchUnchangedSquareShapeFn)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
Computes the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||||
|
|
||||||
For an explanation see "Differentiation of the Cholesky algorithm" by
|
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||||
Iain Murray http://arxiv.org/abs/1602.07527.
|
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||||
@ -299,7 +299,7 @@ REGISTER_OP("SelfAdjointEig")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the Eigen Decomposition of a square Self-Adjoint matrix.
|
Computes the Eigen Decomposition of a square Self-Adjoint matrix.
|
||||||
|
|
||||||
Only the lower-triangular part of the input will be used in this case. The
|
Only the lower-triangular part of the input will be used in this case. The
|
||||||
upper-triangular part will not be read.
|
upper-triangular part will not be read.
|
||||||
@ -330,7 +330,7 @@ REGISTER_OP("BatchSelfAdjointEig")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Calculates the Eigen Decomposition of a batch of square self-adjoint matrices.
|
Computes the Eigen Decomposition of a batch of square self-adjoint matrices.
|
||||||
|
|
||||||
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||||
form square matrices, with the same constraints as the single matrix
|
form square matrices, with the same constraints as the single matrix
|
||||||
@ -526,10 +526,10 @@ REGISTER_OP("BatchMatrixSolveLs")
|
|||||||
Solves multiple linear least-squares problems.
|
Solves multiple linear least-squares problems.
|
||||||
|
|
||||||
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
||||||
form square matrices. Rhs is a tensor of shape `[..., M, K]`. The output
|
form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`.
|
||||||
is a tensor shape `[..., N, K]` where each output matrix solves each of
|
The output is a tensor shape `[..., N, K]` where each output matrix solves
|
||||||
the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the
|
each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]
|
||||||
least squares sense.
|
in the least squares sense.
|
||||||
|
|
||||||
Below we will use the following notation for each pair of
|
Below we will use the following notation for each pair of
|
||||||
matrix and right-hand sides in the batch:
|
matrix and right-hand sides in the batch:
|
||||||
@ -563,4 +563,82 @@ rhs: Shape is `[..., M, K]`.
|
|||||||
output: Shape is `[..., N, K]`.
|
output: Shape is `[..., N, K]`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("Svd")
|
||||||
|
.Input("input: T")
|
||||||
|
.Output("s: T")
|
||||||
|
.Output("u: T")
|
||||||
|
.Output("v: T")
|
||||||
|
.Attr("compute_uv: bool = False")
|
||||||
|
.Attr("full_matrices: bool = False")
|
||||||
|
.Attr("T: {double, float}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes the singular value decomposition of a matrix.
|
||||||
|
|
||||||
|
Computes the SVD of if `input` such that `input = u * diag(s) * transpose(v)`
|
||||||
|
|
||||||
|
```prettyprint
|
||||||
|
# a is a matrix.
|
||||||
|
# s is a vector of singular values.
|
||||||
|
# u is the matrix of left singular vectors.
|
||||||
|
# v is a matrix of right singular vectors.
|
||||||
|
s, _, _ = svd(a, compute_uv=False)
|
||||||
|
s, u, v = svd(a, compute_uv=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
input: Shape is `[M, N]`. Let `P` be the minimum of `M` and `N`.
|
||||||
|
s: Singular values. Shape is `[P]`.
|
||||||
|
u: Left singular vectors; if `full_matrices` is `False` then shape is `[M, M]`.
|
||||||
|
If `full_matrices` is `True` then shape is `[M, P]`.
|
||||||
|
Undefined if `compute_uv` is `False`.
|
||||||
|
v: Left singular vectors. If `full_matrices` is `False` then shape is `[N, N]`.
|
||||||
|
If `full_matrices` is `True` then shape is `[N, P]`.
|
||||||
|
Undefined if `compute_uv` is false.
|
||||||
|
compute_uv: If true, left and right singular vectors will be
|
||||||
|
computed and returned in `u` and `v`, respectively.
|
||||||
|
If false, `u` and `v` are not set and should never referenced.
|
||||||
|
full_matrices: If true, compute full-sized `u` and `v`. If false
|
||||||
|
(the default), compute only the leading `P` singular vectors.
|
||||||
|
Ignored if `compute_uv` is `False`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("BatchSvd")
|
||||||
|
.Input("input: T")
|
||||||
|
.Output("s: T")
|
||||||
|
.Output("u: T")
|
||||||
|
.Output("v: T")
|
||||||
|
.Attr("compute_uv: bool = False")
|
||||||
|
.Attr("full_matrices: bool = False")
|
||||||
|
.Attr("T: {double, float}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Computes the singular value decompositions of a batch of matrices.
|
||||||
|
|
||||||
|
Computes the SVD of each inner matrix in `input` such that
|
||||||
|
`input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])`
|
||||||
|
|
||||||
|
```prettyprint
|
||||||
|
# a is a tensor containing a batch of matrices.
|
||||||
|
# s is a tensor of singular values for each matrix.
|
||||||
|
# u is the tensor containing of left singular vectors for each matrix.
|
||||||
|
# v is the tensor containing of right singular vectors for each matrix.
|
||||||
|
s, _, _ = batch_svd(a, compute_uv=False)
|
||||||
|
s, u, v = batch_svd(a, compute_uv=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
||||||
|
form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
|
||||||
|
s: Singular values. Shape is `[..., P]`.
|
||||||
|
u: Left singular vectors. If `full_matrices` is `False` then shape is
|
||||||
|
`[..., M, M]`; if `full_matrices` is `True` then shape is
|
||||||
|
`[..., M, P]`. Undefined if `compute_uv` is `False`.
|
||||||
|
v: Left singular vectors. If `full_matrices` is `False` then shape is
|
||||||
|
`[..., N, N]`. If `full_matrices` is `True` then shape is `[..., N, P]`.
|
||||||
|
Undefined if `compute_uv` is false.
|
||||||
|
compute_uv: If true, left and right singular vectors will be
|
||||||
|
computed and returned in `u` and `v`, respectively.
|
||||||
|
If false, `u` and `v` are not set and should never referenced.
|
||||||
|
full_matrices: If true, compute full-sized `u` and `v`. If false
|
||||||
|
(the default), compute only the leading `P` singular vectors.
|
||||||
|
Ignored if `compute_uv` is `False`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -209,6 +209,7 @@ cuda_py_tests(
|
|||||||
"cwise_ops_test.py",
|
"cwise_ops_test.py",
|
||||||
"embedding_ops_test.py",
|
"embedding_ops_test.py",
|
||||||
"linalg_grad_test.py",
|
"linalg_grad_test.py",
|
||||||
|
"svd_op_test.py",
|
||||||
],
|
],
|
||||||
shard_count = 50,
|
shard_count = 50,
|
||||||
tags = ["notap"], # b/30226163
|
tags = ["notap"], # b/30226163
|
||||||
|
112
tensorflow/python/kernel_tests/svd_op_test.py
Normal file
112
tensorflow/python/kernel_tests/svd_op_test.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class SvdOpTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testWrongDimensions(self):
|
||||||
|
# The input to svd should be 2-dimensional tensor.
|
||||||
|
scalar = tf.constant(1.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.svd(scalar)
|
||||||
|
vector = tf.constant([1., 2.])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.svd(vector)
|
||||||
|
tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.svd(tensor)
|
||||||
|
|
||||||
|
# The input to batch_svd should be a tensor of at least rank 2.
|
||||||
|
scalar = tf.constant(1.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.batch_svd(scalar)
|
||||||
|
vector = tf.constant([1., 2.])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.batch_svd(vector)
|
||||||
|
|
||||||
|
|
||||||
|
def _GetSvdOpTest(dtype_, shape_):
|
||||||
|
|
||||||
|
def _CompareSingularVectors(self, x, y, atol):
|
||||||
|
# Singular vectors are only unique up to sign (complex phase factor for
|
||||||
|
# complex matrices), so we normalize the signs first.
|
||||||
|
signs = np.sign(np.sum(np.divide(x, y), -2, keepdims=True))
|
||||||
|
x *= signs
|
||||||
|
self.assertAllClose(x, y, atol=atol)
|
||||||
|
|
||||||
|
def Test(self):
|
||||||
|
np.random.seed(1)
|
||||||
|
x = np.random.uniform(
|
||||||
|
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
|
||||||
|
if dtype_ == np.float32:
|
||||||
|
atol = 1e-4
|
||||||
|
else:
|
||||||
|
atol = 1e-14
|
||||||
|
for compute_uv in False, True:
|
||||||
|
for full_matrices in False, True:
|
||||||
|
with self.test_session():
|
||||||
|
if x.ndim == 2:
|
||||||
|
if compute_uv:
|
||||||
|
tf_s, tf_u, tf_v = tf.svd(tf.constant(x),
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
else:
|
||||||
|
tf_s = tf.svd(tf.constant(x),
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
else:
|
||||||
|
if compute_uv:
|
||||||
|
tf_s, tf_u, tf_v = tf.batch_svd(
|
||||||
|
tf.constant(x),
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
else:
|
||||||
|
tf_s = tf.batch_svd(
|
||||||
|
tf.constant(x),
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
if compute_uv:
|
||||||
|
np_u, np_s, np_v = np.linalg.svd(x,
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
else:
|
||||||
|
np_s = np.linalg.svd(x,
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
self.assertAllClose(np_s, tf_s.eval(), atol=atol)
|
||||||
|
if compute_uv:
|
||||||
|
_CompareSingularVectors(self, np_u, tf_u.eval(), atol)
|
||||||
|
_CompareSingularVectors(self, np.swapaxes(np_v, -2, -1),
|
||||||
|
tf_v.eval(), atol)
|
||||||
|
|
||||||
|
return Test
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
for dtype in np.float32, np.float64:
|
||||||
|
for m in 1, 2, 5, 10:
|
||||||
|
for n in 1, 2, 5, 10:
|
||||||
|
for batch_dims in [(), (3,)] + [(3, 2)] * (max(m, n) < 10):
|
||||||
|
shape = batch_dims + (m, n)
|
||||||
|
name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
|
||||||
|
setattr(SvdOpTest, 'testSvd_' + name, _GetSvdOpTest(dtype, shape))
|
||||||
|
tf.test.main()
|
@ -32,6 +32,10 @@ from tensorflow.python.ops import math_ops
|
|||||||
|
|
||||||
ops.NoGradient("CholeskyGrad")
|
ops.NoGradient("CholeskyGrad")
|
||||||
ops.NoGradient("BatchCholeskyGrad")
|
ops.NoGradient("BatchCholeskyGrad")
|
||||||
|
ops.NoGradient("SelfAdjointEig")
|
||||||
|
ops.NoGradient("BatchSelfAdjointEig")
|
||||||
|
ops.NoGradient("Svd")
|
||||||
|
ops.NoGradient("BatchSvd")
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("MatrixInverse")
|
@ops.RegisterGradient("MatrixInverse")
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.ops.gen_linalg_ops import *
|
|||||||
@ops.RegisterShape("CholeskyGrad")
|
@ops.RegisterShape("CholeskyGrad")
|
||||||
@ops.RegisterShape("MatrixInverse")
|
@ops.RegisterShape("MatrixInverse")
|
||||||
def _UnchangedSquare(op):
|
def _UnchangedSquare(op):
|
||||||
|
"""Shape function for matrix ops with output equal to input shape."""
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
# The matrix must be square.
|
# The matrix must be square.
|
||||||
input_shape[0].assert_is_compatible_with(input_shape[1])
|
input_shape[0].assert_is_compatible_with(input_shape[1])
|
||||||
@ -41,6 +42,7 @@ def _UnchangedSquare(op):
|
|||||||
@ops.RegisterShape("BatchCholeskyGrad")
|
@ops.RegisterShape("BatchCholeskyGrad")
|
||||||
@ops.RegisterShape("BatchMatrixInverse")
|
@ops.RegisterShape("BatchMatrixInverse")
|
||||||
def _BatchUnchangedSquare(op):
|
def _BatchUnchangedSquare(op):
|
||||||
|
"""Shape function for batch matrix ops with output equal to input shape."""
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
# The matrices in the batch must be square.
|
# The matrices in the batch must be square.
|
||||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||||
@ -48,6 +50,7 @@ def _BatchUnchangedSquare(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("MatrixDeterminant")
|
@ops.RegisterShape("MatrixDeterminant")
|
||||||
def _MatrixDeterminantShape(op):
|
def _MatrixDeterminantShape(op):
|
||||||
|
"""Shape function for determinant op."""
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
# The matrix must be square.
|
# The matrix must be square.
|
||||||
input_shape[0].assert_is_compatible_with(input_shape[1])
|
input_shape[0].assert_is_compatible_with(input_shape[1])
|
||||||
@ -59,6 +62,7 @@ def _MatrixDeterminantShape(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("BatchMatrixDeterminant")
|
@ops.RegisterShape("BatchMatrixDeterminant")
|
||||||
def _BatchMatrixDeterminantShape(op):
|
def _BatchMatrixDeterminantShape(op):
|
||||||
|
"""Shape function for batch determinant op."""
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
# The matrices in the batch must be square.
|
# The matrices in the batch must be square.
|
||||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||||
@ -70,6 +74,7 @@ def _BatchMatrixDeterminantShape(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("SelfAdjointEig")
|
@ops.RegisterShape("SelfAdjointEig")
|
||||||
def _SelfAdjointEigShape(op):
|
def _SelfAdjointEigShape(op):
|
||||||
|
"""Shape function for self-adjoint eigensolver op."""
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(2)
|
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
# The matrix must be square.
|
# The matrix must be square.
|
||||||
input_shape[0].assert_is_compatible_with(input_shape[1])
|
input_shape[0].assert_is_compatible_with(input_shape[1])
|
||||||
@ -80,6 +85,7 @@ def _SelfAdjointEigShape(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("BatchSelfAdjointEig")
|
@ops.RegisterShape("BatchSelfAdjointEig")
|
||||||
def _BatchSelfAdjointEigShape(op):
|
def _BatchSelfAdjointEigShape(op):
|
||||||
|
"""Shape function for batch self-adjoint eigensolver op."""
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
# The matrices in the batch must be square.
|
# The matrices in the batch must be square.
|
||||||
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
input_shape[-1].assert_is_compatible_with(input_shape[-2])
|
||||||
@ -89,9 +95,63 @@ def _BatchSelfAdjointEigShape(op):
|
|||||||
return [out_shape]
|
return [out_shape]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterShape("Svd")
|
||||||
|
def _SvdShape(op):
|
||||||
|
"""Shape function for SVD op."""
|
||||||
|
input_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
|
unknown = tensor_shape.unknown_shape()
|
||||||
|
compute_uv = op.get_attr("compute_uv")
|
||||||
|
if input_shape.ndims is not None:
|
||||||
|
return [unknown, unknown, unknown]
|
||||||
|
full_matrices = op.get_attr("full_matrices")
|
||||||
|
m = input_shape.dims[0]
|
||||||
|
n = input_shape.dims[1]
|
||||||
|
p = min(m, n)
|
||||||
|
s_shape = tensor_shape.TensorShape([p])
|
||||||
|
if compute_uv:
|
||||||
|
if full_matrices:
|
||||||
|
u_shape = tensor_shape.TensorShape([m, m])
|
||||||
|
v_shape = tensor_shape.TensorShape([n, n])
|
||||||
|
else:
|
||||||
|
u_shape = tensor_shape.TensorShape([m, p])
|
||||||
|
v_shape = tensor_shape.TensorShape([n, p])
|
||||||
|
else:
|
||||||
|
u_shape = [0]
|
||||||
|
v_shape = [0]
|
||||||
|
return [s_shape, u_shape, v_shape]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterShape("BatchSvd")
|
||||||
|
def _BatchSvdShape(op):
|
||||||
|
"""Shape function for batch SVD op."""
|
||||||
|
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
|
unknown = tensor_shape.unknown_shape()
|
||||||
|
if input_shape.ndims is not None:
|
||||||
|
return [unknown, unknown, unknown]
|
||||||
|
compute_uv = op.get_attr("compute_uv")
|
||||||
|
full_matrices = op.get_attr("full_matrices")
|
||||||
|
m = input_shape.dims[-2]
|
||||||
|
n = input_shape.dims[-1]
|
||||||
|
p = min(m, n)
|
||||||
|
batch_shape = input_shape.dims[:-2]
|
||||||
|
s_shape = batch_shape.concatenate([p])
|
||||||
|
if compute_uv:
|
||||||
|
if full_matrices:
|
||||||
|
u_shape = batch_shape.concatenate([m, m])
|
||||||
|
v_shape = batch_shape.concatenate([n, n])
|
||||||
|
else:
|
||||||
|
u_shape = batch_shape.concatenate([m, p])
|
||||||
|
v_shape = batch_shape.concatenate([n, p])
|
||||||
|
else:
|
||||||
|
u_shape = [0]
|
||||||
|
v_shape = [0]
|
||||||
|
return [s_shape, u_shape, v_shape]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("MatrixSolve")
|
@ops.RegisterShape("MatrixSolve")
|
||||||
@ops.RegisterShape("MatrixTriangularSolve")
|
@ops.RegisterShape("MatrixTriangularSolve")
|
||||||
def _SquareMatrixSolveShape(op):
|
def _SquareMatrixSolveShape(op):
|
||||||
|
"""Shape function for square matrix solver ops."""
|
||||||
lhs_shape = op.inputs[0].get_shape().with_rank(2)
|
lhs_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
rhs_shape = op.inputs[1].get_shape().with_rank(2)
|
rhs_shape = op.inputs[1].get_shape().with_rank(2)
|
||||||
# The matrix must be square.
|
# The matrix must be square.
|
||||||
@ -104,6 +164,7 @@ def _SquareMatrixSolveShape(op):
|
|||||||
@ops.RegisterShape("BatchMatrixSolve")
|
@ops.RegisterShape("BatchMatrixSolve")
|
||||||
@ops.RegisterShape("BatchMatrixTriangularSolve")
|
@ops.RegisterShape("BatchMatrixTriangularSolve")
|
||||||
def _BatchSquareMatrixSolveShape(op):
|
def _BatchSquareMatrixSolveShape(op):
|
||||||
|
"""Shape function for batch square matrix solver ops."""
|
||||||
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
||||||
# The matrices must be square.
|
# The matrices must be square.
|
||||||
@ -116,6 +177,7 @@ def _BatchSquareMatrixSolveShape(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("MatrixSolveLs")
|
@ops.RegisterShape("MatrixSolveLs")
|
||||||
def _MatrixSolveLsShape(op):
|
def _MatrixSolveLsShape(op):
|
||||||
|
"""Shape function for least-squares matrix solver op."""
|
||||||
lhs_shape = op.inputs[0].get_shape().with_rank(2)
|
lhs_shape = op.inputs[0].get_shape().with_rank(2)
|
||||||
rhs_shape = op.inputs[1].get_shape().with_rank(2)
|
rhs_shape = op.inputs[1].get_shape().with_rank(2)
|
||||||
# The matrix and right-hand side must have the same number of rows.
|
# The matrix and right-hand side must have the same number of rows.
|
||||||
@ -125,6 +187,7 @@ def _MatrixSolveLsShape(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("BatchMatrixSolveLs")
|
@ops.RegisterShape("BatchMatrixSolveLs")
|
||||||
def _BatchMatrixSolveLsShape(op):
|
def _BatchMatrixSolveLsShape(op):
|
||||||
|
"""Shape function for batch least-squares matrix solver op."""
|
||||||
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||||
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
||||||
# The matrices and right-hand sides in the batch must have the same number of
|
# The matrices and right-hand sides in the batch must have the same number of
|
||||||
@ -331,4 +394,92 @@ def batch_matrix_solve_ls(matrix,
|
|||||||
fast=fast,
|
fast=fast,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def svd(matrix, compute_uv=False, full_matrices=False, name=None):
|
||||||
|
"""Computes the singular value decomposition of a matrix.
|
||||||
|
|
||||||
|
Computes the SVD of if `matrix` such that `matrix = u * diag(s) *
|
||||||
|
transpose(v)`
|
||||||
|
|
||||||
|
```prettyprint
|
||||||
|
# a is a matrix.
|
||||||
|
# s is a vector of singular values.
|
||||||
|
# u is the matrix of left singular vectors.
|
||||||
|
# v is a matrix of right singular vectors.
|
||||||
|
s = svd(a, compute_uv=False)
|
||||||
|
s, u, v = svd(a, compute_uv=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matrix: `Tensor` of shape `[M, N]`. Let `P` be the minimum of `M` and `N`.
|
||||||
|
compute_uv: If `True` then left and right singular vectors will be
|
||||||
|
computed and returned in `u` and `v`, respectively. Otherwise, only the
|
||||||
|
singular values will be computed.
|
||||||
|
full_matrices: If true, compute full-sized `u` and `v`. If false
|
||||||
|
(the default), compute only the leading `P` singular vectors.
|
||||||
|
Ignored if `compute_uv` is `False`.
|
||||||
|
name: string, optional name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
s: Singular values. Shape is `[P]`.
|
||||||
|
u: Right singular vectors. If `full_matrices` is `False` (default) then
|
||||||
|
shape is `[M, P]`; if `full_matrices` is `True` then shape is
|
||||||
|
`[M, M]`. Not returned if `compute_uv` is `False`.
|
||||||
|
v: Left singular vectors. If `full_matrices` is `False` (default) then
|
||||||
|
shape is `[N, P]`. If `full_matrices` is `True` then shape is
|
||||||
|
`[N, N]`. Not returned if `compute_uv` is `False`.
|
||||||
|
"""
|
||||||
|
s, u, v = gen_linalg_ops.svd(matrix,
|
||||||
|
compute_uv=compute_uv,
|
||||||
|
full_matrices=full_matrices)
|
||||||
|
if compute_uv:
|
||||||
|
return s, u, v
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def batch_svd(tensor, compute_uv=False, full_matrices=False, name=None):
|
||||||
|
"""Computes the singular value decompositions of a batch of matrices.
|
||||||
|
|
||||||
|
Computes the SVD of each inner matrix in `tensor` such that
|
||||||
|
`tensor[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :,
|
||||||
|
:])`
|
||||||
|
|
||||||
|
```prettyprint
|
||||||
|
# a is a tensor.
|
||||||
|
# s is a tensor of singular values.
|
||||||
|
# u is a tensor of left singular vectors.
|
||||||
|
# v is a tensor of right singular vectors.
|
||||||
|
s = batch_svd(a, compute_uv=False)
|
||||||
|
s, u, v = batch_svd(a, compute_uv=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matrix: `Tensor` of shape `[..., M, N]`. Let `P` be the minimum of `M` and
|
||||||
|
`N`.
|
||||||
|
compute_uv: If `True` then left and right singular vectors will be
|
||||||
|
computed and returned in `u` and `v`, respectively. Otherwise, only the
|
||||||
|
singular values will be computed.
|
||||||
|
full_matrices: If true, compute full-sized `u` and `v`. If false
|
||||||
|
(the default), compute only the leading `P` singular vectors.
|
||||||
|
Ignored if `compute_uv` is `False`.
|
||||||
|
name: string, optional name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
s: Singular values. Shape is `[..., P]`.
|
||||||
|
u: Right singular vectors. If `full_matrices` is `False` (default) then
|
||||||
|
shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
|
||||||
|
`[..., M, M]`. Not returned if `compute_uv` is `False`.
|
||||||
|
v: Left singular vectors. If `full_matrices` is `False` (default) then
|
||||||
|
shape is `[..., N, P]`. If `full_matrices` is `True` then shape is
|
||||||
|
`[..., N, N]`. Not returned if `compute_uv` is `False`.
|
||||||
|
"""
|
||||||
|
s, u, v = gen_linalg_ops.batch_svd(
|
||||||
|
tensor, compute_uv=compute_uv, full_matrices=full_matrices)
|
||||||
|
if compute_uv:
|
||||||
|
return s, u, v
|
||||||
|
else:
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
|
@ -98,9 +98,6 @@ functions on matrices to your graph.
|
|||||||
@@cholesky_solve
|
@@cholesky_solve
|
||||||
@@batch_cholesky_solve
|
@@batch_cholesky_solve
|
||||||
|
|
||||||
@@self_adjoint_eig
|
|
||||||
@@batch_self_adjoint_eig
|
|
||||||
|
|
||||||
@@matrix_solve
|
@@matrix_solve
|
||||||
@@batch_matrix_solve
|
@@batch_matrix_solve
|
||||||
|
|
||||||
@ -110,6 +107,12 @@ functions on matrices to your graph.
|
|||||||
@@matrix_solve_ls
|
@@matrix_solve_ls
|
||||||
@@batch_matrix_solve_ls
|
@@batch_matrix_solve_ls
|
||||||
|
|
||||||
|
@@self_adjoint_eig
|
||||||
|
@@batch_self_adjoint_eig
|
||||||
|
|
||||||
|
@@svd
|
||||||
|
@@batch_svd
|
||||||
|
|
||||||
## Complex Number Functions
|
## Complex Number Functions
|
||||||
|
|
||||||
TensorFlow provides several operations that you can use to add complex number
|
TensorFlow provides several operations that you can use to add complex number
|
||||||
@ -1637,20 +1640,22 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
|
|||||||
"""
|
"""
|
||||||
with ops.op_scope([x], name, "Cumsum") as name:
|
with ops.op_scope([x], name, "Cumsum") as name:
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
return gen_math_ops.cumsum(x, axis, exclusive=exclusive,
|
return gen_math_ops.cumsum(
|
||||||
reverse=reverse, name=name)
|
x, axis, exclusive=exclusive, reverse=reverse, name=name)
|
||||||
|
|
||||||
|
|
||||||
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
|
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
|
||||||
"""Compute the cumulative product of the tensor `x` along `axis`.
|
"""Compute the cumulative product of the tensor `x` along `axis`.
|
||||||
|
|
||||||
By default, this op performs an inclusive cumprod, which means that the first
|
By default, this op performs an inclusive cumprod, which means that the
|
||||||
|
first
|
||||||
element of the input is identical to the first element of the output:
|
element of the input is identical to the first element of the output:
|
||||||
```prettyprint
|
```prettyprint
|
||||||
tf.cumprod([a, b, c]) ==> [a, a * b, a * b * c]
|
tf.cumprod([a, b, c]) ==> [a, a * b, a * b * c]
|
||||||
```
|
```
|
||||||
|
|
||||||
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is performed
|
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
|
||||||
|
performed
|
||||||
instead:
|
instead:
|
||||||
```prettyprint
|
```prettyprint
|
||||||
tf.cumprod([a, b, c], exclusive=True) ==> [0, a, a * b]
|
tf.cumprod([a, b, c], exclusive=True) ==> [0, a, a * b]
|
||||||
@ -1681,8 +1686,8 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
|
|||||||
"""
|
"""
|
||||||
with ops.op_scope([x], name, "Cumprod") as name:
|
with ops.op_scope([x], name, "Cumprod") as name:
|
||||||
x = ops.convert_to_tensor(x, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
return gen_math_ops.cumprod(x, axis, exclusive=exclusive,
|
return gen_math_ops.cumprod(
|
||||||
reverse=reverse, name=name)
|
x, axis, exclusive=exclusive, reverse=reverse, name=name)
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
|
ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
|
||||||
|
1
third_party/eigen3/BUILD
vendored
1
third_party/eigen3/BUILD
vendored
@ -8,6 +8,7 @@ cc_library(
|
|||||||
"Eigen/Cholesky",
|
"Eigen/Cholesky",
|
||||||
"Eigen/Eigenvalues",
|
"Eigen/Eigenvalues",
|
||||||
"Eigen/QR",
|
"Eigen/QR",
|
||||||
|
"Eigen/SVD",
|
||||||
"unsupported/Eigen/SpecialFunctions",
|
"unsupported/Eigen/SpecialFunctions",
|
||||||
"unsupported/Eigen/CXX11/Tensor",
|
"unsupported/Eigen/CXX11/Tensor",
|
||||||
"unsupported/Eigen/CXX11/FixedPoint",
|
"unsupported/Eigen/CXX11/FixedPoint",
|
||||||
|
38
third_party/eigen3/Eigen/SVD
vendored
38
third_party/eigen3/Eigen/SVD
vendored
@ -1,37 +1 @@
|
|||||||
#ifndef EIGEN_SVD_MODULE_H
|
#include "Eigen/SVD"
|
||||||
#define EIGEN_SVD_MODULE_H
|
|
||||||
|
|
||||||
#include "QR"
|
|
||||||
#include "Householder"
|
|
||||||
#include "Jacobi"
|
|
||||||
|
|
||||||
#include "src/Core/util/DisableStupidWarnings.h"
|
|
||||||
|
|
||||||
/** \defgroup SVD_Module SVD module
|
|
||||||
*
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* This module provides SVD decomposition for matrices (both real and complex).
|
|
||||||
* This decomposition is accessible via the following MatrixBase method:
|
|
||||||
* - MatrixBase::jacobiSvd()
|
|
||||||
*
|
|
||||||
* \code
|
|
||||||
* #include <Eigen/SVD>
|
|
||||||
* \endcode
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "src/misc/Solve.h"
|
|
||||||
#include "src/SVD/JacobiSVD.h"
|
|
||||||
#if defined(EIGEN_USE_LAPACKE) && !defined(EIGEN_USE_LAPACKE_STRICT)
|
|
||||||
#include "src/SVD/JacobiSVD_MKL.h"
|
|
||||||
#endif
|
|
||||||
#include "src/SVD/UpperBidiagonalization.h"
|
|
||||||
|
|
||||||
#ifdef EIGEN2_SUPPORT
|
|
||||||
#include "src/Eigen2Support/SVD.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "src/Core/util/ReenableStupidWarnings.h"
|
|
||||||
|
|
||||||
#endif // EIGEN_SVD_MODULE_H
|
|
||||||
/* vim: set filetype=cpp et sw=2 ts=2 ai: */
|
|
||||||
|
Loading…
Reference in New Issue
Block a user