Add TensorFlow op to compute the LU decomposition of a matrix.
For CPU, we use Eigen's PartialPivLU op, which has performance comparable to Cholesky decomposition. Unlike Cholesky this is also robust for non-SPD matrices. For GPU, we use getrf and getrfbatched from cuSolver. PiperOrigin-RevId: 224244390
This commit is contained in:
parent
f102c3c051
commit
5caa987005
51
tensorflow/core/api_def/base_api/api_def_Lu.pbtxt
Normal file
51
tensorflow/core/api_def/base_api/api_def_Lu.pbtxt
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Lu"
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: <<END
|
||||||
|
A tensor of shape `[..., M, M]` whose inner-most 2 dimensions form matrices of
|
||||||
|
size `[M, M]`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "lu"
|
||||||
|
description: <<END
|
||||||
|
A tensor of shape `[..., M, M]` whose strictly lower triangular part denotes the
|
||||||
|
lower triangular factor `L` with unit diagonal, and whose upper triangular part
|
||||||
|
denotes the upper triangular factor `U`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "p"
|
||||||
|
description: <<END
|
||||||
|
Permutation of the rows encoded as a list of indices in `0..M-1`. Shape is
|
||||||
|
`[..., M]`.
|
||||||
|
@compatibility(scipy)
|
||||||
|
Similar to `scipy.linalg.lu`, except the triangular factors `L` and `U` are
|
||||||
|
packed into a single tensor, the permutation is applied to `input` instead of
|
||||||
|
the right hand side and the permutation `P` is returned as a list of indices
|
||||||
|
instead of a permutation matrix.
|
||||||
|
@end_compatibility
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Computes the LU decomposition of one or more square matrices."
|
||||||
|
description: <<END
|
||||||
|
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
|
||||||
|
form square matrices.
|
||||||
|
|
||||||
|
The input has to be invertible.
|
||||||
|
|
||||||
|
The output consists of two tensors LU and P containing the LU decomposition
|
||||||
|
of all input submatrices `[..., :, :]`. LU encodes the lower triangular and
|
||||||
|
upper triangular factors.
|
||||||
|
|
||||||
|
For each input submatrix of shape `[M, M]`, L is a lower triangular matrix of
|
||||||
|
shape `[M, M]` with unit diagonal whose entries correspond to the strictly lower
|
||||||
|
triangular part of LU. U is a upper triangular matrix of shape `[M, M]` whose
|
||||||
|
entries correspond to the upper triangular part, including the diagonal, of LU.
|
||||||
|
|
||||||
|
P represents a permutation matrix encoded as a list of indices each between `0`
|
||||||
|
and `M-1`, inclusive. If P_mat denotes the permutation matrix corresponding to
|
||||||
|
P, then the L, U and P satisfies P_mat * input = L * U.
|
||||||
|
END
|
||||||
|
}
|
6
tensorflow/core/api_def/python_api/api_def_Lu.pbtxt
Normal file
6
tensorflow/core/api_def/python_api/api_def_Lu.pbtxt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Lu"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.lu"
|
||||||
|
}
|
||||||
|
}
|
@ -2755,6 +2755,7 @@ cc_library(
|
|||||||
":cholesky_grad",
|
":cholesky_grad",
|
||||||
":cholesky_op",
|
":cholesky_op",
|
||||||
":determinant_op",
|
":determinant_op",
|
||||||
|
":lu_op",
|
||||||
":matrix_exponential_op",
|
":matrix_exponential_op",
|
||||||
":matrix_inverse_op",
|
":matrix_inverse_op",
|
||||||
":matrix_logarithm_op",
|
":matrix_logarithm_op",
|
||||||
@ -2900,6 +2901,19 @@ tf_kernel_library(
|
|||||||
deps = LINALG_DEPS,
|
deps = LINALG_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "lu_op",
|
||||||
|
prefix = "lu_op",
|
||||||
|
deps = if_cuda([
|
||||||
|
":cuda_solvers",
|
||||||
|
":transpose_functor",
|
||||||
|
]) + [
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "linalg_ops_common",
|
name = "linalg_ops_common",
|
||||||
srcs = ["linalg_ops_common.cc"],
|
srcs = ["linalg_ops_common.cc"],
|
||||||
|
193
tensorflow/core/kernels/lu_op.cc
Normal file
193
tensorflow/core/kernels/lu_op.cc
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "third_party/eigen3/Eigen/LU"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
template <typename Scalar, typename Tidx>
|
||||||
|
class LuOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit LuOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
|
||||||
|
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
|
||||||
|
|
||||||
|
using Matrix =
|
||||||
|
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||||
|
using MatrixMap = Eigen::Map<Matrix>;
|
||||||
|
|
||||||
|
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
||||||
|
|
||||||
|
using Indices =
|
||||||
|
Eigen::Matrix<Tidx, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
using IndicesMap = Eigen::Map<Indices>;
|
||||||
|
using ConstIndicesMap = Eigen::Map<const Indices>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Returns the cost per matrix operation. This is used to determine the
|
||||||
|
// number of threads to use for parallelizing factorization in batch mode.
|
||||||
|
// Cost per unit is assumed to be roughly 1ns, based on comments
|
||||||
|
// in core/util/work_sharder.cc.
|
||||||
|
// LU decomposition for a square matrix takes roughly (2/3) * (num_rows)^3.
|
||||||
|
// TODO(anudhyan): Refine this estimate after taking constant factors into
|
||||||
|
// account.
|
||||||
|
int64 GetCostPerUnit(const TensorShape& input_matrix_shape) const {
|
||||||
|
double num_rows = static_cast<double>(input_matrix_shape.dim_size(0));
|
||||||
|
double cost = (2 / 3.0) * MathUtil::IPow(num_rows, 3);
|
||||||
|
return cost >= static_cast<double>(kint64max) ? kint64max
|
||||||
|
: static_cast<int64>(cost);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
OP_REQUIRES(context, context->num_inputs() == 1,
|
||||||
|
errors::InvalidArgument("Expecting exactly one input, got ",
|
||||||
|
context->num_inputs()));
|
||||||
|
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
int input_rank = input.dims();
|
||||||
|
OP_REQUIRES(context, input_rank >= 2,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input tensor must have rank >= 2, got ", input_rank));
|
||||||
|
|
||||||
|
// If the tensor rank is greater than 2, we consider the inner-most
|
||||||
|
// dimensions as matrices, and loop over all the other outer ("batch")
|
||||||
|
// dimensions to compute the results.
|
||||||
|
TensorShape input_matrix_shape;
|
||||||
|
TensorShape batch_shape;
|
||||||
|
for (int dim = 0; dim < input_rank - 2; ++dim) {
|
||||||
|
batch_shape.AddDim(input.dim_size(dim));
|
||||||
|
}
|
||||||
|
const int64 num_rows = input.dim_size(input_rank - 2);
|
||||||
|
const int64 num_cols = input.dim_size(input_rank - 1);
|
||||||
|
|
||||||
|
input_matrix_shape.AppendShape({num_rows, num_cols});
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shape),
|
||||||
|
errors::InvalidArgument("Input matrix must be square."));
|
||||||
|
|
||||||
|
// packed_triangular_factors is a matrix with the same shape as the input;
|
||||||
|
// permutation is a vector.
|
||||||
|
TensorShape permutation_shape = batch_shape;
|
||||||
|
permutation_shape.AddDim(num_rows);
|
||||||
|
|
||||||
|
TensorShapes output_matrix_shapes({input.shape(), permutation_shape});
|
||||||
|
|
||||||
|
TensorOutputs outputs;
|
||||||
|
Tensor* output_packed_triangular_factors = nullptr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context, context->forward_input_or_allocate_output(
|
||||||
|
{0}, 0, input.shape(), &output_packed_triangular_factors));
|
||||||
|
outputs.emplace_back(output_packed_triangular_factors);
|
||||||
|
|
||||||
|
Tensor* output_permutation = nullptr;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(1, permutation_shape,
|
||||||
|
&output_permutation));
|
||||||
|
outputs.emplace_back(output_permutation);
|
||||||
|
|
||||||
|
if (num_rows == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the individual matrix problems in parallel using a threadpool.
|
||||||
|
auto shard = [this, &input, &num_rows, &num_cols, &outputs,
|
||||||
|
&output_matrix_shapes, context](int64 begin, int64 end) {
|
||||||
|
for (int64 i = begin; i < end; ++i) {
|
||||||
|
ComputeTensorSlice(context, i, input, num_rows, num_cols, outputs,
|
||||||
|
output_matrix_shapes);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers,
|
||||||
|
batch_shape.num_elements(), GetCostPerUnit(input_matrix_shape),
|
||||||
|
shard);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
|
||||||
|
const Tensor& input, int64 num_rows, int64 num_cols,
|
||||||
|
const TensorOutputs& outputs,
|
||||||
|
const TensorShapes& output_matrix_shapes) {
|
||||||
|
// TODO(kalakris): Handle alignment if possible. Eigen::Map is
|
||||||
|
// unaligned by default.
|
||||||
|
ConstMatrixMap input_matrix(
|
||||||
|
input.flat<Scalar>().data() + matrix_index * num_rows * num_cols,
|
||||||
|
num_rows, num_cols);
|
||||||
|
|
||||||
|
// packed_triangular_factors has shape [num_rows, num_cols]
|
||||||
|
MatrixMap packed_triangular_factors(
|
||||||
|
outputs[0]->flat<Scalar>().data() + matrix_index * num_rows * num_cols,
|
||||||
|
num_rows, num_rows);
|
||||||
|
|
||||||
|
// permutation has shape [num_rows, 1]
|
||||||
|
IndicesMap permutation_indices(
|
||||||
|
outputs[1]->flat<Tidx>().data() + matrix_index * num_rows, num_rows, 1);
|
||||||
|
|
||||||
|
Eigen::PartialPivLU<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
lu_decomposition(input_matrix);
|
||||||
|
|
||||||
|
// Output the packed triangular factors in a dense form.
|
||||||
|
// The lower triangular factor L corresponds to the strictly lower
|
||||||
|
// triangular part of packed_triangular_factors with an implicit unit
|
||||||
|
// diagonal. The upper triangular factor U is the upper triangular part of
|
||||||
|
// packed_triangular_factors. The triangular factors satisfy the equation
|
||||||
|
// P * input_matrix = L * U
|
||||||
|
// where P is the permutation matrix corresponding to the indices in
|
||||||
|
// permutation_indices.
|
||||||
|
packed_triangular_factors = lu_decomposition.matrixLU();
|
||||||
|
// Output the permutation matrix used for pivoting.
|
||||||
|
Eigen::PermutationMatrix<-1, -1, Tidx> permutation =
|
||||||
|
lu_decomposition.permutationP().transpose();
|
||||||
|
permutation_indices = permutation.indices();
|
||||||
|
|
||||||
|
// PartialPivLU cannot give strong guarantees on invertibility,
|
||||||
|
// but we can at least guard against exact zero pivots. This can occur as
|
||||||
|
// a result of basic user mistakes such providing integer valued
|
||||||
|
// matrices that are exactly singular, or due to underflow if this
|
||||||
|
// code is run with denormals being flushed to zero.
|
||||||
|
const RealScalar min_abs_pivot =
|
||||||
|
packed_triangular_factors.diagonal().cwiseAbs().minCoeff();
|
||||||
|
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
|
||||||
|
errors::InvalidArgument("Input is not invertible."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_LU(type, idx_type) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("Lu") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.TypeConstraint<idx_type>("output_idx_type"), \
|
||||||
|
LuOp<type, idx_type>);
|
||||||
|
|
||||||
|
REGISTER_LU(float, int32);
|
||||||
|
REGISTER_LU(double, int32);
|
||||||
|
REGISTER_LU(complex64, int32);
|
||||||
|
REGISTER_LU(complex128, int32);
|
||||||
|
|
||||||
|
REGISTER_LU(float, int64);
|
||||||
|
REGISTER_LU(double, int64);
|
||||||
|
REGISTER_LU(complex64, int64);
|
||||||
|
REGISTER_LU(complex128, int64);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
275
tensorflow/core/kernels/lu_op_gpu.cu.cc
Normal file
275
tensorflow/core/kernels/lu_op_gpu.cu.cc
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename Scalar>
|
||||||
|
__device__ void ComputePermutationFromTranspositions(
|
||||||
|
int64 num_rows, const int* pivots, Scalar* permutation_indices) {
|
||||||
|
// Fill in the output array with the identity permutation.
|
||||||
|
for (int i = 0; i < num_rows; ++i) {
|
||||||
|
permutation_indices[i] = Scalar(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the permutation from a sequence of transpositions encoded
|
||||||
|
// in the pivot array by applying the transpositions in order on the
|
||||||
|
// identity permutation.
|
||||||
|
for (int i = 0; i < num_rows; ++i) {
|
||||||
|
// Note: Internally, the cuBlas code uses Fortran convention (1-based)
|
||||||
|
// indexing so ith row was swapped with (pivots[i]-1)'th row in 0-based
|
||||||
|
// indexing.
|
||||||
|
Scalar t = permutation_indices[i];
|
||||||
|
permutation_indices[i] = permutation_indices[pivots[i] - 1];
|
||||||
|
permutation_indices[pivots[i] - 1] = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Kernel to compute the inverse of a permutation from a sequence of
|
||||||
|
// transpositions.
|
||||||
|
template <typename Scalar>
|
||||||
|
__global__ void ComputePermutationFromTranspositionsKernel(
|
||||||
|
CudaLaunchConfig config, const int64 num_rows, const int* all_pivots,
|
||||||
|
Scalar* all_permutation_indices) {
|
||||||
|
// We only parallelize over batches here. Performance is not critical,
|
||||||
|
// since this cheap O(num_rows) kernel always follows an O(num_rows^3)
|
||||||
|
// LU factorization.
|
||||||
|
CUDA_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
|
||||||
|
ComputePermutationFromTranspositions(
|
||||||
|
num_rows, all_pivots + index * num_rows,
|
||||||
|
all_permutation_indices + index * num_rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Scalar, class Tidx>
|
||||||
|
class LuOpGpu : public AsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit LuOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {}
|
||||||
|
|
||||||
|
void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
|
||||||
|
// Analyze shape and validate inputs.
|
||||||
|
const int input_rank = input.dims();
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, input_rank >= 2,
|
||||||
|
errors::InvalidArgument("Input must have rank >= 2, got ", input_rank),
|
||||||
|
done);
|
||||||
|
|
||||||
|
const int64 num_rows = input.dim_size(input_rank - 2);
|
||||||
|
const int64 num_cols = input.dim_size(input_rank - 1);
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, num_rows == num_cols,
|
||||||
|
errors::InvalidArgument("Input matrices must be squares, got", num_rows,
|
||||||
|
" != ", num_cols),
|
||||||
|
done);
|
||||||
|
|
||||||
|
TensorShape batch_shape;
|
||||||
|
for (int dim = 0; dim < input_rank - 2; ++dim) {
|
||||||
|
batch_shape.AddDim(input.dim_size(dim));
|
||||||
|
}
|
||||||
|
TensorShape permutation_indices_shape = batch_shape;
|
||||||
|
permutation_indices_shape.AddDim(num_rows);
|
||||||
|
|
||||||
|
const GPUDevice& device = context->eigen_device<GPUDevice>();
|
||||||
|
auto solver = absl::make_unique<CudaSolver>(context);
|
||||||
|
|
||||||
|
// We output the packed triangular factors in a dense form.
|
||||||
|
// The lower triangular factor L corresponds to the strictly lower
|
||||||
|
// triangular part of packed_triangular_factors with an implicit unit
|
||||||
|
// diagonal. The upper triangular factor U is the upper triangular part of
|
||||||
|
// packed_triangular_factors. The triangular factors satisfy the equation
|
||||||
|
// P * input_matrix = L * U
|
||||||
|
// where P is the permutation matrix corresponding to the indices in
|
||||||
|
// permutation_indices.
|
||||||
|
//
|
||||||
|
// Reuse the input buffer or make a copy for the factorization step,
|
||||||
|
// depending on whether this ops owns it exclusively.
|
||||||
|
Tensor* packed_triangular_factors;
|
||||||
|
OP_REQUIRES_OK_ASYNC(context,
|
||||||
|
context->forward_input_or_allocate_output(
|
||||||
|
{0}, 0, input.shape(), &packed_triangular_factors),
|
||||||
|
done);
|
||||||
|
if (!packed_triangular_factors->SharesBufferWith(input)) {
|
||||||
|
device.memcpy(packed_triangular_factors->flat<Scalar>().data(),
|
||||||
|
input.flat<Scalar>().data(),
|
||||||
|
input.NumElements() * sizeof(Scalar));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate output permutation.
|
||||||
|
Tensor* permutation_indices = nullptr;
|
||||||
|
OP_REQUIRES_OK_ASYNC(context,
|
||||||
|
context->allocate_output(1, permutation_indices_shape,
|
||||||
|
&permutation_indices),
|
||||||
|
done);
|
||||||
|
|
||||||
|
if (input.NumElements() == 0) {
|
||||||
|
done();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate a temporary Tensor to store the transposed packed triangular
|
||||||
|
// factors.
|
||||||
|
Tensor packed_triangular_factors_transpose;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<Scalar>::value, input.shape(),
|
||||||
|
&packed_triangular_factors_transpose),
|
||||||
|
done);
|
||||||
|
auto packed_triangular_factors_transpose_reshaped =
|
||||||
|
packed_triangular_factors_transpose
|
||||||
|
.template flat_inner_dims<Scalar, 3>();
|
||||||
|
const int64 batch_size =
|
||||||
|
packed_triangular_factors_transpose_reshaped.dimension(0);
|
||||||
|
|
||||||
|
// Allocate pivots on the device.
|
||||||
|
Tensor pivots;
|
||||||
|
OP_REQUIRES_OK_ASYNC(context,
|
||||||
|
solver->allocate_scoped_tensor(
|
||||||
|
DataTypeToEnum<int32>::value,
|
||||||
|
TensorShape{batch_size, num_rows}, &pivots),
|
||||||
|
done);
|
||||||
|
auto pivots_mat = pivots.template matrix<int32>();
|
||||||
|
|
||||||
|
// Transpose the input. This is necessary because cuBLAS assumes
|
||||||
|
// column-major storage while TensorFlow uses row-major.
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
DoMatrixTranspose(device, *packed_triangular_factors,
|
||||||
|
&packed_triangular_factors_transpose),
|
||||||
|
done);
|
||||||
|
|
||||||
|
std::vector<DeviceLapackInfo> dev_info;
|
||||||
|
if (num_rows == num_cols && num_rows / batch_size <= 128) {
|
||||||
|
// For small matrices or large batch sizes, we use the batched
|
||||||
|
// interface from cuBlas.
|
||||||
|
auto packed_triangular_factors_ptrs = solver->GetScratchSpace<uint8>(
|
||||||
|
sizeof(Scalar*) * batch_size, "packed_triangular_factors_ptrs",
|
||||||
|
/* on_host */ true);
|
||||||
|
const Scalar** packed_triangular_factors_ptrs_base =
|
||||||
|
reinterpret_cast<const Scalar**>(
|
||||||
|
packed_triangular_factors_ptrs.mutable_data());
|
||||||
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
|
packed_triangular_factors_ptrs_base[batch] =
|
||||||
|
&packed_triangular_factors_transpose_reshaped(batch, 0, 0);
|
||||||
|
}
|
||||||
|
dev_info.push_back(
|
||||||
|
solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
solver->GetrfBatched(num_rows, packed_triangular_factors_ptrs_base,
|
||||||
|
num_rows, pivots_mat.data(), &dev_info.back(),
|
||||||
|
batch_size),
|
||||||
|
done);
|
||||||
|
} else {
|
||||||
|
// For small batch sizes we use the non-batched interface from cuSolver,
|
||||||
|
// which is much faster for large matrices.
|
||||||
|
dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
|
||||||
|
for (int batch = 0; batch < batch_size; ++batch) {
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
solver->Getrf(
|
||||||
|
num_rows, num_cols,
|
||||||
|
&packed_triangular_factors_transpose_reshaped(batch, 0, 0),
|
||||||
|
num_rows, &pivots_mat(batch, 0), &dev_info.back()(batch)),
|
||||||
|
done);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transpose the result since we had transposed the input.
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
DoMatrixTranspose(device, packed_triangular_factors_transpose,
|
||||||
|
packed_triangular_factors),
|
||||||
|
done);
|
||||||
|
|
||||||
|
// Pivots encode the permutation of the rows as a sequences of row swaps.
|
||||||
|
// For each index i, row i is swapped with row pivots[i].
|
||||||
|
int* pivots_ptr = pivots.flat<int>().data();
|
||||||
|
Tidx* permutation_indices_ptr =
|
||||||
|
permutation_indices->template flat<Tidx>().data();
|
||||||
|
CudaLaunchConfig cfgPivots = GetCudaLaunchConfig(batch_size, device);
|
||||||
|
ComputePermutationFromTranspositionsKernel<<<cfgPivots.block_count,
|
||||||
|
cfgPivots.thread_per_block, 0,
|
||||||
|
device.stream()>>>(
|
||||||
|
cfgPivots, num_rows, pivots_ptr, permutation_indices_ptr);
|
||||||
|
|
||||||
|
// Callback for checking info after kernels finish. Also capture the
|
||||||
|
// temporary Tensors/ScratchSpace so they don't get deallocated before the
|
||||||
|
// kernels run.
|
||||||
|
// TODO(rmlarsen): Use move capture once C++14 becomes available.
|
||||||
|
auto info_checker = [context, done, dev_info](
|
||||||
|
const Status& status,
|
||||||
|
const std::vector<HostLapackInfo>& host_infos) {
|
||||||
|
if (!status.ok() && errors::IsInvalidArgument(status) &&
|
||||||
|
!host_infos.empty()) {
|
||||||
|
for (int i = 0; i < host_infos[0].size(); ++i) {
|
||||||
|
// Match the CPU error message for singular matrices. Otherwise
|
||||||
|
// just print the original error message from the status below.
|
||||||
|
OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0,
|
||||||
|
errors::InvalidArgument("Input is not invertible."),
|
||||||
|
done);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK_ASYNC(context, status, done);
|
||||||
|
done();
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
|
||||||
|
std::move(info_checker));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_LU_GPU(type, idx_type) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("Lu") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.TypeConstraint<idx_type>("output_idx_type"), \
|
||||||
|
LuOpGpu<type, idx_type>);
|
||||||
|
|
||||||
|
REGISTER_LU_GPU(float, int32);
|
||||||
|
REGISTER_LU_GPU(double, int32);
|
||||||
|
REGISTER_LU_GPU(complex64, int32);
|
||||||
|
REGISTER_LU_GPU(complex128, int32);
|
||||||
|
|
||||||
|
REGISTER_LU_GPU(float, int64);
|
||||||
|
REGISTER_LU_GPU(double, int64);
|
||||||
|
REGISTER_LU_GPU(complex64, int64);
|
||||||
|
REGISTER_LU_GPU(complex128, int64);
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -109,6 +109,30 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Input is [...,N,N].
|
||||||
|
// First and second outputs are:
|
||||||
|
// [...,N,N]; [...,N].
|
||||||
|
Status LuShapeFn(InferenceContext* c) {
|
||||||
|
ShapeHandle input;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
|
||||||
|
|
||||||
|
DimensionHandle n;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
|
||||||
|
|
||||||
|
ShapeHandle batch_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
|
||||||
|
|
||||||
|
ShapeHandle lu_shape;
|
||||||
|
ShapeHandle p_shape;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
|
||||||
|
|
||||||
|
c->set_output(0, lu_shape);
|
||||||
|
c->set_output(1, p_shape);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Input is [...,M,N].
|
// Input is [...,M,N].
|
||||||
// First and second outputs are:
|
// First and second outputs are:
|
||||||
// [...,M,M]; [...,M,N], if full_matrices is true,
|
// [...,M,M]; [...,M,N], if full_matrices is true,
|
||||||
@ -289,6 +313,14 @@ REGISTER_OP("SelfAdjointEigV2")
|
|||||||
.Attr("T: {double, float, complex64, complex128}")
|
.Attr("T: {double, float, complex64, complex128}")
|
||||||
.SetShapeFn(SelfAdjointEigV2ShapeFn);
|
.SetShapeFn(SelfAdjointEigV2ShapeFn);
|
||||||
|
|
||||||
|
REGISTER_OP("Lu")
|
||||||
|
.Input("input: T")
|
||||||
|
.Output("lu: T")
|
||||||
|
.Output("p: output_idx_type")
|
||||||
|
.Attr("T: {double, float, complex64, complex128}")
|
||||||
|
.Attr("output_idx_type: {int32, int64} = DT_INT32")
|
||||||
|
.SetShapeFn(LuShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("MatrixSolve")
|
REGISTER_OP("MatrixSolve")
|
||||||
.Input("matrix: T")
|
.Input("matrix: T")
|
||||||
.Input("rhs: T")
|
.Input("rhs: T")
|
||||||
|
@ -274,4 +274,23 @@ TEST(LinalgOpsTest, Svd_ShapeFn) {
|
|||||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
|
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(LinalgOpsTest, Lu_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("Lu");
|
||||||
|
INFER_OK(op, "?", "?;?");
|
||||||
|
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
|
||||||
|
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?,3,4,1,2]");
|
||||||
|
|
||||||
|
INFER_OK(op, "[?,?]", "[d0_0,d0_0];[d0_0]");
|
||||||
|
INFER_OK(op, "[1,?]", "[d0_0,d0_0];[d0_0]");
|
||||||
|
INFER_OK(op, "[?,1]", "[d0_1,d0_1];[d0_1]");
|
||||||
|
|
||||||
|
// Repeat previous block of tests with input rank > 2.
|
||||||
|
INFER_OK(op, "[1,?,3,4,?,?]",
|
||||||
|
"[d0_0,d0_1,d0_2,d0_3,d0_4,d0_4];[d0_0,d0_1,d0_2,d0_3,d0_4]");
|
||||||
|
INFER_OK(op, "[1,?,3,4,1,?]",
|
||||||
|
"[d0_0,d0_1,d0_2,d0_3,d0_4,d0_4];[d0_0,d0_1,d0_2,d0_3,d0_4]");
|
||||||
|
INFER_OK(op, "[1,?,3,4,?,1]",
|
||||||
|
"[d0_0,d0_1,d0_2,d0_3,d0_5,d0_5];[d0_0,d0_1,d0_2,d0_3,d0_5]");
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -1888,6 +1888,22 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "lu_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["lu_op_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:linalg_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python/ops/linalg",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "manip_ops_test",
|
name = "manip_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
285
tensorflow/python/kernel_tests/lu_op_test.py
Normal file
285
tensorflow/python/kernel_tests/lu_op_test.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.ops.tf.Lu."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import functional_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import benchmark
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class LuOpTest(test.TestCase):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def float_types(self):
|
||||||
|
return set((np.float64, np.float32, np.complex64, np.complex128))
|
||||||
|
|
||||||
|
def _verifyLuBase(self, sess, x, lower, upper, perm, verification,
|
||||||
|
output_idx_type):
|
||||||
|
lower_np, upper_np, perm_np, verification_np = sess.run(
|
||||||
|
[lower, upper, perm, verification])
|
||||||
|
|
||||||
|
self.assertAllClose(x, verification_np)
|
||||||
|
self.assertShapeEqual(x, lower)
|
||||||
|
self.assertShapeEqual(x, upper)
|
||||||
|
|
||||||
|
self.assertAllEqual(x.shape[:-1], perm.shape.as_list())
|
||||||
|
|
||||||
|
# Check dtypes are as expected.
|
||||||
|
self.assertEqual(x.dtype, lower_np.dtype)
|
||||||
|
self.assertEqual(x.dtype, upper_np.dtype)
|
||||||
|
self.assertEqual(output_idx_type.as_numpy_dtype, perm_np.dtype)
|
||||||
|
|
||||||
|
# Check that the permutation is valid.
|
||||||
|
if perm_np.shape[-1] > 0:
|
||||||
|
perm_reshaped = np.reshape(perm_np, (-1, perm_np.shape[-1]))
|
||||||
|
for perm_vector in perm_reshaped:
|
||||||
|
self.assertAllClose(np.arange(len(perm_vector)), np.sort(perm_vector))
|
||||||
|
|
||||||
|
def _verifyLu(self, x, output_idx_type=dtypes.int64):
|
||||||
|
# Verify that Px = LU.
|
||||||
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
|
|
||||||
|
lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type)
|
||||||
|
|
||||||
|
# Prepare the lower factor of shape num_rows x num_rows
|
||||||
|
lu_shape = np.array(lu.shape.as_list())
|
||||||
|
batch_shape = lu_shape[:-2]
|
||||||
|
num_rows = lu_shape[-2]
|
||||||
|
num_cols = lu_shape[-1]
|
||||||
|
|
||||||
|
lower = array_ops.matrix_band_part(lu, -1, 0)
|
||||||
|
|
||||||
|
if num_rows > num_cols:
|
||||||
|
eye = linalg_ops.eye(
|
||||||
|
num_rows, batch_shape=batch_shape, dtype=lower.dtype)
|
||||||
|
lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1)
|
||||||
|
elif num_rows < num_cols:
|
||||||
|
lower = lower[..., :num_rows]
|
||||||
|
|
||||||
|
# Fill the diagonal with ones.
|
||||||
|
ones_diag = array_ops.ones(
|
||||||
|
np.append(batch_shape, num_rows), dtype=lower.dtype)
|
||||||
|
lower = array_ops.matrix_set_diag(lower, ones_diag)
|
||||||
|
|
||||||
|
# Prepare the upper factor.
|
||||||
|
upper = array_ops.matrix_band_part(lu, 0, -1)
|
||||||
|
|
||||||
|
verification = math_ops.matmul(lower, upper)
|
||||||
|
|
||||||
|
# Permute the rows of product of the Cholesky factors.
|
||||||
|
if num_rows > 0:
|
||||||
|
# Reshape the product of the triangular factors and permutation indices
|
||||||
|
# to a single batch dimension. This makes it easy to apply
|
||||||
|
# invert_permutation and gather_nd ops.
|
||||||
|
perm_reshaped = array_ops.reshape(perm, [-1, num_rows])
|
||||||
|
verification_reshaped = array_ops.reshape(verification,
|
||||||
|
[-1, num_rows, num_cols])
|
||||||
|
# Invert the permutation in each batch.
|
||||||
|
inv_perm_reshaped = functional_ops.map_fn(array_ops.invert_permutation,
|
||||||
|
perm_reshaped)
|
||||||
|
batch_size = perm_reshaped.shape.as_list()[0]
|
||||||
|
# Prepare the batch indices with the same shape as the permutation.
|
||||||
|
# The corresponding batch index is paired with each of the `num_rows`
|
||||||
|
# permutation indices.
|
||||||
|
batch_indices = math_ops.cast(
|
||||||
|
array_ops.broadcast_to(
|
||||||
|
math_ops.range(batch_size)[:, None], perm_reshaped.shape),
|
||||||
|
dtype=output_idx_type)
|
||||||
|
permuted_verification_reshaped = array_ops.gather_nd(
|
||||||
|
verification_reshaped,
|
||||||
|
array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1))
|
||||||
|
|
||||||
|
# Reshape the verification matrix back to the original shape.
|
||||||
|
verification = array_ops.reshape(permuted_verification_reshaped,
|
||||||
|
lu_shape)
|
||||||
|
|
||||||
|
self._verifyLuBase(sess, x, lower, upper, perm, verification,
|
||||||
|
output_idx_type)
|
||||||
|
|
||||||
|
def testBasic(self):
|
||||||
|
data = np.array([[4., -1., 2.], [-1., 6., 0], [10., 0., 5.]])
|
||||||
|
|
||||||
|
for dtype in (np.float32, np.float64):
|
||||||
|
for output_idx_type in (dtypes.int32, dtypes.int64):
|
||||||
|
self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
|
||||||
|
|
||||||
|
for dtype in (np.complex64, np.complex128):
|
||||||
|
for output_idx_type in (dtypes.int32, dtypes.int64):
|
||||||
|
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||||
|
complex_data += np.triu(-1j * data, 1).astype(dtype)
|
||||||
|
complex_data += data
|
||||||
|
self._verifyLu(complex_data, output_idx_type=output_idx_type)
|
||||||
|
|
||||||
|
def testPivoting(self):
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
# This matrix triggers partial pivoting because the first diagonal entry
|
||||||
|
# is small.
|
||||||
|
data = np.array([[1e-9, 1., 0.], [1., 0., 0], [0., 1., 5]])
|
||||||
|
self._verifyLu(data.astype(np.float32))
|
||||||
|
|
||||||
|
for dtype in (np.float32, np.float64):
|
||||||
|
self._verifyLu(data.astype(dtype))
|
||||||
|
_, p = linalg_ops.lu(data)
|
||||||
|
p_val = sess.run([p])
|
||||||
|
# Make sure p_val is not the identity permutation.
|
||||||
|
self.assertNotAllClose(np.arange(3), p_val)
|
||||||
|
|
||||||
|
for dtype in (np.complex64, np.complex128):
|
||||||
|
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||||
|
complex_data += np.triu(-1j * data, 1).astype(dtype)
|
||||||
|
complex_data += data
|
||||||
|
self._verifyLu(complex_data)
|
||||||
|
_, p = linalg_ops.lu(data)
|
||||||
|
p_val = sess.run([p])
|
||||||
|
# Make sure p_val is not the identity permutation.
|
||||||
|
self.assertNotAllClose(np.arange(3), p_val)
|
||||||
|
|
||||||
|
def testInvalidMatrix(self):
|
||||||
|
# LU factorization gives an error when the input is singular.
|
||||||
|
# Note: A singular matrix may return without error but it won't be a valid
|
||||||
|
# factorization.
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
for dtype in self.float_types:
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
sess.run(
|
||||||
|
linalg_ops.lu(
|
||||||
|
np.array([[1., 2., 3.], [2., 4., 6.], [2., 3., 4.]],
|
||||||
|
dtype=dtype)))
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
sess.run(
|
||||||
|
linalg_ops.lu(
|
||||||
|
np.array([[[1., 2., 3.], [2., 4., 6.], [1., 2., 3.]],
|
||||||
|
[[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]],
|
||||||
|
dtype=dtype)))
|
||||||
|
|
||||||
|
def testBatch(self):
|
||||||
|
simple_array = np.array([[[1., -1.], [2., 5.]]]) # shape (1, 2, 2)
|
||||||
|
self._verifyLu(simple_array)
|
||||||
|
self._verifyLu(np.vstack((simple_array, simple_array)))
|
||||||
|
odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
|
||||||
|
self._verifyLu(np.vstack((odd_sized_array, odd_sized_array)))
|
||||||
|
|
||||||
|
batch_size = 200
|
||||||
|
|
||||||
|
# Generate random matrices.
|
||||||
|
np.random.seed(42)
|
||||||
|
matrices = np.random.rand(batch_size, 5, 5)
|
||||||
|
self._verifyLu(matrices)
|
||||||
|
|
||||||
|
# Generate random complex valued matrices.
|
||||||
|
np.random.seed(52)
|
||||||
|
matrices = np.random.rand(batch_size, 5,
|
||||||
|
5) + 1j * np.random.rand(batch_size, 5, 5)
|
||||||
|
self._verifyLu(matrices)
|
||||||
|
|
||||||
|
def testLargeMatrix(self):
|
||||||
|
# Generate random matrices.
|
||||||
|
n = 500
|
||||||
|
np.random.seed(64)
|
||||||
|
data = np.random.rand(n, n)
|
||||||
|
self._verifyLu(data)
|
||||||
|
|
||||||
|
# Generate random complex valued matrices.
|
||||||
|
np.random.seed(129)
|
||||||
|
data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
|
||||||
|
self._verifyLu(data)
|
||||||
|
|
||||||
|
def testEmpty(self):
|
||||||
|
self._verifyLu(np.empty([0, 2, 2]))
|
||||||
|
self._verifyLu(np.empty([2, 0, 0]))
|
||||||
|
|
||||||
|
def testConcurrentExecutesWithoutError(self):
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
matrix1 = random_ops.random_normal([5, 5], seed=42)
|
||||||
|
matrix2 = random_ops.random_normal([5, 5], seed=42)
|
||||||
|
lu1, p1 = linalg_ops.lu(matrix1)
|
||||||
|
lu2, p2 = linalg_ops.lu(matrix2)
|
||||||
|
lu1_val, p1_val, lu2_val, p2_val = sess.run([lu1, p1, lu2, p2])
|
||||||
|
self.assertAllEqual(lu1_val, lu2_val)
|
||||||
|
self.assertAllEqual(p1_val, p2_val)
|
||||||
|
|
||||||
|
|
||||||
|
class LuBenchmark(test.Benchmark):
|
||||||
|
shapes = [
|
||||||
|
(4, 4),
|
||||||
|
(10, 10),
|
||||||
|
(16, 16),
|
||||||
|
(101, 101),
|
||||||
|
(256, 256),
|
||||||
|
(1000, 1000),
|
||||||
|
(1024, 1024),
|
||||||
|
(2048, 2048),
|
||||||
|
(4096, 4096),
|
||||||
|
(513, 2, 2),
|
||||||
|
(513, 8, 8),
|
||||||
|
(513, 256, 256),
|
||||||
|
(4, 513, 2, 2),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _GenerateMatrix(self, shape):
|
||||||
|
batch_shape = shape[:-2]
|
||||||
|
shape = shape[-2:]
|
||||||
|
assert shape[0] == shape[1]
|
||||||
|
n = shape[0]
|
||||||
|
matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
|
||||||
|
np.ones(n).astype(np.float32))
|
||||||
|
return np.tile(matrix, batch_shape + (1, 1))
|
||||||
|
|
||||||
|
def benchmarkLuOp(self):
|
||||||
|
for shape in self.shapes:
|
||||||
|
with ops.Graph().as_default(), \
|
||||||
|
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||||
|
ops.device("/cpu:0"):
|
||||||
|
matrix = variables.Variable(self._GenerateMatrix(shape))
|
||||||
|
lu, p = linalg_ops.lu(matrix)
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.run_op_benchmark(
|
||||||
|
sess,
|
||||||
|
control_flow_ops.group(lu, p),
|
||||||
|
min_iters=25,
|
||||||
|
name="lu_cpu_{shape}".format(shape=shape))
|
||||||
|
|
||||||
|
if test.is_gpu_available(True):
|
||||||
|
with ops.Graph().as_default(), \
|
||||||
|
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||||
|
ops.device("/device:GPU:0"):
|
||||||
|
matrix = variables.Variable(self._GenerateMatrix(shape))
|
||||||
|
lu, p = linalg_ops.lu(matrix)
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.run_op_benchmark(
|
||||||
|
sess,
|
||||||
|
control_flow_ops.group(lu, p),
|
||||||
|
min_iters=25,
|
||||||
|
name="lu_gpu_{shape}".format(shape=shape))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -44,6 +44,7 @@ einsum = special_math_ops.einsum
|
|||||||
eye = linalg_ops.eye
|
eye = linalg_ops.eye
|
||||||
inv = linalg_ops.matrix_inverse
|
inv = linalg_ops.matrix_inverse
|
||||||
logm = gen_linalg_ops.matrix_logarithm
|
logm = gen_linalg_ops.matrix_logarithm
|
||||||
|
lu = gen_linalg_ops.lu
|
||||||
tf_export('linalg.logm')(logm)
|
tf_export('linalg.logm')(logm)
|
||||||
lstsq = linalg_ops.matrix_solve_ls
|
lstsq = linalg_ops.matrix_solve_ls
|
||||||
norm = linalg_ops.norm
|
norm = linalg_ops.norm
|
||||||
|
@ -132,6 +132,10 @@ tf_module {
|
|||||||
name: "lstsq"
|
name: "lstsq"
|
||||||
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
|
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lu"
|
||||||
|
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "matmul"
|
name: "matmul"
|
||||||
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
|
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
|
||||||
|
@ -132,6 +132,10 @@ tf_module {
|
|||||||
name: "lstsq"
|
name: "lstsq"
|
||||||
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
|
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "lu"
|
||||||
|
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "matmul"
|
name: "matmul"
|
||||||
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
|
argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user