diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a078488dd18..f0cb90053e4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1023,6 +1023,7 @@ tf_kernel_libraries( "matrix_solve_ls_op", "matrix_solve_op", "matrix_triangular_solve_op", + "svd_op", ], deps = [ ":linalg_ops_common", diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc index 9fbb6db9cf0..575c7e2e7c2 100644 --- a/tensorflow/core/kernels/linalg_ops_common.cc +++ b/tensorflow/core/kernels/linalg_ops_common.cc @@ -90,19 +90,35 @@ void LinearAlgebraOp::Compute( TensorInputs inputs; TensorShapes input_matrix_shapes; TensorShape batch_shape; + AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape); + + TensorShapes output_matrix_shapes; + TensorOutputs outputs; + PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs, + &output_matrix_shapes); + + // Process the individual matrix problems in parallel using a threadpool. + auto shard = [this, &inputs, &input_matrix_shapes, &outputs, + &output_matrix_shapes, context](int64 begin, int64 end) { + for (int64 i = begin; i < end; ++i) { + ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs, + output_matrix_shapes); + } + }; + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, + batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard); +} + +template +void LinearAlgebraOp::AnalyzeInputs( + OpKernelContext* context, TensorInputs* inputs, + TensorShapes* input_matrix_shapes, TensorShape* batch_shape) { int input_rank = -1; - int num_batch_matrices = 1; for (int i = 0; i < NumMatrixInputs(context); ++i) { const Tensor& in = context->input(i); if (i == 0) { - // If the tensor rank is greater than 2, we consider the inner-most - // dimensions as matrices, and loop over all the other outer ("batch") - // dimensions to compute the results. input_rank = in.dims(); - for (int dim = 0; dim < input_rank - 2; ++dim) { - num_batch_matrices *= in.dim_size(dim); - batch_shape.AddDim(in.dim_size(dim)); - } if (SupportsBatchOperation) { OP_REQUIRES( context, input_rank >= 2, @@ -114,6 +130,13 @@ void LinearAlgebraOp::Compute( errors::InvalidArgument("Input tensor ", i, " must have rank == 2, got", input_rank)); } + + // If the tensor rank is greater than 2, we consider the inner-most + // dimensions as matrices, and loop over all the other outer ("batch") + // dimensions to compute the results. + for (int dim = 0; dim < input_rank - 2; ++dim) { + batch_shape->AddDim(in.dim_size(dim)); + } } else { // Make sure that all inputs have the same rank and outer dimensions. OP_REQUIRES(context, input_rank == in.dims(), @@ -121,7 +144,7 @@ void LinearAlgebraOp::Compute( "All input tensors must have the same rank.")); for (int dim = 0; dim < input_rank - 2; ++dim) { OP_REQUIRES( - context, in.dim_size(dim) == batch_shape.dim_size(dim), + context, in.dim_size(dim) == batch_shape->dim_size(dim), errors::InvalidArgument( "All input tensors must have the same outer dimensions.")); } @@ -131,64 +154,59 @@ void LinearAlgebraOp::Compute( const int col_dimension = input_rank - 1; const int64 num_rows = in.dim_size(row_dimension); const int64 num_cols = in.dim_size(col_dimension); - input_matrix_shapes.push_back(TensorShape({num_rows, num_cols})); - inputs.push_back(in); + // TODO(rmlarsen): Use emplace_back when it is added to InlinedVector. Same + // in several places below. + input_matrix_shapes->push_back(TensorShape({num_rows, num_cols})); + inputs->push_back(in); } // Have the derived class validate that the inputs are as expected. - ValidateInputMatrixShapes(context, input_matrix_shapes); - - // Get shape for each of the matrix outputs. - const TensorShapes output_matrix_shapes = - GetOutputMatrixShapes(input_matrix_shapes); - // Make sure the number of outputs is what the derived class expects. - OP_REQUIRES( - context, output_matrix_shapes.size() == context->num_outputs(), - errors::Internal( - "Derived class expected (%d) output matrices for op, got (%d).", - output_matrix_shapes.size(), context->num_outputs())); - - // Allocate outputs. - TensorShapes output_shapes; - TensorOutputs outputs; - for (int i = 0; i < context->num_outputs(); ++i) { - OP_REQUIRES(context, output_matrix_shapes[i].dims() <= 2, - errors::InvalidArgument( - "Rank of matrix output no. %d must be 0, 1 or 2, got %d.", - i, output_matrix_shapes[i].dims())); - - // The final output has the shape of the outer batch dimensions concatenated - // with the output_matrix_shape (if the output is not scalar). - TensorShape output_shape; - if (input_rank == 2) { - output_shape = output_matrix_shapes[i]; - } else { - output_shape = batch_shape; - // Add the inner dimensions that depend on the operation implemented by - // the derived class. - for (int dim = 0; dim < output_matrix_shapes[i].dims(); ++dim) { - output_shape.AddDim(output_matrix_shapes[i].dim_size(dim)); - } - } - output_shapes.push_back(output_shape); - Tensor* out = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(i, output_shape, &out)); - outputs.push_back(out); - } - - auto shard = [this, &inputs, &input_matrix_shapes, &outputs, - &output_matrix_shapes, context](int64 begin, int64 end) { - for (int64 i = begin; i < end; ++i) { - ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs, - output_matrix_shapes); - } - }; - auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - Shard(worker_threads.num_threads, worker_threads.workers, num_batch_matrices, - GetCostPerUnit(input_matrix_shapes), shard); + ValidateInputMatrixShapes(context, *input_matrix_shapes); } -template -void LinearAlgebraOp::ComputeTensorSlice( +template +void LinearAlgebraOp::PrepareOutputs( + OpKernelContext* context, const TensorShapes& input_matrix_shapes, + const TensorShape& batch_shape, TensorOutputs* outputs, + TensorShapes* output_matrix_shapes) { + // Get shape for each of the matrix outputs produced by the derived class. + *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes); + const int num_outputs = output_matrix_shapes->size(); + + // Make sure the number of op outputs is what the derived class expects. + OP_REQUIRES( + context, num_outputs <= context->num_outputs(), + errors::Internal( + "Derived class expected more outputs (%d) that the op has (%d).", + num_outputs, context->num_outputs())); + + // Allocate outputs. + for (int i = 0; i < context->num_outputs(); ++i) { + TensorShape output_tensor_shape({0}); + if (i < num_outputs) { + // This output is used, set up output shape and allocate it. + const TensorShape& output_matrix_shape = output_matrix_shapes->at(i); + OP_REQUIRES(context, output_matrix_shape.dims() <= 2, + errors::InvalidArgument( + "Rank of matrix output no. %d must be 0, 1 or 2, got %d.", + i, output_matrix_shape.dims())); + + // The final output has the shape of the outer batch dimensions + // concatenated with the output_matrix_shape (if the output is not + // scalar). + output_tensor_shape = batch_shape; + for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) { + output_tensor_shape.AddDim(output_matrix_shape.dim_size(dim)); + } + } + Tensor* out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(i, output_tensor_shape, &out)); + outputs->push_back(out); + } +} + +template +void LinearAlgebraOp::ComputeTensorSlice( OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs, const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs, const TensorShapes& output_matrix_shapes) { @@ -204,7 +222,7 @@ void LinearAlgebraOp::ComputeTensorSlice( } MatrixMaps matrix_outputs; - for (int i = 0; i < outputs.size(); ++i) { + for (int i = 0; i < output_matrix_shapes.size(); ++i) { // The output matrix shape may not be a matrix. int num_output_rows = output_matrix_shapes[i].dims() >= 1 ? output_matrix_shapes[i].dim_size(0) diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h index dda83ad2d12..3be9853c6cf 100644 --- a/tensorflow/core/kernels/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg_ops_common.h @@ -43,7 +43,7 @@ template class LinearAlgebraOp : public OpKernel { public: explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {} - ~LinearAlgebraOp() override {} + void Compute(OpKernelContext* context) override; protected: @@ -80,19 +80,26 @@ class LinearAlgebraOp : public OpKernel { const TensorShapes& input_matrix_shapes); // Returns the output shapes of each individual matrix operation. Output - // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0. - // For many ops the output dimensions are the same as the input dimensions, + // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0. + // + // The derived class may return a number of shapes (N) less than + // context->num_outputs() (M) to indicate that a only leading subset of + // the outputs will be populated. In this case, a dummy scalar tensor with + // value zero will be return for the last M-N outputs. + // + // For many ops, the output dimensions are the same as the input dimensions, // so we provide that as a default implementation for convenience. virtual TensorShapes GetOutputMatrixShapes( const TensorShapes& input_matrix_shapes) const { return input_matrix_shapes; } - // Returns the cost per matrix operation. Cost per unit is assumed to be - // roughly 1ns, based on comments in core/util/work_sharder.cc. - // Many linear algebra ops take roughly max(m,n) * min(m,n)^2, where the first - // input matrix is m-by-n. We provide that as a default implementation for - // convenience. + // Returns the cost per matrix operation. This is used to determine the + // number of threads to use for parallelizing calls to ComputeMatrix in + // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments + // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n) + // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a + // default implementation for convenience. virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const { double m = static_cast(input_matrix_shapes[0].dim_size(0)); double n = static_cast(input_matrix_shapes[0].dim_size(1)); @@ -111,7 +118,9 @@ class LinearAlgebraOp : public OpKernel { // Performs a single matrix computation given input matrices, and // stores the result in outputs. For batch operations, this will be called // repeatedly for a single call to Compute() when multiple matrices exist in - // input Tensors with rank > 2. + // input Tensors with rank > 2. In this case the calls to ComputeMatrix are + // parallelized. The number of threads used is determined by a cost model from + // the value returned by GetCostPerUnit(). virtual void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, MatrixMaps* outputs) = 0; @@ -142,6 +151,15 @@ class LinearAlgebraOp : public OpKernel { const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs, const TensorShapes& output_matrix_shapes); + + void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs, + TensorShapes* input_matrix_shapes, + TensorShape* batch_shape); + + void PrepareOutputs(OpKernelContext* context, + const TensorShapes& input_matrix_shapes, + const TensorShape& batch_shape, TensorOutputs* outputs, + TensorShapes* output_matrix_shapes); }; // Declare that LinearAlgebraOp is explicitly instantiated in diff --git a/tensorflow/core/kernels/svd_op.cc b/tensorflow/core/kernels/svd_op.cc new file mode 100644 index 00000000000..c3686947dda --- /dev/null +++ b/tensorflow/core/kernels/svd_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. +#include + +#include "third_party/eigen3/Eigen/SVD" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +template +class SvdOp : public LinearAlgebraOp { + public: + typedef LinearAlgebraOp Base; + + explicit SvdOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("compute_uv", &compute_uv_)); + OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); + } + + using TensorShapes = typename Base::TensorShapes; + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + Base::ValidateSingleMatrix(context, input_matrix_shapes); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + int64 m = input_matrix_shapes[0].dim_size(0); + int64 n = input_matrix_shapes[0].dim_size(1); + int64 min_size = std::min(m, n); + if (compute_uv_) { + return TensorShapes({TensorShape({min_size}), + TensorShape({m, full_matrices_ ? m : min_size}), + TensorShape({n, full_matrices_ ? n : min_size})}); + } else { + return TensorShapes({TensorShape({min_size})}); + } + } + + // TODO(rmlarsen): This should depend on compute_uv. See b/30409375. + int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + double m = static_cast(input_matrix_shapes[0].dim_size(0)); + double n = static_cast(input_matrix_shapes[0].dim_size(1)); + double cost = 12 * std::max(m, n) * std::min(m, n) * std::min(m, n); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + using Matrix = typename Base::Matrix; + using MatrixMaps = typename Base::MatrixMaps; + using ConstMatrixMap = typename Base::ConstMatrixMap; + using ConstMatrixMaps = typename Base::ConstMatrixMaps; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + Eigen::JacobiSVD svd; + if (compute_uv_) { + svd.compute(inputs[0], + (full_matrices_ ? Eigen::ComputeFullU | Eigen::ComputeFullV + : Eigen::ComputeThinU | Eigen::ComputeThinV)); + outputs->at(0) = svd.singularValues(); + outputs->at(1) = svd.matrixU(); + outputs->at(2) = svd.matrixV(); + } else { + svd.compute(inputs[0]); + outputs->at(0) = svd.singularValues(); + } + } + + private: + bool compute_uv_; + bool full_matrices_; + + TF_DISALLOW_COPY_AND_ASSIGN(SvdOp); +}; + +REGISTER_LINALG_OP("Svd", (SvdOp), float); +REGISTER_LINALG_OP("Svd", (SvdOp), double); +REGISTER_LINALG_OP("BatchSvd", (SvdOp), float); +REGISTER_LINALG_OP("BatchSvd", (SvdOp), double); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index ab4b2644b24..0ea31ddca33 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -128,7 +128,7 @@ REGISTER_OP("MatrixDeterminant") return Status::OK(); }) .Doc(R"doc( -Calculates the determinant of a square matrix. +Computes the determinant of a square matrix. input: A tensor of shape `[M, M]`. output: A scalar, equal to the determinant of the input. @@ -152,7 +152,7 @@ REGISTER_OP("BatchMatrixDeterminant") return Status::OK(); }) .Doc(R"doc( -Calculates the determinants for a batch of square matrices. +Computes the determinants for a batch of square matrices. The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form square matrices. The output is a tensor containing the determinants @@ -169,7 +169,7 @@ REGISTER_OP("MatrixInverse") .Attr("T: {double, float}") .SetShapeFn(UnchangedSquareShapeFn) .Doc(R"doc( -Calculates the inverse of a square invertible matrix or its adjoint (conjugate +Computes the inverse of a square invertible matrix or its adjoint (conjugate transpose). The op uses LU decomposition with partial pivoting to compute the inverse. @@ -191,7 +191,7 @@ REGISTER_OP("BatchMatrixInverse") .Attr("T: {double, float}") .SetShapeFn(BatchUnchangedSquareShapeFn) .Doc(R"doc( -Calculates the inverse of square invertible matrices or their adjoints +Computes the inverse of square invertible matrices or their adjoints (conjugate transposes). The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions @@ -214,7 +214,7 @@ REGISTER_OP("Cholesky") .Attr("T: {double, float}") .SetShapeFn(UnchangedSquareShapeFn) .Doc(R"doc( -Calculates the Cholesky decomposition of a square matrix. +Computes the Cholesky decomposition of a square matrix. The input has to be symmetric and positive definite. Only the lower-triangular part of the input will be used for this operation. The upper-triangular part @@ -233,7 +233,7 @@ REGISTER_OP("BatchCholesky") .Attr("T: {double, float}") .SetShapeFn(BatchUnchangedSquareShapeFn) .Doc(R"doc( -Calculates the Cholesky decomposition of a batch of square matrices. +Computes the Cholesky decomposition of a batch of square matrices. The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form square matrices, with the same constraints as the single matrix Cholesky @@ -251,7 +251,7 @@ REGISTER_OP("CholeskyGrad") .Attr("T: {float, double}") .SetShapeFn(UnchangedSquareShapeFn) .Doc(R"doc( -Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. +Computes the reverse mode backpropagated gradient of the Cholesky algorithm. For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527. @@ -270,7 +270,7 @@ REGISTER_OP("BatchCholeskyGrad") .Attr("T: {float, double}") .SetShapeFn(BatchUnchangedSquareShapeFn) .Doc(R"doc( -Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. +Computes the reverse mode backpropagated gradient of the Cholesky algorithm. For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527. @@ -299,7 +299,7 @@ REGISTER_OP("SelfAdjointEig") return Status::OK(); }) .Doc(R"doc( -Calculates the Eigen Decomposition of a square Self-Adjoint matrix. +Computes the Eigen Decomposition of a square Self-Adjoint matrix. Only the lower-triangular part of the input will be used in this case. The upper-triangular part will not be read. @@ -330,7 +330,7 @@ REGISTER_OP("BatchSelfAdjointEig") return Status::OK(); }) .Doc(R"doc( -Calculates the Eigen Decomposition of a batch of square self-adjoint matrices. +Computes the Eigen Decomposition of a batch of square self-adjoint matrices. The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form square matrices, with the same constraints as the single matrix @@ -526,10 +526,10 @@ REGISTER_OP("BatchMatrixSolveLs") Solves multiple linear least-squares problems. `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -form square matrices. Rhs is a tensor of shape `[..., M, K]`. The output -is a tensor shape `[..., N, K]` where each output matrix solves each of -the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the -least squares sense. +form matrices of size `[M, N]`. Rhs is a tensor of shape `[..., M, K]`. +The output is a tensor shape `[..., N, K]` where each output matrix solves +each of the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] +in the least squares sense. Below we will use the following notation for each pair of matrix and right-hand sides in the batch: @@ -563,4 +563,82 @@ rhs: Shape is `[..., M, K]`. output: Shape is `[..., N, K]`. )doc"); +REGISTER_OP("Svd") + .Input("input: T") + .Output("s: T") + .Output("u: T") + .Output("v: T") + .Attr("compute_uv: bool = False") + .Attr("full_matrices: bool = False") + .Attr("T: {double, float}") + .Doc(R"doc( +Computes the singular value decomposition of a matrix. + +Computes the SVD of if `input` such that `input = u * diag(s) * transpose(v)` + +```prettyprint +# a is a matrix. +# s is a vector of singular values. +# u is the matrix of left singular vectors. +# v is a matrix of right singular vectors. +s, _, _ = svd(a, compute_uv=False) +s, u, v = svd(a, compute_uv=True) +``` + +input: Shape is `[M, N]`. Let `P` be the minimum of `M` and `N`. +s: Singular values. Shape is `[P]`. +u: Left singular vectors; if `full_matrices` is `False` then shape is `[M, M]`. + If `full_matrices` is `True` then shape is `[M, P]`. + Undefined if `compute_uv` is `False`. +v: Left singular vectors. If `full_matrices` is `False` then shape is `[N, N]`. + If `full_matrices` is `True` then shape is `[N, P]`. + Undefined if `compute_uv` is false. +compute_uv: If true, left and right singular vectors will be + computed and returned in `u` and `v`, respectively. + If false, `u` and `v` are not set and should never referenced. +full_matrices: If true, compute full-sized `u` and `v`. If false + (the default), compute only the leading `P` singular vectors. + Ignored if `compute_uv` is `False`. +)doc"); + +REGISTER_OP("BatchSvd") + .Input("input: T") + .Output("s: T") + .Output("u: T") + .Output("v: T") + .Attr("compute_uv: bool = False") + .Attr("full_matrices: bool = False") + .Attr("T: {double, float}") + .Doc(R"doc( +Computes the singular value decompositions of a batch of matrices. + +Computes the SVD of each inner matrix in `input` such that +`input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])` + +```prettyprint +# a is a tensor containing a batch of matrices. +# s is a tensor of singular values for each matrix. +# u is the tensor containing of left singular vectors for each matrix. +# v is the tensor containing of right singular vectors for each matrix. +s, _, _ = batch_svd(a, compute_uv=False) +s, u, v = batch_svd(a, compute_uv=True) +``` + +input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions + form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`. +s: Singular values. Shape is `[..., P]`. +u: Left singular vectors. If `full_matrices` is `False` then shape is + `[..., M, M]`; if `full_matrices` is `True` then shape is + `[..., M, P]`. Undefined if `compute_uv` is `False`. +v: Left singular vectors. If `full_matrices` is `False` then shape is + `[..., N, N]`. If `full_matrices` is `True` then shape is `[..., N, P]`. + Undefined if `compute_uv` is false. +compute_uv: If true, left and right singular vectors will be + computed and returned in `u` and `v`, respectively. + If false, `u` and `v` are not set and should never referenced. +full_matrices: If true, compute full-sized `u` and `v`. If false + (the default), compute only the leading `P` singular vectors. + Ignored if `compute_uv` is `False`. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 7e11f17211b..16c260f154b 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -209,6 +209,7 @@ cuda_py_tests( "cwise_ops_test.py", "embedding_ops_test.py", "linalg_grad_test.py", + "svd_op_test.py", ], shard_count = 50, tags = ["notap"], # b/30226163 diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py new file mode 100644 index 00000000000..6c2d8369799 --- /dev/null +++ b/tensorflow/python/kernel_tests/svd_op_test.py @@ -0,0 +1,112 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops.matrix_inverse.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +class SvdOpTest(tf.test.TestCase): + + def testWrongDimensions(self): + # The input to svd should be 2-dimensional tensor. + scalar = tf.constant(1.) + with self.assertRaises(ValueError): + tf.svd(scalar) + vector = tf.constant([1., 2.]) + with self.assertRaises(ValueError): + tf.svd(vector) + tensor = tf.constant([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]]) + with self.assertRaises(ValueError): + tf.svd(tensor) + + # The input to batch_svd should be a tensor of at least rank 2. + scalar = tf.constant(1.) + with self.assertRaises(ValueError): + tf.batch_svd(scalar) + vector = tf.constant([1., 2.]) + with self.assertRaises(ValueError): + tf.batch_svd(vector) + + +def _GetSvdOpTest(dtype_, shape_): + + def _CompareSingularVectors(self, x, y, atol): + # Singular vectors are only unique up to sign (complex phase factor for + # complex matrices), so we normalize the signs first. + signs = np.sign(np.sum(np.divide(x, y), -2, keepdims=True)) + x *= signs + self.assertAllClose(x, y, atol=atol) + + def Test(self): + np.random.seed(1) + x = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) + if dtype_ == np.float32: + atol = 1e-4 + else: + atol = 1e-14 + for compute_uv in False, True: + for full_matrices in False, True: + with self.test_session(): + if x.ndim == 2: + if compute_uv: + tf_s, tf_u, tf_v = tf.svd(tf.constant(x), + compute_uv=compute_uv, + full_matrices=full_matrices) + else: + tf_s = tf.svd(tf.constant(x), + compute_uv=compute_uv, + full_matrices=full_matrices) + else: + if compute_uv: + tf_s, tf_u, tf_v = tf.batch_svd( + tf.constant(x), + compute_uv=compute_uv, + full_matrices=full_matrices) + else: + tf_s = tf.batch_svd( + tf.constant(x), + compute_uv=compute_uv, + full_matrices=full_matrices) + if compute_uv: + np_u, np_s, np_v = np.linalg.svd(x, + compute_uv=compute_uv, + full_matrices=full_matrices) + else: + np_s = np.linalg.svd(x, + compute_uv=compute_uv, + full_matrices=full_matrices) + self.assertAllClose(np_s, tf_s.eval(), atol=atol) + if compute_uv: + _CompareSingularVectors(self, np_u, tf_u.eval(), atol) + _CompareSingularVectors(self, np.swapaxes(np_v, -2, -1), + tf_v.eval(), atol) + + return Test + + +if __name__ == '__main__': + for dtype in np.float32, np.float64: + for m in 1, 2, 5, 10: + for n in 1, 2, 5, 10: + for batch_dims in [(), (3,)] + [(3, 2)] * (max(m, n) < 10): + shape = batch_dims + (m, n) + name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape))) + setattr(SvdOpTest, 'testSvd_' + name, _GetSvdOpTest(dtype, shape)) + tf.test.main() diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 67fadc12cdc..908e04df7c8 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -32,6 +32,10 @@ from tensorflow.python.ops import math_ops ops.NoGradient("CholeskyGrad") ops.NoGradient("BatchCholeskyGrad") +ops.NoGradient("SelfAdjointEig") +ops.NoGradient("BatchSelfAdjointEig") +ops.NoGradient("Svd") +ops.NoGradient("BatchSvd") @ops.RegisterGradient("MatrixInverse") diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 0e76f772caf..60707800207 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.ops.gen_linalg_ops import * @ops.RegisterShape("CholeskyGrad") @ops.RegisterShape("MatrixInverse") def _UnchangedSquare(op): + """Shape function for matrix ops with output equal to input shape.""" input_shape = op.inputs[0].get_shape().with_rank(2) # The matrix must be square. input_shape[0].assert_is_compatible_with(input_shape[1]) @@ -41,6 +42,7 @@ def _UnchangedSquare(op): @ops.RegisterShape("BatchCholeskyGrad") @ops.RegisterShape("BatchMatrixInverse") def _BatchUnchangedSquare(op): + """Shape function for batch matrix ops with output equal to input shape.""" input_shape = op.inputs[0].get_shape().with_rank_at_least(2) # The matrices in the batch must be square. input_shape[-1].assert_is_compatible_with(input_shape[-2]) @@ -48,6 +50,7 @@ def _BatchUnchangedSquare(op): @ops.RegisterShape("MatrixDeterminant") def _MatrixDeterminantShape(op): + """Shape function for determinant op.""" input_shape = op.inputs[0].get_shape().with_rank(2) # The matrix must be square. input_shape[0].assert_is_compatible_with(input_shape[1]) @@ -59,6 +62,7 @@ def _MatrixDeterminantShape(op): @ops.RegisterShape("BatchMatrixDeterminant") def _BatchMatrixDeterminantShape(op): + """Shape function for batch determinant op.""" input_shape = op.inputs[0].get_shape().with_rank_at_least(2) # The matrices in the batch must be square. input_shape[-1].assert_is_compatible_with(input_shape[-2]) @@ -70,6 +74,7 @@ def _BatchMatrixDeterminantShape(op): @ops.RegisterShape("SelfAdjointEig") def _SelfAdjointEigShape(op): + """Shape function for self-adjoint eigensolver op.""" input_shape = op.inputs[0].get_shape().with_rank(2) # The matrix must be square. input_shape[0].assert_is_compatible_with(input_shape[1]) @@ -80,6 +85,7 @@ def _SelfAdjointEigShape(op): @ops.RegisterShape("BatchSelfAdjointEig") def _BatchSelfAdjointEigShape(op): + """Shape function for batch self-adjoint eigensolver op.""" input_shape = op.inputs[0].get_shape().with_rank_at_least(2) # The matrices in the batch must be square. input_shape[-1].assert_is_compatible_with(input_shape[-2]) @@ -89,9 +95,63 @@ def _BatchSelfAdjointEigShape(op): return [out_shape] +@ops.RegisterShape("Svd") +def _SvdShape(op): + """Shape function for SVD op.""" + input_shape = op.inputs[0].get_shape().with_rank(2) + unknown = tensor_shape.unknown_shape() + compute_uv = op.get_attr("compute_uv") + if input_shape.ndims is not None: + return [unknown, unknown, unknown] + full_matrices = op.get_attr("full_matrices") + m = input_shape.dims[0] + n = input_shape.dims[1] + p = min(m, n) + s_shape = tensor_shape.TensorShape([p]) + if compute_uv: + if full_matrices: + u_shape = tensor_shape.TensorShape([m, m]) + v_shape = tensor_shape.TensorShape([n, n]) + else: + u_shape = tensor_shape.TensorShape([m, p]) + v_shape = tensor_shape.TensorShape([n, p]) + else: + u_shape = [0] + v_shape = [0] + return [s_shape, u_shape, v_shape] + + +@ops.RegisterShape("BatchSvd") +def _BatchSvdShape(op): + """Shape function for batch SVD op.""" + input_shape = op.inputs[0].get_shape().with_rank_at_least(2) + unknown = tensor_shape.unknown_shape() + if input_shape.ndims is not None: + return [unknown, unknown, unknown] + compute_uv = op.get_attr("compute_uv") + full_matrices = op.get_attr("full_matrices") + m = input_shape.dims[-2] + n = input_shape.dims[-1] + p = min(m, n) + batch_shape = input_shape.dims[:-2] + s_shape = batch_shape.concatenate([p]) + if compute_uv: + if full_matrices: + u_shape = batch_shape.concatenate([m, m]) + v_shape = batch_shape.concatenate([n, n]) + else: + u_shape = batch_shape.concatenate([m, p]) + v_shape = batch_shape.concatenate([n, p]) + else: + u_shape = [0] + v_shape = [0] + return [s_shape, u_shape, v_shape] + + @ops.RegisterShape("MatrixSolve") @ops.RegisterShape("MatrixTriangularSolve") def _SquareMatrixSolveShape(op): + """Shape function for square matrix solver ops.""" lhs_shape = op.inputs[0].get_shape().with_rank(2) rhs_shape = op.inputs[1].get_shape().with_rank(2) # The matrix must be square. @@ -104,6 +164,7 @@ def _SquareMatrixSolveShape(op): @ops.RegisterShape("BatchMatrixSolve") @ops.RegisterShape("BatchMatrixTriangularSolve") def _BatchSquareMatrixSolveShape(op): + """Shape function for batch square matrix solver ops.""" lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2) rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) # The matrices must be square. @@ -116,6 +177,7 @@ def _BatchSquareMatrixSolveShape(op): @ops.RegisterShape("MatrixSolveLs") def _MatrixSolveLsShape(op): + """Shape function for least-squares matrix solver op.""" lhs_shape = op.inputs[0].get_shape().with_rank(2) rhs_shape = op.inputs[1].get_shape().with_rank(2) # The matrix and right-hand side must have the same number of rows. @@ -125,6 +187,7 @@ def _MatrixSolveLsShape(op): @ops.RegisterShape("BatchMatrixSolveLs") def _BatchMatrixSolveLsShape(op): + """Shape function for batch least-squares matrix solver op.""" lhs_shape = op.inputs[0].get_shape().with_rank_at_least(2) rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) # The matrices and right-hand sides in the batch must have the same number of @@ -331,4 +394,92 @@ def batch_matrix_solve_ls(matrix, fast=fast, name=name) + +def svd(matrix, compute_uv=False, full_matrices=False, name=None): + """Computes the singular value decomposition of a matrix. + + Computes the SVD of if `matrix` such that `matrix = u * diag(s) * + transpose(v)` + + ```prettyprint + # a is a matrix. + # s is a vector of singular values. + # u is the matrix of left singular vectors. + # v is a matrix of right singular vectors. + s = svd(a, compute_uv=False) + s, u, v = svd(a, compute_uv=True) + ``` + + Args: + matrix: `Tensor` of shape `[M, N]`. Let `P` be the minimum of `M` and `N`. + compute_uv: If `True` then left and right singular vectors will be + computed and returned in `u` and `v`, respectively. Otherwise, only the + singular values will be computed. + full_matrices: If true, compute full-sized `u` and `v`. If false + (the default), compute only the leading `P` singular vectors. + Ignored if `compute_uv` is `False`. + name: string, optional name of the operation. + + Returns: + s: Singular values. Shape is `[P]`. + u: Right singular vectors. If `full_matrices` is `False` (default) then + shape is `[M, P]`; if `full_matrices` is `True` then shape is + `[M, M]`. Not returned if `compute_uv` is `False`. + v: Left singular vectors. If `full_matrices` is `False` (default) then + shape is `[N, P]`. If `full_matrices` is `True` then shape is + `[N, N]`. Not returned if `compute_uv` is `False`. + """ + s, u, v = gen_linalg_ops.svd(matrix, + compute_uv=compute_uv, + full_matrices=full_matrices) + if compute_uv: + return s, u, v + else: + return s + + +def batch_svd(tensor, compute_uv=False, full_matrices=False, name=None): + """Computes the singular value decompositions of a batch of matrices. + + Computes the SVD of each inner matrix in `tensor` such that + `tensor[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, + :])` + + ```prettyprint + # a is a tensor. + # s is a tensor of singular values. + # u is a tensor of left singular vectors. + # v is a tensor of right singular vectors. + s = batch_svd(a, compute_uv=False) + s, u, v = batch_svd(a, compute_uv=True) + ``` + + Args: + matrix: `Tensor` of shape `[..., M, N]`. Let `P` be the minimum of `M` and + `N`. + compute_uv: If `True` then left and right singular vectors will be + computed and returned in `u` and `v`, respectively. Otherwise, only the + singular values will be computed. + full_matrices: If true, compute full-sized `u` and `v`. If false + (the default), compute only the leading `P` singular vectors. + Ignored if `compute_uv` is `False`. + name: string, optional name of the operation. + + Returns: + s: Singular values. Shape is `[..., P]`. + u: Right singular vectors. If `full_matrices` is `False` (default) then + shape is `[..., M, P]`; if `full_matrices` is `True` then shape is + `[..., M, M]`. Not returned if `compute_uv` is `False`. + v: Left singular vectors. If `full_matrices` is `False` (default) then + shape is `[..., N, P]`. If `full_matrices` is `True` then shape is + `[..., N, N]`. Not returned if `compute_uv` is `False`. + """ + s, u, v = gen_linalg_ops.batch_svd( + tensor, compute_uv=compute_uv, full_matrices=full_matrices) + if compute_uv: + return s, u, v + else: + return s + + # pylint: enable=invalid-name diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index cd7e92401d2..981a951e662 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -98,9 +98,6 @@ functions on matrices to your graph. @@cholesky_solve @@batch_cholesky_solve -@@self_adjoint_eig -@@batch_self_adjoint_eig - @@matrix_solve @@batch_matrix_solve @@ -110,6 +107,12 @@ functions on matrices to your graph. @@matrix_solve_ls @@batch_matrix_solve_ls +@@self_adjoint_eig +@@batch_self_adjoint_eig + +@@svd +@@batch_svd + ## Complex Number Functions TensorFlow provides several operations that you can use to add complex number @@ -1598,91 +1601,93 @@ def tanh(x, name=None): def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): - """Compute the cumulative sum of the tensor `x` along `axis`. + """Compute the cumulative sum of the tensor `x` along `axis`. - By default, this op performs an inclusive cumsum, which means that the first - element of the input is identical to the first element of the output: - ```prettyprint - tf.cumsum([a, b, c]) ==> [a, a + b, a + b + c] - ``` + By default, this op performs an inclusive cumsum, which means that the first + element of the input is identical to the first element of the output: + ```prettyprint + tf.cumsum([a, b, c]) ==> [a, a + b, a + b + c] + ``` - By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed - instead: - ```prettyprint - tf.cumsum([a, b, c], exclusive=True) ==> [0, a, a + b] - ``` + By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed + instead: + ```prettyprint + tf.cumsum([a, b, c], exclusive=True) ==> [0, a, a + b] + ``` - By setting the `reverse` kwarg to `True`, the cumsum is performed in the - opposite direction: - ```prettyprint - tf.cumsum([a, b, c], reverse=True) ==> [a + b + c, b + c, c] - ``` - This is more efficient than using separate `tf.reverse` ops. + By setting the `reverse` kwarg to `True`, the cumsum is performed in the + opposite direction: + ```prettyprint + tf.cumsum([a, b, c], reverse=True) ==> [a + b + c, b + c, c] + ``` + This is more efficient than using separate `tf.reverse` ops. - The `reverse` and `exclusive` kwargs can also be combined: - ```prettyprint - tf.cumsum([a, b, c], exclusive=True, reverse=True) ==> [b + c, c, 0] - ``` + The `reverse` and `exclusive` kwargs can also be combined: + ```prettyprint + tf.cumsum([a, b, c], exclusive=True, reverse=True) ==> [b + c, c, 0] + ``` - Args: - x: A `Tensor`. Must be one of the following types: `float32`, `float64`, + Args: + x: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. - axis: A `Tensor` of type `int32` (default: 0). - reverse: A `bool` (default: False). - name: A name for the operation (optional). + axis: A `Tensor` of type `int32` (default: 0). + reverse: A `bool` (default: False). + name: A name for the operation (optional). - Returns: - A `Tensor`. Has the same type as `x`. - """ - with ops.op_scope([x], name, "Cumsum") as name: - x = ops.convert_to_tensor(x, name="x") - return gen_math_ops.cumsum(x, axis, exclusive=exclusive, - reverse=reverse, name=name) + Returns: + A `Tensor`. Has the same type as `x`. + """ + with ops.op_scope([x], name, "Cumsum") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops.cumsum( + x, axis, exclusive=exclusive, reverse=reverse, name=name) def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): - """Compute the cumulative product of the tensor `x` along `axis`. + """Compute the cumulative product of the tensor `x` along `axis`. - By default, this op performs an inclusive cumprod, which means that the first - element of the input is identical to the first element of the output: - ```prettyprint - tf.cumprod([a, b, c]) ==> [a, a * b, a * b * c] - ``` + By default, this op performs an inclusive cumprod, which means that the + first + element of the input is identical to the first element of the output: + ```prettyprint + tf.cumprod([a, b, c]) ==> [a, a * b, a * b * c] + ``` - By setting the `exclusive` kwarg to `True`, an exclusive cumprod is performed - instead: - ```prettyprint - tf.cumprod([a, b, c], exclusive=True) ==> [0, a, a * b] - ``` + By setting the `exclusive` kwarg to `True`, an exclusive cumprod is + performed + instead: + ```prettyprint + tf.cumprod([a, b, c], exclusive=True) ==> [0, a, a * b] + ``` - By setting the `reverse` kwarg to `True`, the cumprod is performed in the - opposite direction: - ```prettyprint - tf.cumprod([a, b, c], reverse=True) ==> [a * b * c, b * c, c] - ``` - This is more efficient than using separate `tf.reverse` ops. + By setting the `reverse` kwarg to `True`, the cumprod is performed in the + opposite direction: + ```prettyprint + tf.cumprod([a, b, c], reverse=True) ==> [a * b * c, b * c, c] + ``` + This is more efficient than using separate `tf.reverse` ops. - The `reverse` and `exclusive` kwargs can also be combined: - ```prettyprint - tf.cumprod([a, b, c], exclusive=True, reverse=True) ==> [b * c, c, 0] - ``` + The `reverse` and `exclusive` kwargs can also be combined: + ```prettyprint + tf.cumprod([a, b, c], exclusive=True, reverse=True) ==> [b * c, c, 0] + ``` - Args: - x: A `Tensor`. Must be one of the following types: `float32`, `float64`, + Args: + x: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. - axis: A `Tensor` of type `int32` (default: 0). - reverse: A `bool` (default: False). - name: A name for the operation (optional). + axis: A `Tensor` of type `int32` (default: 0). + reverse: A `bool` (default: False). + name: A name for the operation (optional). - Returns: - A `Tensor`. Has the same type as `x`. - """ - with ops.op_scope([x], name, "Cumprod") as name: - x = ops.convert_to_tensor(x, name="x") - return gen_math_ops.cumprod(x, axis, exclusive=exclusive, - reverse=reverse, name=name) + Returns: + A `Tensor`. Has the same type as `x`. + """ + with ops.op_scope([x], name, "Cumprod") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops.cumprod( + x, axis, exclusive=exclusive, reverse=reverse, name=name) ops.RegisterShape("Abs")(common_shapes.unchanged_shape) diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD index 9062ed2ec0d..15534fa9612 100644 --- a/third_party/eigen3/BUILD +++ b/third_party/eigen3/BUILD @@ -8,6 +8,7 @@ cc_library( "Eigen/Cholesky", "Eigen/Eigenvalues", "Eigen/QR", + "Eigen/SVD", "unsupported/Eigen/SpecialFunctions", "unsupported/Eigen/CXX11/Tensor", "unsupported/Eigen/CXX11/FixedPoint", diff --git a/third_party/eigen3/Eigen/SVD b/third_party/eigen3/Eigen/SVD index fd310017ad1..eecf47c1031 100644 --- a/third_party/eigen3/Eigen/SVD +++ b/third_party/eigen3/Eigen/SVD @@ -1,37 +1 @@ -#ifndef EIGEN_SVD_MODULE_H -#define EIGEN_SVD_MODULE_H - -#include "QR" -#include "Householder" -#include "Jacobi" - -#include "src/Core/util/DisableStupidWarnings.h" - -/** \defgroup SVD_Module SVD module - * - * - * - * This module provides SVD decomposition for matrices (both real and complex). - * This decomposition is accessible via the following MatrixBase method: - * - MatrixBase::jacobiSvd() - * - * \code - * #include - * \endcode - */ - -#include "src/misc/Solve.h" -#include "src/SVD/JacobiSVD.h" -#if defined(EIGEN_USE_LAPACKE) && !defined(EIGEN_USE_LAPACKE_STRICT) -#include "src/SVD/JacobiSVD_MKL.h" -#endif -#include "src/SVD/UpperBidiagonalization.h" - -#ifdef EIGEN2_SUPPORT -#include "src/Eigen2Support/SVD.h" -#endif - -#include "src/Core/util/ReenableStupidWarnings.h" - -#endif // EIGEN_SVD_MODULE_H -/* vim: set filetype=cpp et sw=2 ts=2 ai: */ +#include "Eigen/SVD"