Add TensorFlow op to compute the LU decomposition of a matrix.

For CPU, we use Eigen's PartialPivLU op, which has performance comparable to Cholesky decomposition. Unlike Cholesky this is also robust for non-SPD matrices.

For GPU, we use getrf and getrfbatched from cuSolver.

PiperOrigin-RevId: 224244390
This commit is contained in:
Anudhyan Boral 2018-12-05 16:45:41 -08:00 committed by TensorFlower Gardener
parent f102c3c051
commit 5caa987005
12 changed files with 900 additions and 0 deletions

View File

@ -0,0 +1,51 @@
op {
graph_op_name: "Lu"
in_arg {
name: "input"
description: <<END
A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of
size `[M, M]`.
END
}
out_arg {
name: "lu"
description: <<END
A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the
lower triangular factor `L` with unit diagonal, and whose upper triangular part
denotes the upper triangular factor `U`.
END
}
out_arg {
name: "p"
description: <<END
Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is
`[..., M]`.
@compatibility(scipy)
Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are
packed into a single tensor, the permutation is applied to `input` instead of
the right hand side and the permutation `P` is returned as a list of indices
instead of a permutation matrix.
@end_compatibility
END
}
summary: "Computes the LU decomposition of one or more square matrices."
description: <<END
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices.
The input has to be invertible.
The output consists of two tensors LU and P containing the LU decomposition
of all input submatrices `[..., :, :]`. LU encodes the lower triangular and
upper triangular factors.
For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of
shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower
triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose
entries correspond to the upper triangular part, including the diagonal, of LU.
P represents a permutation matrix encoded as a list of indices each between `0`
and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to
P, then the L, U and P satisfies P_mat * input = L * U.
END
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "Lu"
endpoint {
name: "linalg.lu"
}
}

View File

@ -2755,6 +2755,7 @@ cc_library(
":cholesky_grad",
":cholesky_op",
":determinant_op",
":lu_op",
":matrix_exponential_op",
":matrix_inverse_op",
":matrix_logarithm_op",
@ -2900,6 +2901,19 @@ tf_kernel_library(
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "lu_op",
prefix = "lu_op",
deps = if_cuda([
":cuda_solvers",
":transpose_functor",
]) + [
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
cc_library(
name = "linalg_ops_common",
srcs = ["linalg_ops_common.cc"],

View File

@ -0,0 +1,193 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/Eigen/LU"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Scalar, typename Tidx>
class LuOp : public OpKernel {
public:
explicit LuOp(OpKernelConstruction* context) : OpKernel(context) {}
protected:
using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
using Indices =
Eigen::Matrix<Tidx, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using IndicesMap = Eigen::Map<Indices>;
using ConstIndicesMap = Eigen::Map<const Indices>;
public:
// Returns the cost per matrix operation. This is used to determine the
// number of threads to use for parallelizing factorization in batch mode.
// Cost per unit is assumed to be roughly 1ns, based on comments
// in core/util/work_sharder.cc.
// LU decomposition for a square matrix takes roughly (2/3) * (num_rows)^3.
// TODO(anudhyan): Refine this estimate after taking constant factors into
// account.
int64 GetCostPerUnit(const TensorShape& input_matrix_shape) const {
double num_rows = static_cast<double>(input_matrix_shape.dim_size(0));
double cost = (2 / 3.0) * MathUtil::IPow(num_rows, 3);
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
void Compute(OpKernelContext* context) override {
OP_REQUIRES(context, context->num_inputs() == 1,
errors::InvalidArgument("Expecting exactly one input, got ",
context->num_inputs()));
const Tensor& input = context->input(0);
int input_rank = input.dims();
OP_REQUIRES(context, input_rank >= 2,
errors::InvalidArgument(
"Input tensor must have rank >= 2, got ", input_rank));
// If the tensor rank is greater than 2, we consider the inner-most
// dimensions as matrices, and loop over all the other outer ("batch")
// dimensions to compute the results.
TensorShape input_matrix_shape;
TensorShape batch_shape;
for (int dim = 0; dim < input_rank - 2; ++dim) {
batch_shape.AddDim(input.dim_size(dim));
}
const int64 num_rows = input.dim_size(input_rank - 2);
const int64 num_cols = input.dim_size(input_rank - 1);
input_matrix_shape.AppendShape({num_rows, num_cols});
OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shape),
errors::InvalidArgument("Input matrix must be square."));
// packed_triangular_factors is a matrix with the same shape as the input;
// permutation is a vector.
TensorShape permutation_shape = batch_shape;
permutation_shape.AddDim(num_rows);
TensorShapes output_matrix_shapes({input.shape(), permutation_shape});
TensorOutputs outputs;
Tensor* output_packed_triangular_factors = nullptr;
OP_REQUIRES_OK(
context, context->forward_input_or_allocate_output(
{0}, 0, input.shape(), &output_packed_triangular_factors));
outputs.emplace_back(output_packed_triangular_factors);
Tensor* output_permutation = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, permutation_shape,
&output_permutation));
outputs.emplace_back(output_permutation);
if (num_rows == 0) {
return;
}
// Process the individual matrix problems in parallel using a threadpool.
auto shard = [this, &input, &num_rows, &num_cols, &outputs,
&output_matrix_shapes, context](int64 begin, int64 end) {
for (int64 i = begin; i < end; ++i) {
ComputeTensorSlice(context, i, input, num_rows, num_cols, outputs,
output_matrix_shapes);
}
};
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers,
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shape),
shard);
}
void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
const Tensor& input, int64 num_rows, int64 num_cols,
const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes) {
// TODO(kalakris): Handle alignment if possible. Eigen::Map is
// unaligned by default.
ConstMatrixMap input_matrix(
input.flat<Scalar>().data() + matrix_index * num_rows * num_cols,
num_rows, num_cols);
// packed_triangular_factors has shape [num_rows, num_cols]
MatrixMap packed_triangular_factors(
outputs[0]->flat<Scalar>().data() + matrix_index * num_rows * num_cols,
num_rows, num_rows);
// permutation has shape [num_rows, 1]
IndicesMap permutation_indices(
outputs[1]->flat<Tidx>().data() + matrix_index * num_rows, num_rows, 1);
Eigen::PartialPivLU<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>
lu_decomposition(input_matrix);
// Output the packed triangular factors in a dense form.
// The lower triangular factor L corresponds to the strictly lower
// triangular part of packed_triangular_factors with an implicit unit
// diagonal. The upper triangular factor U is the upper triangular part of
// packed_triangular_factors. The triangular factors satisfy the equation
// P * input_matrix = L * U
// where P is the permutation matrix corresponding to the indices in
// permutation_indices.
packed_triangular_factors = lu_decomposition.matrixLU();
// Output the permutation matrix used for pivoting.
Eigen::PermutationMatrix<-1, -1, Tidx> permutation =
lu_decomposition.permutationP().transpose();
permutation_indices = permutation.indices();
// PartialPivLU cannot give strong guarantees on invertibility,
// but we can at least guard against exact zero pivots. This can occur as
// a result of basic user mistakes such providing integer valued
// matrices that are exactly singular, or due to underflow if this
// code is run with denormals being flushed to zero.
const RealScalar min_abs_pivot =
packed_triangular_factors.diagonal().cwiseAbs().minCoeff();
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input is not invertible."));
}
};
#define REGISTER_LU(type, idx_type) \
REGISTER_KERNEL_BUILDER(Name("Lu") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<idx_type>("output_idx_type"), \
LuOp<type, idx_type>);
REGISTER_LU(float, int32);
REGISTER_LU(double, int32);
REGISTER_LU(complex64, int32);
REGISTER_LU(complex128, int32);
REGISTER_LU(float, int64);
REGISTER_LU(double, int64);
REGISTER_LU(complex64, int64);
REGISTER_LU(complex128, int64);
} // namespace tensorflow

View File

@ -0,0 +1,275 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include <algorithm>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename Scalar>
__device__ void ComputePermutationFromTranspositions(
int64 num_rows, const int* pivots, Scalar* permutation_indices) {
// Fill in the output array with the identity permutation.
for (int i = 0; i < num_rows; ++i) {
permutation_indices[i] = Scalar(i);
}
// Compute the permutation from a sequence of transpositions encoded
// in the pivot array by applying the transpositions in order on the
// identity permutation.
for (int i = 0; i < num_rows; ++i) {
// Note: Internally, the cuBlas code uses Fortran convention (1-based)
// indexing so ith row was swapped with (pivots[i]-1)'th row in 0-based
// indexing.
Scalar t = permutation_indices[i];
permutation_indices[i] = permutation_indices[pivots[i] - 1];
permutation_indices[pivots[i] - 1] = t;
}
}
} // namespace
// Kernel to compute the inverse of a permutation from a sequence of
// transpositions.
template <typename Scalar>
__global__ void ComputePermutationFromTranspositionsKernel(
CudaLaunchConfig config, const int64 num_rows, const int* all_pivots,
Scalar* all_permutation_indices) {
// We only parallelize over batches here. Performance is not critical,
// since this cheap O(num_rows) kernel always follows an O(num_rows^3)
// LU factorization.
CUDA_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
ComputePermutationFromTranspositions(
num_rows, all_pivots + index * num_rows,
all_permutation_indices + index * num_rows);
}
}
template <class Scalar, class Tidx>
class LuOpGpu : public AsyncOpKernel {
public:
explicit LuOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {}
void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
const Tensor& input = context->input(0);
// Analyze shape and validate inputs.
const int input_rank = input.dims();
OP_REQUIRES_ASYNC(
context, input_rank >= 2,
errors::InvalidArgument("Input must have rank >= 2, got ", input_rank),
done);
const int64 num_rows = input.dim_size(input_rank - 2);
const int64 num_cols = input.dim_size(input_rank - 1);
OP_REQUIRES_ASYNC(
context, num_rows == num_cols,
errors::InvalidArgument("Input matrices must be squares, got", num_rows,
" != ", num_cols),
done);
TensorShape batch_shape;
for (int dim = 0; dim < input_rank - 2; ++dim) {
batch_shape.AddDim(input.dim_size(dim));
}
TensorShape permutation_indices_shape = batch_shape;
permutation_indices_shape.AddDim(num_rows);
const GPUDevice& device = context->eigen_device<GPUDevice>();
auto solver = absl::make_unique<CudaSolver>(context);
// We output the packed triangular factors in a dense form.
// The lower triangular factor L corresponds to the strictly lower
// triangular part of packed_triangular_factors with an implicit unit
// diagonal. The upper triangular factor U is the upper triangular part of
// packed_triangular_factors. The triangular factors satisfy the equation
// P * input_matrix = L * U
// where P is the permutation matrix corresponding to the indices in
// permutation_indices.
//
// Reuse the input buffer or make a copy for the factorization step,
// depending on whether this ops owns it exclusively.
Tensor* packed_triangular_factors;
OP_REQUIRES_OK_ASYNC(context,
context->forward_input_or_allocate_output(
{0}, 0, input.shape(), &packed_triangular_factors),
done);
if (!packed_triangular_factors->SharesBufferWith(input)) {
device.memcpy(packed_triangular_factors->flat<Scalar>().data(),
input.flat<Scalar>().data(),
input.NumElements() * sizeof(Scalar));
}
// Allocate output permutation.
Tensor* permutation_indices = nullptr;
OP_REQUIRES_OK_ASYNC(context,
context->allocate_output(1, permutation_indices_shape,
&permutation_indices),
done);
if (input.NumElements() == 0) {
done();
return;
}
// Allocate a temporary Tensor to store the transposed packed triangular
// factors.
Tensor packed_triangular_factors_transpose;
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<Scalar>::value, input.shape(),
&packed_triangular_factors_transpose),
done);
auto packed_triangular_factors_transpose_reshaped =
packed_triangular_factors_transpose
.template flat_inner_dims<Scalar, 3>();
const int64 batch_size =
packed_triangular_factors_transpose_reshaped.dimension(0);
// Allocate pivots on the device.
Tensor pivots;
OP_REQUIRES_OK_ASYNC(context,
solver->allocate_scoped_tensor(
DataTypeToEnum<int32>::value,
TensorShape{batch_size, num_rows}, &pivots),
done);
auto pivots_mat = pivots.template matrix<int32>();
// Transpose the input. This is necessary because cuBLAS assumes
// column-major storage while TensorFlow uses row-major.
OP_REQUIRES_OK_ASYNC(
context,
DoMatrixTranspose(device, *packed_triangular_factors,
&packed_triangular_factors_transpose),
done);
std::vector<DeviceLapackInfo> dev_info;
if (num_rows == num_cols && num_rows / batch_size <= 128) {
// For small matrices or large batch sizes, we use the batched
// interface from cuBlas.
auto packed_triangular_factors_ptrs = solver->GetScratchSpace<uint8>(
sizeof(Scalar*) * batch_size, "packed_triangular_factors_ptrs",
/* on_host */ true);
const Scalar** packed_triangular_factors_ptrs_base =
reinterpret_cast<const Scalar**>(
packed_triangular_factors_ptrs.mutable_data());
for (int batch = 0; batch < batch_size; ++batch) {
packed_triangular_factors_ptrs_base[batch] =
&packed_triangular_factors_transpose_reshaped(batch, 0, 0);
}
dev_info.push_back(
solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
OP_REQUIRES_OK_ASYNC(
context,
solver->GetrfBatched(num_rows, packed_triangular_factors_ptrs_base,
num_rows, pivots_mat.data(), &dev_info.back(),
batch_size),
done);
} else {
// For small batch sizes we use the non-batched interface from cuSolver,
// which is much faster for large matrices.
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
for (int batch = 0; batch < batch_size; ++batch) {
OP_REQUIRES_OK_ASYNC(
context,
solver->Getrf(
num_rows, num_cols,
&packed_triangular_factors_transpose_reshaped(batch, 0, 0),
num_rows, &pivots_mat(batch, 0), &dev_info.back()(batch)),
done);
}
}
// Transpose the result since we had transposed the input.
OP_REQUIRES_OK_ASYNC(
context,
DoMatrixTranspose(device, packed_triangular_factors_transpose,
packed_triangular_factors),
done);
// Pivots encode the permutation of the rows as a sequences of row swaps.
// For each index i, row i is swapped with row pivots[i].
int* pivots_ptr = pivots.flat<int>().data();
Tidx* permutation_indices_ptr =
permutation_indices->template flat<Tidx>().data();
CudaLaunchConfig cfgPivots = GetCudaLaunchConfig(batch_size, device);
ComputePermutationFromTranspositionsKernel<<<cfgPivots.block_count,
cfgPivots.thread_per_block, 0,
device.stream()>>>(
cfgPivots, num_rows, pivots_ptr, permutation_indices_ptr);
// Callback for checking info after kernels finish. Also capture the
// temporary Tensors/ScratchSpace so they don't get deallocated before the
// kernels run.
// TODO(rmlarsen): Use move capture once C++14 becomes available.
auto info_checker = [context, done, dev_info](
const Status& status,
const std::vector<HostLapackInfo>& host_infos) {
if (!status.ok() && errors::IsInvalidArgument(status) &&
!host_infos.empty()) {
for (int i = 0; i < host_infos[0].size(); ++i) {
// Match the CPU error message for singular matrices. Otherwise
// just print the original error message from the status below.
OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0,
errors::InvalidArgument("Input is not invertible."),
done);
}
}
OP_REQUIRES_OK_ASYNC(context, status, done);
done();
};
CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
std::move(info_checker));
}
};
#define REGISTER_LU_GPU(type, idx_type) \
REGISTER_KERNEL_BUILDER(Name("Lu") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<idx_type>("output_idx_type"), \
LuOpGpu<type, idx_type>);
REGISTER_LU_GPU(float, int32);
REGISTER_LU_GPU(double, int32);
REGISTER_LU_GPU(complex64, int32);
REGISTER_LU_GPU(complex128, int32);
REGISTER_LU_GPU(float, int64);
REGISTER_LU_GPU(double, int64);
REGISTER_LU_GPU(complex64, int64);
REGISTER_LU_GPU(complex128, int64);
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -109,6 +109,30 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
return Status::OK();
}
// Input is [...,N,N].
// First and second outputs are:
// [...,N,N]; [...,N].
Status LuShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
DimensionHandle n;
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
ShapeHandle batch_shape;
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
ShapeHandle lu_shape;
ShapeHandle p_shape;
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
c->set_output(0, lu_shape);
c->set_output(1, p_shape);
return Status::OK();
}
// Input is [...,M,N].
// First and second outputs are:
// [...,M,M]; [...,M,N], if full_matrices is true,
@ -289,6 +313,14 @@ REGISTER_OP("SelfAdjointEigV2")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(SelfAdjointEigV2ShapeFn);
REGISTER_OP("Lu")
.Input("input: T")
.Output("lu: T")
.Output("p: output_idx_type")
.Attr("T: {double, float, complex64, complex128}")
.Attr("output_idx_type: {int32, int64} = DT_INT32")
.SetShapeFn(LuShapeFn);
REGISTER_OP("MatrixSolve")
.Input("matrix: T")
.Input("rhs: T")

View File

@ -274,4 +274,23 @@ TEST(LinalgOpsTest, Svd_ShapeFn) {
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
}
TEST(LinalgOpsTest, Lu_ShapeFn) {
ShapeInferenceTestOp op("Lu");
INFER_OK(op, "?", "?;?");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?,3,4,1,2]");
INFER_OK(op, "[?,?]", "[d0_0,d0_0];[d0_0]");
INFER_OK(op, "[1,?]", "[d0_0,d0_0];[d0_0]");
INFER_OK(op, "[?,1]", "[d0_1,d0_1];[d0_1]");
// Repeat previous block of tests with input rank > 2.
INFER_OK(op, "[1,?,3,4,?,?]",
"[d0_0,d0_1,d0_2,d0_3,d0_4,d0_4];[d0_0,d0_1,d0_2,d0_3,d0_4]");
INFER_OK(op, "[1,?,3,4,1,?]",
"[d0_0,d0_1,d0_2,d0_3,d0_4,d0_4];[d0_0,d0_1,d0_2,d0_3,d0_4]");
INFER_OK(op, "[1,?,3,4,?,1]",
"[d0_0,d0_1,d0_2,d0_3,d0_5,d0_5];[d0_0,d0_1,d0_2,d0_3,d0_5]");
}
} // end namespace tensorflow

View File

@ -1888,6 +1888,22 @@ cuda_py_test(
],
)
cuda_py_test(
name = "lu_op_test",
size = "small",
srcs = ["lu_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python/ops/linalg",
],
)
cuda_py_test(
name = "manip_ops_test",
size = "small",

View File

@ -0,0 +1,285 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.Lu."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
class LuOpTest(test.TestCase):
@property
def float_types(self):
return set((np.float64, np.float32, np.complex64, np.complex128))
def _verifyLuBase(self, sess, x, lower, upper, perm, verification,
output_idx_type):
lower_np, upper_np, perm_np, verification_np = sess.run(
[lower, upper, perm, verification])
self.assertAllClose(x, verification_np)
self.assertShapeEqual(x, lower)
self.assertShapeEqual(x, upper)
self.assertAllEqual(x.shape[:-1], perm.shape.as_list())
# Check dtypes are as expected.
self.assertEqual(x.dtype, lower_np.dtype)
self.assertEqual(x.dtype, upper_np.dtype)
self.assertEqual(output_idx_type.as_numpy_dtype, perm_np.dtype)
# Check that the permutation is valid.
if perm_np.shape[-1] > 0:
perm_reshaped = np.reshape(perm_np, (-1, perm_np.shape[-1]))
for perm_vector in perm_reshaped:
self.assertAllClose(np.arange(len(perm_vector)), np.sort(perm_vector))
def _verifyLu(self, x, output_idx_type=dtypes.int64):
# Verify that Px = LU.
with self.cached_session(use_gpu=True) as sess:
lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type)
# Prepare the lower factor of shape num_rows x num_rows
lu_shape = np.array(lu.shape.as_list())
batch_shape = lu_shape[:-2]
num_rows = lu_shape[-2]
num_cols = lu_shape[-1]
lower = array_ops.matrix_band_part(lu, -1, 0)
if num_rows > num_cols:
eye = linalg_ops.eye(
num_rows, batch_shape=batch_shape, dtype=lower.dtype)
lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1)
elif num_rows < num_cols:
lower = lower[..., :num_rows]
# Fill the diagonal with ones.
ones_diag = array_ops.ones(
np.append(batch_shape, num_rows), dtype=lower.dtype)
lower = array_ops.matrix_set_diag(lower, ones_diag)
# Prepare the upper factor.
upper = array_ops.matrix_band_part(lu, 0, -1)
verification = math_ops.matmul(lower, upper)
# Permute the rows of product of the Cholesky factors.
if num_rows > 0:
# Reshape the product of the triangular factors and permutation indices
# to a single batch dimension. This makes it easy to apply
# invert_permutation and gather_nd ops.
perm_reshaped = array_ops.reshape(perm, [-1, num_rows])
verification_reshaped = array_ops.reshape(verification,
[-1, num_rows, num_cols])
# Invert the permutation in each batch.
inv_perm_reshaped = functional_ops.map_fn(array_ops.invert_permutation,
perm_reshaped)
batch_size = perm_reshaped.shape.as_list()[0]
# Prepare the batch indices with the same shape as the permutation.
# The corresponding batch index is paired with each of the `num_rows`
# permutation indices.
batch_indices = math_ops.cast(
array_ops.broadcast_to(
math_ops.range(batch_size)[:, None], perm_reshaped.shape),
dtype=output_idx_type)
permuted_verification_reshaped = array_ops.gather_nd(
verification_reshaped,
array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1))
# Reshape the verification matrix back to the original shape.
verification = array_ops.reshape(permuted_verification_reshaped,
lu_shape)
self._verifyLuBase(sess, x, lower, upper, perm, verification,
output_idx_type)
def testBasic(self):
data = np.array([[4., -1., 2.], [-1., 6., 0], [10., 0., 5.]])
for dtype in (np.float32, np.float64):
for output_idx_type in (dtypes.int32, dtypes.int64):
self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
for dtype in (np.complex64, np.complex128):
for output_idx_type in (dtypes.int32, dtypes.int64):
complex_data = np.tril(1j * data, -1).astype(dtype)
complex_data += np.triu(-1j * data, 1).astype(dtype)
complex_data += data
self._verifyLu(complex_data, output_idx_type=output_idx_type)
def testPivoting(self):
with self.session(use_gpu=True) as sess:
# This matrix triggers partial pivoting because the first diagonal entry
# is small.
data = np.array([[1e-9, 1., 0.], [1., 0., 0], [0., 1., 5]])
self._verifyLu(data.astype(np.float32))
for dtype in (np.float32, np.float64):
self._verifyLu(data.astype(dtype))
_, p = linalg_ops.lu(data)
p_val = sess.run([p])
# Make sure p_val is not the identity permutation.
self.assertNotAllClose(np.arange(3), p_val)
for dtype in (np.complex64, np.complex128):
complex_data = np.tril(1j * data, -1).astype(dtype)
complex_data += np.triu(-1j * data, 1).astype(dtype)
complex_data += data
self._verifyLu(complex_data)
_, p = linalg_ops.lu(data)
p_val = sess.run([p])
# Make sure p_val is not the identity permutation.
self.assertNotAllClose(np.arange(3), p_val)
def testInvalidMatrix(self):
# LU factorization gives an error when the input is singular.
# Note: A singular matrix may return without error but it won't be a valid
# factorization.
with self.session(use_gpu=True) as sess:
for dtype in self.float_types:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
linalg_ops.lu(
np.array([[1., 2., 3.], [2., 4., 6.], [2., 3., 4.]],
dtype=dtype)))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
linalg_ops.lu(
np.array([[[1., 2., 3.], [2., 4., 6.], [1., 2., 3.]],
[[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]],
dtype=dtype)))
def testBatch(self):
simple_array = np.array([[[1., -1.], [2., 5.]]]) # shape (1, 2, 2)
self._verifyLu(simple_array)
self._verifyLu(np.vstack((simple_array, simple_array)))
odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
self._verifyLu(np.vstack((odd_sized_array, odd_sized_array)))
batch_size = 200
# Generate random matrices.
np.random.seed(42)
matrices = np.random.rand(batch_size, 5, 5)
self._verifyLu(matrices)
# Generate random complex valued matrices.
np.random.seed(52)
matrices = np.random.rand(batch_size, 5,
5) + 1j * np.random.rand(batch_size, 5, 5)
self._verifyLu(matrices)
def testLargeMatrix(self):
# Generate random matrices.
n = 500
np.random.seed(64)
data = np.random.rand(n, n)
self._verifyLu(data)
# Generate random complex valued matrices.
np.random.seed(129)
data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
self._verifyLu(data)
def testEmpty(self):
self._verifyLu(np.empty([0, 2, 2]))
self._verifyLu(np.empty([2, 0, 0]))
def testConcurrentExecutesWithoutError(self):
with self.session(use_gpu=True) as sess:
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
lu1, p1 = linalg_ops.lu(matrix1)
lu2, p2 = linalg_ops.lu(matrix2)
lu1_val, p1_val, lu2_val, p2_val = sess.run([lu1, p1, lu2, p2])
self.assertAllEqual(lu1_val, lu2_val)
self.assertAllEqual(p1_val, p2_val)
class LuBenchmark(test.Benchmark):
shapes = [
(4, 4),
(10, 10),
(16, 16),
(101, 101),
(256, 256),
(1000, 1000),
(1024, 1024),
(2048, 2048),
(4096, 4096),
(513, 2, 2),
(513, 8, 8),
(513, 256, 256),
(4, 513, 2, 2),
]
def _GenerateMatrix(self, shape):
batch_shape = shape[:-2]
shape = shape[-2:]
assert shape[0] == shape[1]
n = shape[0]
matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
np.ones(n).astype(np.float32))
return np.tile(matrix, batch_shape + (1, 1))
def benchmarkLuOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
lu, p = linalg_ops.lu(matrix)
variables.global_variables_initializer().run()
self.run_op_benchmark(
sess,
control_flow_ops.group(lu, p),
min_iters=25,
name="lu_cpu_{shape}".format(shape=shape))
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/device:GPU:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
lu, p = linalg_ops.lu(matrix)
variables.global_variables_initializer().run()
self.run_op_benchmark(
sess,
control_flow_ops.group(lu, p),
min_iters=25,
name="lu_gpu_{shape}".format(shape=shape))
if __name__ == "__main__":
test.main()

View File

@ -44,6 +44,7 @@ einsum = special_math_ops.einsum
eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm
lu = gen_linalg_ops.lu
tf_export('linalg.logm')(logm)
lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm

View File

@ -132,6 +132,10 @@ tf_module {
name: "lstsq"
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
}
member_method {
name: "lu"
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
name: "matmul"
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "

View File

@ -132,6 +132,10 @@ tf_module {
name: "lstsq"
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
}
member_method {
name: "lu"
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
name: "matmul"
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "