Add banded triangular solve op.
PiperOrigin-RevId: 317124054 Change-Id: I54f090d7583b21fa18788a2deb02262d9c8231be
This commit is contained in:
parent
18f54c42c6
commit
89b80c5fb9
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BandedTriangularSolve"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BandedTriangularSolve"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -3577,6 +3577,7 @@ tf_cc_tests(
|
||||
cc_library(
|
||||
name = "linalg",
|
||||
deps = [
|
||||
":banded_triangular_solve_op",
|
||||
":cholesky_grad",
|
||||
":cholesky_op",
|
||||
":determinant_op",
|
||||
@ -3750,6 +3751,12 @@ tf_kernel_library(
|
||||
deps = LINALG_DEPS,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "banded_triangular_solve_op",
|
||||
prefix = "banded_triangular_solve_op",
|
||||
deps = LINALG_DEPS + [":fill_functor"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "matrix_triangular_solve_op",
|
||||
hdrs = ["matrix_triangular_solve_op_impl.h"],
|
||||
@ -4425,6 +4432,26 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "banded_triangular_solve_op_test",
|
||||
size = "small",
|
||||
srcs = ["banded_triangular_solve_op_test.cc"],
|
||||
deps = [
|
||||
":banded_triangular_solve_op",
|
||||
":matrix_set_diag_op",
|
||||
":matrix_triangular_solve_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "matrix_triangular_solve_op_test",
|
||||
size = "small",
|
||||
|
293
tensorflow/core/kernels/banded_triangular_solve_op.cc
Normal file
293
tensorflow/core/kernels/banded_triangular_solve_op.cc
Normal file
@ -0,0 +1,293 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/linalg_ops.cc.
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename Scalar>
|
||||
Scalar eigen_conj(const Scalar& scalar) {
|
||||
return Eigen::numext::conj<Scalar>(scalar);
|
||||
}
|
||||
|
||||
// Sequential batch matrix triangular solve kernel that calls Eigen's
|
||||
// matrix triangular solve.
|
||||
template <typename Scalar>
|
||||
struct SequentialBandedTriangularSolveKernel {
|
||||
using Matrix =
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using MatrixMap = Eigen::Map<Matrix>;
|
||||
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
||||
|
||||
static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
|
||||
int slice) {
|
||||
return ConstMatrixMap(
|
||||
t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
|
||||
t.dim_size(1), t.dim_size(2));
|
||||
}
|
||||
|
||||
static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
|
||||
return MatrixMap(
|
||||
t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
|
||||
t->dim_size(1), t->dim_size(2));
|
||||
}
|
||||
|
||||
static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
|
||||
bool adjoint, const MatMulBCast& bcast, Tensor* out,
|
||||
int start, int limit) {
|
||||
const bool should_bcast = bcast.IsBroadcastingRequired();
|
||||
const auto& x_batch_indices = bcast.x_batch_indices();
|
||||
const auto& y_batch_indices = bcast.y_batch_indices();
|
||||
int num_bands = in_x.dim_size(1);
|
||||
int matrix_size = in_x.dim_size(2);
|
||||
|
||||
for (int64 i = start; i < limit; ++i) {
|
||||
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
|
||||
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
|
||||
auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
|
||||
auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
|
||||
auto output = TensorSliceToEigenMatrix(out, i);
|
||||
// Below, we use the standard algorithm for computing a triangular solve,
|
||||
// except we band limit it.
|
||||
// Given A x = b, where A is lower triangular,
|
||||
// x_i = (b_i - sum a_ij * x_j) / a_ii, where the sum is from
|
||||
// j = 0 to i - 1.
|
||||
//
|
||||
// Now, in a banded triangular matrix, when i exceeds the band size,
|
||||
// then the sum goes from j = i - band_size to i - 1, since the other
|
||||
// elements are zero.
|
||||
//
|
||||
// Finally, given the band storage format, we'll need to change the
|
||||
// indexing.
|
||||
if (lower) {
|
||||
if (!adjoint) {
|
||||
output.row(0) = rhs.row(0) / matrix(0, 0);
|
||||
for (int i = 1; i < matrix_size; ++i) {
|
||||
if (i < num_bands) {
|
||||
output.row(i).noalias() =
|
||||
(rhs.row(i) - matrix.block(1, i, i, 1).reverse().transpose() *
|
||||
output.topRows(i)) /
|
||||
matrix(0, i);
|
||||
} else {
|
||||
output.row(i).noalias() =
|
||||
(rhs.row(i) -
|
||||
matrix.block(1, i, num_bands - 1, 1).reverse().transpose() *
|
||||
output.middleRows(i - (num_bands - 1), num_bands - 1)) /
|
||||
matrix(0, i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// In the adjoint case, here and below, we now have an upper (lower)
|
||||
// triangular matrix, and thus need to work through with the other
|
||||
// case. We can't simply conjugate `matrix` and use the upper (lower)
|
||||
// algorithm because the band storage format for upper and lower
|
||||
// triangular matrices are different (in the lower case, we pad
|
||||
// entries on the left, and in the upper case we pad entries on the
|
||||
// right.
|
||||
output.row(matrix_size - 1) =
|
||||
rhs.row(matrix_size - 1) / eigen_conj(matrix(0, matrix_size - 1));
|
||||
for (int i = matrix_size - 1; i >= 0; --i) {
|
||||
output.row(i).noalias() = rhs.row(i);
|
||||
for (int j = i + 1; j < std::min(matrix_size, i + num_bands); ++j) {
|
||||
output.row(i).noalias() -=
|
||||
eigen_conj(matrix(j - i, j)) * output.row(j);
|
||||
}
|
||||
output.row(i) /= eigen_conj(matrix(0, i));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!adjoint) {
|
||||
output.row(matrix_size - 1) =
|
||||
rhs.row(matrix_size - 1) / matrix(num_bands - 1, matrix_size - 1);
|
||||
for (int i = 1; i < matrix_size; ++i) {
|
||||
int k = matrix_size - 1 - i;
|
||||
if (i < num_bands) {
|
||||
output.row(k).noalias() =
|
||||
(rhs.row(k) - matrix.block(num_bands - 1 - i, k, i, 1)
|
||||
.reverse()
|
||||
.transpose() *
|
||||
output.bottomRows(i)) /
|
||||
matrix(num_bands - 1, k);
|
||||
} else {
|
||||
output.row(k).noalias() =
|
||||
(rhs.row(k) -
|
||||
matrix.block(0, k, num_bands - 1, 1).reverse().transpose() *
|
||||
output.middleRows(k + 1, num_bands - 1)) /
|
||||
matrix(num_bands - 1, k);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output.row(0) = rhs.row(0) / eigen_conj(matrix(num_bands - 1, 0));
|
||||
for (int i = 1; i < matrix_size; ++i) {
|
||||
output.row(i).noalias() = rhs.row(i);
|
||||
for (int j = std::max(0, i - (num_bands - 1)); j < i; ++j) {
|
||||
output.row(i).noalias() -=
|
||||
eigen_conj(matrix(num_bands - 1 - (i - j), j)) *
|
||||
output.row(j);
|
||||
}
|
||||
output.row(i) /= eigen_conj(matrix(num_bands - 1, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchBandedTriangularSolve;
|
||||
|
||||
template <typename Scalar>
|
||||
struct LaunchBatchBandedTriangularSolve {
|
||||
static void Launch(OpKernelContext* context, const Tensor& in_x,
|
||||
const Tensor& in_y, bool adjoint, bool lower,
|
||||
const MatMulBCast& bcast, Tensor* out) {
|
||||
// Number of banded matrix triangular solves i.e. size of the batch.
|
||||
const int64 batch_size = bcast.output_batch_size();
|
||||
const int64 cost_per_unit =
|
||||
in_x.dim_size(1) * in_x.dim_size(2) * in_y.dim_size(2);
|
||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
using Matrix =
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
||||
// Check diagonal before doing any solves. This is the first row in the
|
||||
// lower case and else is the last row.
|
||||
auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
|
||||
in_x.dim_size(2));
|
||||
RealScalar min_abs_pivot;
|
||||
if (lower) {
|
||||
min_abs_pivot = matrix.row(0).cwiseAbs().minCoeff();
|
||||
} else {
|
||||
min_abs_pivot = matrix.row(in_x.dim_size(1) - 1).cwiseAbs().minCoeff();
|
||||
}
|
||||
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
|
||||
errors::InvalidArgument("Input matrix is not invertible."));
|
||||
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||
cost_per_unit,
|
||||
[&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
|
||||
SequentialBandedTriangularSolveKernel<Scalar>::Run(
|
||||
in_x, in_y, lower, adjoint, bcast, out, start, limit);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
class BandedTriangularSolveOpCpu : public OpKernel {
|
||||
public:
|
||||
explicit BandedTriangularSolveOpCpu(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
|
||||
}
|
||||
|
||||
~BandedTriangularSolveOpCpu() override {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& in0 = ctx->input(0);
|
||||
const Tensor& in1 = ctx->input(1);
|
||||
|
||||
ValidateInputTensors(ctx, in0, in1);
|
||||
|
||||
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
|
||||
OP_REQUIRES(
|
||||
ctx, bcast.IsValid(),
|
||||
errors::InvalidArgument(
|
||||
"In[0] and In[1] must have compatible batch dimensions: ",
|
||||
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
|
||||
|
||||
TensorShape out_shape = bcast.output_batch_shape();
|
||||
auto batch_size = bcast.output_batch_size();
|
||||
auto d0 = in0.dim_size(in0.dims() - 2); // Band size.
|
||||
auto d1 = in0.dim_size(in0.dims() - 1);
|
||||
Tensor in0_reshaped;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
|
||||
errors::Internal("Failed to reshape In[0] from ",
|
||||
in0.shape().DebugString()));
|
||||
auto d2 = in1.dim_size(in1.dims() - 2);
|
||||
auto d3 = in1.dim_size(in1.dims() - 1);
|
||||
Tensor in1_reshaped;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
|
||||
errors::Internal("Failed to reshape In[1] from ",
|
||||
in1.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, d1 == d2,
|
||||
errors::InvalidArgument(
|
||||
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
|
||||
in0.shape().DebugString(), " ", in1.shape().DebugString(),
|
||||
" ", lower_, " ", adjoint_));
|
||||
out_shape.AddDim(d1);
|
||||
out_shape.AddDim(d3);
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||
if (out->NumElements() == 0) {
|
||||
return;
|
||||
}
|
||||
Tensor out_reshaped;
|
||||
OP_REQUIRES(ctx,
|
||||
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d1, d3})),
|
||||
errors::Internal("Failed to reshape output from ",
|
||||
out->shape().DebugString()));
|
||||
LaunchBatchBandedTriangularSolve<Scalar>::Launch(
|
||||
ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
|
||||
&out_reshaped);
|
||||
}
|
||||
|
||||
private:
|
||||
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
|
||||
const Tensor& in1) {
|
||||
OP_REQUIRES(
|
||||
ctx, in0.dims() >= 2,
|
||||
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, in1.dims() >= 2,
|
||||
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
|
||||
}
|
||||
bool lower_;
|
||||
bool adjoint_;
|
||||
};
|
||||
|
||||
#define REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BandedTriangularSolve") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("T"), \
|
||||
BandedTriangularSolveOpCpu<TYPE>);
|
||||
|
||||
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(float);
|
||||
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(double);
|
||||
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex64);
|
||||
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex128);
|
||||
|
||||
} // namespace tensorflow
|
180
tensorflow/core/kernels/banded_triangular_solve_op_test.cc
Normal file
180
tensorflow/core/kernels/banded_triangular_solve_op_test.cc
Normal file
@ -0,0 +1,180 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Node* SetDiag(int num_bands, Graph* g, Node* bands, Node* triangular) {
|
||||
Node* ret;
|
||||
Tensor bandwidth(DT_INT32, TensorShape({2}));
|
||||
bandwidth.flat<int32>()(0) = -(num_bands - 1);
|
||||
bandwidth.flat<int32>()(1) = 0;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixSetDiagV3")
|
||||
.Input(triangular)
|
||||
.Input(bands)
|
||||
.Input(test::graph::Constant(g, bandwidth))
|
||||
.Attr("align", "RIGHT_LEFT")
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* BandedTriangularSolve(Graph* g, Node* in0, Node* in1) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BandedTriangularSolve")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("lower", true)
|
||||
.Attr("adjoint", false)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1) {
|
||||
Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Attr("lower", true)
|
||||
.Attr("adjoint", false)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Graph* BandedTriangularSolve(int64 num_bands, int64 n, int64 m,
|
||||
bool use_banded_solver, DataType type) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(type, TensorShape({num_bands, n}));
|
||||
// Set diagonal to nonzero to guarantee invertibility.
|
||||
in0.flat<T>().setRandom();
|
||||
in0.flat<T>() =
|
||||
in0.flat<T>().abs() + in0.flat<T>().constant(static_cast<T>(0.5));
|
||||
Tensor in1(type, TensorShape({n, m}));
|
||||
in1.flat<T>().setRandom();
|
||||
if (use_banded_solver) {
|
||||
BandedTriangularSolve(g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, in1));
|
||||
} else {
|
||||
// Create a zero tensor.
|
||||
Tensor in2(type, TensorShape({n, n}));
|
||||
in2.flat<T>().setZero();
|
||||
Node* triangular_matrix =
|
||||
SetDiag(num_bands, g, test::graph::Constant(g, in0),
|
||||
test::graph::Constant(g, in2));
|
||||
MatrixTriangularSolve(g, triangular_matrix, test::graph::Constant(g, in1));
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// K: Number of bands
|
||||
// N: Inner dimension of LHS, Inner dimension of RHS.
|
||||
// M: Outer dimensions of RHS
|
||||
// BS: boolean indicating whether to use the banded solver
|
||||
// T: C++ type of scalars (e.g. float, std::complex)
|
||||
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
|
||||
#define BM_BandedTriangularSolveDev(K, N, M, BS, T, TT, D) \
|
||||
static void BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT( \
|
||||
int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * K * N + N * M); \
|
||||
test::Benchmark(#D, BandedTriangularSolve<T>(K, N, M, BS, TT)).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT);
|
||||
|
||||
#define BM_BandedTriangularSolve(K, N, M, BS, D) \
|
||||
BM_BandedTriangularSolveDev(K, N, M, BS, float, DT_FLOAT, D); \
|
||||
BM_BandedTriangularSolveDev(K, N, M, BS, double, DT_DOUBLE, D);
|
||||
|
||||
// Small number of bands, few rhs
|
||||
BM_BandedTriangularSolve(2, 32, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 32, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 32, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 32, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 32, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 32, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 32, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 32, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(2, 128, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 128, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 128, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 128, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 128, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 128, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 128, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 128, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(2, 512, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 512, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 512, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 512, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 512, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 512, 1, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 512, 1, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 512, 1, false, cpu);
|
||||
|
||||
// Larger # rhs
|
||||
BM_BandedTriangularSolve(2, 32, 32, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 32, 32, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 32, 32, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 32, 32, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 32, 32, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 32, 32, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 32, 32, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 32, 32, false, cpu);
|
||||
BM_BandedTriangularSolve(2, 128, 128, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 128, 128, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 128, 128, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 128, 128, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 128, 128, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 128, 128, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 128, 128, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 128, 128, false, cpu);
|
||||
BM_BandedTriangularSolve(2, 512, 512, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 512, 512, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 512, 512, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 512, 512, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 512, 512, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 512, 512, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 512, 512, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 512, 512, false, cpu);
|
||||
|
||||
BM_BandedTriangularSolve(2, 2048, 2048, true, cpu);
|
||||
BM_BandedTriangularSolve(2, 2048, 2048, false, cpu);
|
||||
BM_BandedTriangularSolve(4, 2048, 2048, true, cpu);
|
||||
BM_BandedTriangularSolve(4, 2048, 2048, false, cpu);
|
||||
BM_BandedTriangularSolve(8, 2048, 2048, true, cpu);
|
||||
BM_BandedTriangularSolve(8, 2048, 2048, false, cpu);
|
||||
BM_BandedTriangularSolve(16, 2048, 2048, true, cpu);
|
||||
BM_BandedTriangularSolve(16, 2048, 2048, false, cpu);
|
||||
BM_BandedTriangularSolve(32, 2048, 2048, true, cpu);
|
||||
BM_BandedTriangularSolve(32, 2048, 2048, false, cpu);
|
||||
BM_BandedTriangularSolve(64, 2048, 2048, true, cpu);
|
||||
BM_BandedTriangularSolve(64, 2048, 2048, false, cpu);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -47,6 +47,49 @@ Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The first input is [...,K,M] and second input is [...,M,N].
|
||||
Status BandedTriangularSolveShapeFn(InferenceContext* c) {
|
||||
ShapeHandle lhs;
|
||||
ShapeHandle rhs;
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
|
||||
|
||||
// Check K > 0.
|
||||
DimensionHandle num_bands = c->Dim(lhs, -2);
|
||||
DimensionHandle m = c->Dim(lhs, -1);
|
||||
if (c->ValueKnown(num_bands) && c->Value(num_bands) <= 0) {
|
||||
return errors::InvalidArgument("Number of bands must be positive, but is ",
|
||||
c->Value(num_bands));
|
||||
}
|
||||
if (c->ValueKnown(num_bands) && c->ValueKnown(m) &&
|
||||
c->Value(num_bands) > c->Value(m)) {
|
||||
return errors::InvalidArgument("Number of bands ", c->Value(num_bands),
|
||||
" cannot exceed the size of the matrix ",
|
||||
c->Value(m));
|
||||
}
|
||||
|
||||
ShapeHandle lhs_batch_shape;
|
||||
ShapeHandle rhs_batch_shape;
|
||||
ShapeHandle output_batch_shape;
|
||||
// Make the common batch subshape.
|
||||
TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
|
||||
TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
|
||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
|
||||
c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
|
||||
|
||||
// lhs and rhs have the same value for M to be compatible.
|
||||
TF_RETURN_IF_ERROR(c->Merge(m, c->Dim(rhs, -2), &m));
|
||||
|
||||
// Build final shape (batch_shape + m + n) in <out>.
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
|
||||
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
|
||||
// Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
|
||||
Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
|
||||
@ -446,6 +489,17 @@ REGISTER_OP("MatrixSolve")
|
||||
return MatrixSolveShapeFn(c, true /* square (*/);
|
||||
});
|
||||
|
||||
REGISTER_OP("BandedTriangularSolve")
|
||||
.Input("matrix: T")
|
||||
.Input("rhs: T")
|
||||
.Output("output: T")
|
||||
.Attr("lower: bool = True")
|
||||
.Attr("adjoint: bool = False")
|
||||
.Attr("T: {double, float, half, complex64, complex128}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return BandedTriangularSolveShapeFn(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("MatrixTriangularSolve")
|
||||
.Input("matrix: T")
|
||||
.Input("rhs: T")
|
||||
|
@ -762,6 +762,17 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "banded_triangular_solve_op_test",
|
||||
size = "small",
|
||||
srcs = ["banded_triangular_solve_op_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "matrix_triangular_solve_op_test",
|
||||
size = "medium",
|
||||
|
@ -0,0 +1,232 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.math_ops.banded_triangular_solve."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BandedTriangularSolveOpTest(test.TestCase):
|
||||
|
||||
def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None):
|
||||
for lower in (False,):
|
||||
for adjoint in (False, True):
|
||||
for use_placeholder in True, False:
|
||||
self._verifySolve(
|
||||
x,
|
||||
y,
|
||||
lower=lower,
|
||||
adjoint=adjoint,
|
||||
batch_dims=batch_dims,
|
||||
use_placeholder=use_placeholder,
|
||||
dtypes=dtypes)
|
||||
|
||||
def _verifySolveAllWaysReal(self, x, y, batch_dims=None):
|
||||
self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
|
||||
|
||||
def _verifySolveAllWaysComplex(self, x, y, batch_dims=None):
|
||||
self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
|
||||
|
||||
def _verifySolve(self,
|
||||
x,
|
||||
y,
|
||||
lower=True,
|
||||
adjoint=False,
|
||||
batch_dims=None,
|
||||
use_placeholder=False,
|
||||
dtypes=(np.float32, np.float64)):
|
||||
for np_type in dtypes:
|
||||
a = x.astype(np_type)
|
||||
b = y.astype(np_type)
|
||||
|
||||
# Now we need to convert a to a dense triangular matrix.
|
||||
def make_diags(diags, lower=True):
|
||||
n = len(diags[0])
|
||||
a = np.zeros(n * n, dtype=diags.dtype)
|
||||
if lower:
|
||||
for i, diag in enumerate(diags):
|
||||
a[n * i:n * n:n + 1] = diag[i:]
|
||||
else:
|
||||
diags_flip = np.flip(diags, 0)
|
||||
for i, diag in enumerate(diags_flip):
|
||||
a[i:(n - i) * n:n + 1] = diag[:(n - i)]
|
||||
return a.reshape(n, n)
|
||||
|
||||
# For numpy.solve we have to explicitly zero out the strictly
|
||||
# upper or lower triangle.
|
||||
if a.size > 0:
|
||||
a_np = make_diags(a, lower=lower)
|
||||
else:
|
||||
a_np = a
|
||||
if adjoint:
|
||||
a_np = np.conj(np.transpose(a_np))
|
||||
|
||||
if batch_dims is not None:
|
||||
a = np.tile(a, batch_dims + [1, 1])
|
||||
a_np = np.tile(a_np, batch_dims + [1, 1])
|
||||
b = np.tile(b, batch_dims + [1, 1])
|
||||
|
||||
with self.cached_session(use_gpu=True):
|
||||
a_tf = a
|
||||
b_tf = b
|
||||
if use_placeholder:
|
||||
a_tf = array_ops.placeholder_with_default(a_tf, shape=None)
|
||||
b_tf = array_ops.placeholder_with_default(b_tf, shape=None)
|
||||
tf_ans = linalg_ops.banded_triangular_solve(
|
||||
a_tf, b_tf, lower=lower, adjoint=adjoint)
|
||||
tf_val = self.evaluate(tf_ans)
|
||||
np_ans = np.linalg.solve(a_np, b)
|
||||
self.assertEqual(np_ans.shape, tf_val.shape)
|
||||
self.assertAllClose(np_ans, tf_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSolve(self):
|
||||
# 1x1 matrix, single rhs.
|
||||
matrix = np.array([[0.1]])
|
||||
rhs0 = np.array([[1.]])
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
# 2x2 matrix with 2 bands, single right-hand side.
|
||||
# Corresponds to the lower triangular
|
||||
# [[1., 0.], [3., 4.]]
|
||||
# and upper triangular
|
||||
# [[2., 1.], [0., 3.]]
|
||||
matrix = np.array([[1., 4.], [2., 3.]])
|
||||
rhs0 = np.array([[1.], [1.]])
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
# 2x2 matrix with 2 bands, 3 right-hand sides.
|
||||
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
|
||||
self._verifySolveAllWaysReal(matrix, rhs1)
|
||||
# 4 x 4 matrix with 2 bands, 3 right hand sides.
|
||||
# Corresponds to the lower triangular
|
||||
# [[1., 0., 0., 0.],
|
||||
# [-1., 2., 0., 0.],
|
||||
# [0., -2., 3., 0.],
|
||||
# [0., 0., -3., 4.]]
|
||||
# and upper triangular
|
||||
# [[1., 1., 0., 0.],
|
||||
# [0., -1., 2., 0.],
|
||||
# [0., 0., -2., 3.],
|
||||
# [0., 0., 0., -3.]]
|
||||
matrix = np.array([[1., 2., 3., 4.], [1., -1., -2., -3.]])
|
||||
rhs0 = np.array([[1., 0., 1.], [0., 1., 1.], [-1., 2., 1.], [0., -1., -1.]])
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
|
||||
def testSolveBandSizeSmaller(self):
|
||||
rhs0 = np.random.randn(6, 4)
|
||||
|
||||
# 6 x 6 matrix with 2 bands. Ensure all non-zero entries.
|
||||
matrix = 2. * np.random.uniform(size=[3, 6]) + 1.
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
|
||||
# 6 x 6 matrix with 3 bands. Ensure all non-zero entries.
|
||||
matrix = 2. * np.random.uniform(size=[3, 6]) + 1.
|
||||
self._verifySolveAllWaysReal(matrix, rhs0)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSolveComplex(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
# 1x1 matrix, single rhs.
|
||||
matrix = np.array([[0.1 + 1j * 0.1]])
|
||||
rhs0 = np.array([[1. + 1j]])
|
||||
self._verifySolveAllWaysComplex(matrix, rhs0)
|
||||
# 2x2 matrix with 2 bands, single right-hand side.
|
||||
# Corresponds to
|
||||
# [[1. + 1j, 0.], [4 + 1j, 2 + 1j]]
|
||||
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
|
||||
matrix += 1j * matrix
|
||||
rhs0 = np.array([[1.], [1.]]).astype(np.complex64)
|
||||
rhs0 += 1j * rhs0
|
||||
self._verifySolveAllWaysComplex(matrix, rhs0)
|
||||
# 2x2 matrix with 2 bands, 3 right-hand sides.
|
||||
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
|
||||
rhs1 += 1j * rhs1
|
||||
self._verifySolveAllWaysComplex(matrix, rhs1)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSolveBatch(self):
|
||||
matrix = np.array([[1., 2.], [3., 4.]])
|
||||
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
|
||||
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
|
||||
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
|
||||
|
||||
matrix = np.array([[1., 2., 3., 4.], [-1., -2., -3., -4.],
|
||||
[-1., 1., 2., 3.]])
|
||||
rhs = np.array([[-1., 2.], [1., 1.], [0., 1.], [2., 3.]])
|
||||
# Batch of 2x3x4x4 matrices with 3 bands, 2x3x4x2 right-hand sides.
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
|
||||
# Batch of 3x2x4x4 matrices with 3 bands, 3x2x4x2 right-hand sides.
|
||||
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSolveBatchComplex(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
|
||||
matrix += 1j * matrix
|
||||
rhs = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
|
||||
rhs += 1j * rhs
|
||||
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[2, 3])
|
||||
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
|
||||
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[3, 2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWrongDimensions(self):
|
||||
# The matrix should have the same number of rows as the
|
||||
# right-hand sides.
|
||||
matrix = np.array([[1., 1.], [1., 1.]])
|
||||
rhs = np.array([[1., 0.]])
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaises(ValueError):
|
||||
self._verifySolve(matrix, rhs)
|
||||
with self.assertRaises(ValueError):
|
||||
self._verifySolve(matrix, rhs, batch_dims=[2, 3])
|
||||
|
||||
# Number of bands exceeds the dimension of the matrix.
|
||||
matrix = np.ones((6, 4))
|
||||
rhs = np.ones((4, 2))
|
||||
with self.cached_session(use_gpu=True):
|
||||
with self.assertRaises(ValueError):
|
||||
self._verifySolve(matrix, rhs)
|
||||
with self.assertRaises(ValueError):
|
||||
self._verifySolve(matrix, rhs, batch_dims=[2, 3])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla("XLA cannot throw assertion errors during a kernel.")
|
||||
def testNotInvertible(self):
|
||||
# The input should be invertible.
|
||||
# The matrix is singular because it has a zero on the diagonal.
|
||||
# FIXME(rmlarsen): The GPU kernel does not check for singularity.
|
||||
singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError("Input matrix is not invertible."):
|
||||
self._verifySolve(singular_matrix, singular_matrix)
|
||||
with self.assertRaisesOpError("Input matrix is not invertible."):
|
||||
self._verifySolve(singular_matrix, singular_matrix, batch_dims=[2, 3])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -132,6 +132,44 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
|
||||
return Test
|
||||
|
||||
|
||||
def _GetBandedTriangularSolveGradientTest(
|
||||
functor_,
|
||||
dtype_,
|
||||
shape_,
|
||||
float32_tol_fudge=1.0, # pylint: disable=redefined-outer-name
|
||||
**kwargs_):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||
def Test(self):
|
||||
n = shape_[-1]
|
||||
|
||||
np.random.seed(1)
|
||||
# Make sure invertible.
|
||||
a_np = np.random.uniform(low=1.0, high=2.0, size=shape_).astype(dtype_)
|
||||
a = constant_op.constant(a_np)
|
||||
|
||||
b_np = np.random.uniform(low=-1.0, high=1.0, size=[n, n]).astype(dtype_)
|
||||
b = constant_op.constant(b_np)
|
||||
|
||||
epsilon = np.finfo(dtype_).eps
|
||||
delta = epsilon**(1.0 / 3.0)
|
||||
# tolerance obtained by looking at actual differences using
|
||||
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
|
||||
tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
|
||||
|
||||
# check gradient w.r.t. left argument.
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||
lambda x: functor_(x, b, **kwargs_), [a], delta=delta)
|
||||
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
||||
|
||||
# check gradient w.r.t. right argument.
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||
lambda y: functor_(a, y, **kwargs_), [b], delta=delta)
|
||||
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
||||
|
||||
return Test
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Tests for gradients of binary matrix operations.
|
||||
for dtype in np.float32, np.float64:
|
||||
@ -166,6 +204,20 @@ if __name__ == '__main__':
|
||||
adjoint=adjoint,
|
||||
lower=lower))
|
||||
|
||||
band_shape = extra + (size // 2 + 1, size)
|
||||
name = '%s_%s_adj_%s_low_%s' % (dtype.__name__, '_'.join(
|
||||
map(str, band_shape)), str(adjoint), lower)
|
||||
_AddTest(
|
||||
MatrixBinaryFunctorGradientTest,
|
||||
'BandedTriangularSolveGradient', name,
|
||||
_GetBandedTriangularSolveGradientTest(
|
||||
linalg_ops.banded_triangular_solve,
|
||||
dtype,
|
||||
band_shape,
|
||||
float32_tol_fudge=4.0,
|
||||
adjoint=adjoint,
|
||||
lower=lower))
|
||||
|
||||
# Tests for gradients of unary matrix operations.
|
||||
for dtype in np.float32, np.float64:
|
||||
for size in 2, 5, 10:
|
||||
|
@ -340,6 +340,102 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
|
||||
return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
|
||||
|
||||
|
||||
@tf_export('linalg.banded_triangular_solve', v1=[])
|
||||
def banded_triangular_solve(
|
||||
bands,
|
||||
rhs,
|
||||
lower=True,
|
||||
adjoint=False, # pylint: disable=redefined-outer-name
|
||||
name=None):
|
||||
r"""Solve triangular systems of equations with a banded solver.
|
||||
|
||||
`bands` is a tensor of shape `[..., K, M]`, where `K` represents the number
|
||||
of bands stored. This corresponds to a batch of `M` by `M` matrices, whose
|
||||
`K` subdiagonals (when `lower` is `True`) are stored.
|
||||
|
||||
This operator broadcasts the batch dimensions of `bands` and the batch
|
||||
dimensions of `rhs`.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
Storing 2 bands of a 3x3 matrix.
|
||||
Note that first element in the second row is ignored due to
|
||||
the 'LEFT_RIGHT' padding.
|
||||
|
||||
>>> x = [[2., 3., 4.], [1., 2., 3.]]
|
||||
>>> x2 = [[2., 3., 4.], [10000., 2., 3.]]
|
||||
>>> y = tf.zeros([3, 3])
|
||||
>>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
|
||||
>>> z
|
||||
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
|
||||
array([[2., 0., 0.],
|
||||
[2., 3., 0.],
|
||||
[0., 3., 4.]], dtype=float32)>
|
||||
>>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
|
||||
>>> soln
|
||||
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
|
||||
array([[0.5 ],
|
||||
[0. ],
|
||||
[0.25]], dtype=float32)>
|
||||
>>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
|
||||
>>> tf.reduce_all(are_equal).numpy()
|
||||
True
|
||||
>>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
|
||||
>>> tf.reduce_all(are_equal).numpy()
|
||||
True
|
||||
|
||||
Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding
|
||||
the last element of the first row is ignored.
|
||||
|
||||
>>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
|
||||
>>> y = tf.zeros([4, 4])
|
||||
>>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
|
||||
>>> z
|
||||
<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
|
||||
array([[-1., 2., 0., 0.],
|
||||
[ 0., -2., 3., 0.],
|
||||
[ 0., 0., -3., 4.],
|
||||
[ 0., 0., -0., -4.]], dtype=float32)>
|
||||
>>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
|
||||
>>> soln
|
||||
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
|
||||
array([[-4. ],
|
||||
[-1.5 ],
|
||||
[-0.6666667],
|
||||
[-0.25 ]], dtype=float32)>
|
||||
>>> are_equal = (soln == tf.linalg.triangular_solve(
|
||||
... z, tf.ones([4, 1]), lower=False))
|
||||
>>> tf.reduce_all(are_equal).numpy()
|
||||
True
|
||||
|
||||
|
||||
Args:
|
||||
bands: A `Tensor` describing the bands of the left hand side, with shape
|
||||
`[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th
|
||||
diagonal (the diagonal is the top row) when `lower` is `True` and
|
||||
otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is
|
||||
the bottom row) when `lower` is `False`. The bands are stored with
|
||||
'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right
|
||||
and subdiagonals are padded on the left. This is the alignment cuSPARSE
|
||||
uses. See `tf.linalg.set_diag` for more details.
|
||||
rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as
|
||||
`diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
|
||||
statically, `rhs` will be treated as a matrix rather than a vector.
|
||||
lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
|
||||
`bands` represents a lower or upper triangular matrix.
|
||||
adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
|
||||
to solve with the matrix's block-wise adjoint.
|
||||
name: A name to give this `Op` (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of shape [..., M] or [..., M, N] containing the solutions.
|
||||
"""
|
||||
with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]):
|
||||
return gen_linalg_ops.banded_triangular_solve(
|
||||
bands, rhs, lower=lower, adjoint=adjoint)
|
||||
|
||||
|
||||
@tf_export('linalg.tridiagonal_solve')
|
||||
@dispatch.add_dispatch_support
|
||||
def tridiagonal_solve(diagonals,
|
||||
|
@ -607,6 +607,39 @@ def _MatrixSolveLsGrad(op, grad):
|
||||
lambda: _Underdetermined(op, grad))
|
||||
|
||||
|
||||
@ops.RegisterGradient("BandedTriangularSolve")
|
||||
def _BandedTriangularSolveGrad(op, grad):
|
||||
"""Gradient for BandedTriangularSolve."""
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
num_bands = array_ops.shape(a)[-2]
|
||||
adjoint_a = op.get_attr("adjoint")
|
||||
lower_a = op.get_attr("lower")
|
||||
c = op.outputs[0]
|
||||
grad_b = linalg_ops.banded_triangular_solve(
|
||||
a, grad, lower=lower_a, adjoint=not adjoint_a)
|
||||
if adjoint_a:
|
||||
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
|
||||
else:
|
||||
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
|
||||
if lower_a:
|
||||
grad_a = array_ops.matrix_diag_part(
|
||||
grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
|
||||
else:
|
||||
grad_a = array_ops.matrix_diag_part(
|
||||
grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
|
||||
# If the static batch shapes are equal, we don't need to unbroadcast.
|
||||
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
|
||||
a.shape[:-2] == b.shape[:-2]):
|
||||
return grad_a, grad_b
|
||||
a_shape = array_ops.shape(a)
|
||||
b_shape = array_ops.shape(b)
|
||||
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
|
||||
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
|
||||
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
|
||||
return grad_a, grad_b
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixTriangularSolve")
|
||||
def _MatrixTriangularSolveGrad(op, grad):
|
||||
"""Gradient for MatrixTriangularSolve."""
|
||||
|
@ -284,6 +284,10 @@ tf_module {
|
||||
name: "AvgPoolGrad"
|
||||
argspec: "args=[\'orig_input_shape\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BandedTriangularSolve"
|
||||
argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Barrier"
|
||||
argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], "
|
||||
|
@ -96,6 +96,10 @@ tf_module {
|
||||
name: "band_part"
|
||||
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "banded_triangular_solve"
|
||||
argspec: "args=[\'bands\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cholesky"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -284,6 +284,10 @@ tf_module {
|
||||
name: "AvgPoolGrad"
|
||||
argspec: "args=[\'orig_input_shape\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BandedTriangularSolve"
|
||||
argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Barrier"
|
||||
argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user