Add banded triangular solve op.

PiperOrigin-RevId: 317124054
Change-Id: I54f090d7583b21fa18788a2deb02262d9c8231be
This commit is contained in:
Srinivas Vasudevan 2020-06-18 09:59:14 -07:00 committed by TensorFlower Gardener
parent 18f54c42c6
commit 89b80c5fb9
14 changed files with 998 additions and 0 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BandedTriangularSolve"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BandedTriangularSolve"
visibility: HIDDEN
}

View File

@ -3577,6 +3577,7 @@ tf_cc_tests(
cc_library( cc_library(
name = "linalg", name = "linalg",
deps = [ deps = [
":banded_triangular_solve_op",
":cholesky_grad", ":cholesky_grad",
":cholesky_op", ":cholesky_op",
":determinant_op", ":determinant_op",
@ -3750,6 +3751,12 @@ tf_kernel_library(
deps = LINALG_DEPS, deps = LINALG_DEPS,
) )
tf_kernel_library(
name = "banded_triangular_solve_op",
prefix = "banded_triangular_solve_op",
deps = LINALG_DEPS + [":fill_functor"],
)
tf_kernel_library( tf_kernel_library(
name = "matrix_triangular_solve_op", name = "matrix_triangular_solve_op",
hdrs = ["matrix_triangular_solve_op_impl.h"], hdrs = ["matrix_triangular_solve_op_impl.h"],
@ -4425,6 +4432,26 @@ tf_cuda_cc_test(
], ],
) )
tf_cuda_cc_test(
name = "banded_triangular_solve_op_test",
size = "small",
srcs = ["banded_triangular_solve_op_test.cc"],
deps = [
":banded_triangular_solve_op",
":matrix_set_diag_op",
":matrix_triangular_solve_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test( tf_cuda_cc_test(
name = "matrix_triangular_solve_op_test", name = "matrix_triangular_solve_op_test",
size = "small", size = "small",

View File

@ -0,0 +1,293 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Scalar>
Scalar eigen_conj(const Scalar& scalar) {
return Eigen::numext::conj<Scalar>(scalar);
}
// Sequential batch matrix triangular solve kernel that calls Eigen's
// matrix triangular solve.
template <typename Scalar>
struct SequentialBandedTriangularSolveKernel {
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
int slice) {
return ConstMatrixMap(
t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
t.dim_size(1), t.dim_size(2));
}
static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
return MatrixMap(
t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
t->dim_size(1), t->dim_size(2));
}
static void Run(const Tensor& in_x, const Tensor& in_y, bool lower,
bool adjoint, const MatMulBCast& bcast, Tensor* out,
int start, int limit) {
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& x_batch_indices = bcast.x_batch_indices();
const auto& y_batch_indices = bcast.y_batch_indices();
int num_bands = in_x.dim_size(1);
int matrix_size = in_x.dim_size(2);
for (int64 i = start; i < limit; ++i) {
const int64 x_batch_index = should_bcast ? x_batch_indices[i] : i;
const int64 y_batch_index = should_bcast ? y_batch_indices[i] : i;
auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
auto output = TensorSliceToEigenMatrix(out, i);
// Below, we use the standard algorithm for computing a triangular solve,
// except we band limit it.
// Given A x = b, where A is lower triangular,
// x_i = (b_i - sum a_ij * x_j) / a_ii, where the sum is from
// j = 0 to i - 1.
//
// Now, in a banded triangular matrix, when i exceeds the band size,
// then the sum goes from j = i - band_size to i - 1, since the other
// elements are zero.
//
// Finally, given the band storage format, we'll need to change the
// indexing.
if (lower) {
if (!adjoint) {
output.row(0) = rhs.row(0) / matrix(0, 0);
for (int i = 1; i < matrix_size; ++i) {
if (i < num_bands) {
output.row(i).noalias() =
(rhs.row(i) - matrix.block(1, i, i, 1).reverse().transpose() *
output.topRows(i)) /
matrix(0, i);
} else {
output.row(i).noalias() =
(rhs.row(i) -
matrix.block(1, i, num_bands - 1, 1).reverse().transpose() *
output.middleRows(i - (num_bands - 1), num_bands - 1)) /
matrix(0, i);
}
}
} else {
// In the adjoint case, here and below, we now have an upper (lower)
// triangular matrix, and thus need to work through with the other
// case. We can't simply conjugate `matrix` and use the upper (lower)
// algorithm because the band storage format for upper and lower
// triangular matrices are different (in the lower case, we pad
// entries on the left, and in the upper case we pad entries on the
// right.
output.row(matrix_size - 1) =
rhs.row(matrix_size - 1) / eigen_conj(matrix(0, matrix_size - 1));
for (int i = matrix_size - 1; i >= 0; --i) {
output.row(i).noalias() = rhs.row(i);
for (int j = i + 1; j < std::min(matrix_size, i + num_bands); ++j) {
output.row(i).noalias() -=
eigen_conj(matrix(j - i, j)) * output.row(j);
}
output.row(i) /= eigen_conj(matrix(0, i));
}
}
} else {
if (!adjoint) {
output.row(matrix_size - 1) =
rhs.row(matrix_size - 1) / matrix(num_bands - 1, matrix_size - 1);
for (int i = 1; i < matrix_size; ++i) {
int k = matrix_size - 1 - i;
if (i < num_bands) {
output.row(k).noalias() =
(rhs.row(k) - matrix.block(num_bands - 1 - i, k, i, 1)
.reverse()
.transpose() *
output.bottomRows(i)) /
matrix(num_bands - 1, k);
} else {
output.row(k).noalias() =
(rhs.row(k) -
matrix.block(0, k, num_bands - 1, 1).reverse().transpose() *
output.middleRows(k + 1, num_bands - 1)) /
matrix(num_bands - 1, k);
}
}
} else {
output.row(0) = rhs.row(0) / eigen_conj(matrix(num_bands - 1, 0));
for (int i = 1; i < matrix_size; ++i) {
output.row(i).noalias() = rhs.row(i);
for (int j = std::max(0, i - (num_bands - 1)); j < i; ++j) {
output.row(i).noalias() -=
eigen_conj(matrix(num_bands - 1 - (i - j), j)) *
output.row(j);
}
output.row(i) /= eigen_conj(matrix(num_bands - 1, i));
}
}
}
}
}
};
template <typename Scalar>
struct LaunchBatchBandedTriangularSolve;
template <typename Scalar>
struct LaunchBatchBandedTriangularSolve {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adjoint, bool lower,
const MatMulBCast& bcast, Tensor* out) {
// Number of banded matrix triangular solves i.e. size of the batch.
const int64 batch_size = bcast.output_batch_size();
const int64 cost_per_unit =
in_x.dim_size(1) * in_x.dim_size(2) * in_y.dim_size(2);
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
// Check diagonal before doing any solves. This is the first row in the
// lower case and else is the last row.
auto matrix = ConstMatrixMap(in_x.flat<Scalar>().data(), in_x.dim_size(1),
in_x.dim_size(2));
RealScalar min_abs_pivot;
if (lower) {
min_abs_pivot = matrix.row(0).cwiseAbs().minCoeff();
} else {
min_abs_pivot = matrix.row(in_x.dim_size(1) - 1).cwiseAbs().minCoeff();
}
OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
cost_per_unit,
[&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
SequentialBandedTriangularSolveKernel<Scalar>::Run(
in_x, in_y, lower, adjoint, bcast, out, start, limit);
});
}
};
template <typename Scalar>
class BandedTriangularSolveOpCpu : public OpKernel {
public:
explicit BandedTriangularSolveOpCpu(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
}
~BandedTriangularSolveOpCpu() override {}
void Compute(OpKernelContext* ctx) override {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
ValidateInputTensors(ctx, in0, in1);
MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
OP_REQUIRES(
ctx, bcast.IsValid(),
errors::InvalidArgument(
"In[0] and In[1] must have compatible batch dimensions: ",
in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
TensorShape out_shape = bcast.output_batch_shape();
auto batch_size = bcast.output_batch_size();
auto d0 = in0.dim_size(in0.dims() - 2); // Band size.
auto d1 = in0.dim_size(in0.dims() - 1);
Tensor in0_reshaped;
OP_REQUIRES(
ctx,
in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
errors::Internal("Failed to reshape In[0] from ",
in0.shape().DebugString()));
auto d2 = in1.dim_size(in1.dims() - 2);
auto d3 = in1.dim_size(in1.dims() - 1);
Tensor in1_reshaped;
OP_REQUIRES(
ctx,
in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
errors::Internal("Failed to reshape In[1] from ",
in1.shape().DebugString()));
OP_REQUIRES(ctx, d1 == d2,
errors::InvalidArgument(
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
in0.shape().DebugString(), " ", in1.shape().DebugString(),
" ", lower_, " ", adjoint_));
out_shape.AddDim(d1);
out_shape.AddDim(d3);
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
return;
}
Tensor out_reshaped;
OP_REQUIRES(ctx,
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d1, d3})),
errors::Internal("Failed to reshape output from ",
out->shape().DebugString()));
LaunchBatchBandedTriangularSolve<Scalar>::Launch(
ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast,
&out_reshaped);
}
private:
void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
const Tensor& in1) {
OP_REQUIRES(
ctx, in0.dims() >= 2,
errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()));
OP_REQUIRES(
ctx, in1.dims() >= 2,
errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()));
}
bool lower_;
bool adjoint_;
};
#define REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(TYPE) \
REGISTER_KERNEL_BUILDER(Name("BandedTriangularSolve") \
.Device(DEVICE_CPU) \
.TypeConstraint<TYPE>("T"), \
BandedTriangularSolveOpCpu<TYPE>);
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(float);
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(double);
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex64);
REGISTER_BANDED_TRIANGULAR_SOLVE_CPU(complex128);
} // namespace tensorflow

View File

@ -0,0 +1,180 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Node* SetDiag(int num_bands, Graph* g, Node* bands, Node* triangular) {
Node* ret;
Tensor bandwidth(DT_INT32, TensorShape({2}));
bandwidth.flat<int32>()(0) = -(num_bands - 1);
bandwidth.flat<int32>()(1) = 0;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixSetDiagV3")
.Input(triangular)
.Input(bands)
.Input(test::graph::Constant(g, bandwidth))
.Attr("align", "RIGHT_LEFT")
.Finalize(g, &ret));
return ret;
}
Node* BandedTriangularSolve(Graph* g, Node* in0, Node* in1) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BandedTriangularSolve")
.Input(in0)
.Input(in1)
.Attr("lower", true)
.Attr("adjoint", false)
.Finalize(g, &ret));
return ret;
}
Node* MatrixTriangularSolve(Graph* g, Node* in0, Node* in1) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatrixTriangularSolve")
.Input(in0)
.Input(in1)
.Attr("lower", true)
.Attr("adjoint", false)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* BandedTriangularSolve(int64 num_bands, int64 n, int64 m,
bool use_banded_solver, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(type, TensorShape({num_bands, n}));
// Set diagonal to nonzero to guarantee invertibility.
in0.flat<T>().setRandom();
in0.flat<T>() =
in0.flat<T>().abs() + in0.flat<T>().constant(static_cast<T>(0.5));
Tensor in1(type, TensorShape({n, m}));
in1.flat<T>().setRandom();
if (use_banded_solver) {
BandedTriangularSolve(g, test::graph::Constant(g, in0),
test::graph::Constant(g, in1));
} else {
// Create a zero tensor.
Tensor in2(type, TensorShape({n, n}));
in2.flat<T>().setZero();
Node* triangular_matrix =
SetDiag(num_bands, g, test::graph::Constant(g, in0),
test::graph::Constant(g, in2));
MatrixTriangularSolve(g, triangular_matrix, test::graph::Constant(g, in1));
}
return g;
}
// Macro arguments names: --------------------------------------------------- //
// K: Number of bands
// N: Inner dimension of LHS, Inner dimension of RHS.
// M: Outer dimensions of RHS
// BS: boolean indicating whether to use the banded solver
// T: C++ type of scalars (e.g. float, std::complex)
// TT: TensorFlow type of scalars (e.g. DT_FLOAT, DT_COMPLEX128
#define BM_BandedTriangularSolveDev(K, N, M, BS, T, TT, D) \
static void BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT( \
int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * K * N + N * M); \
test::Benchmark(#D, BandedTriangularSolve<T>(K, N, M, BS, TT)).Run(iters); \
} \
BENCHMARK(BM_BandedTriangularSolve##_##K##_##N##_##M##_##BS##_##TT);
#define BM_BandedTriangularSolve(K, N, M, BS, D) \
BM_BandedTriangularSolveDev(K, N, M, BS, float, DT_FLOAT, D); \
BM_BandedTriangularSolveDev(K, N, M, BS, double, DT_DOUBLE, D);
// Small number of bands, few rhs
BM_BandedTriangularSolve(2, 32, 1, true, cpu);
BM_BandedTriangularSolve(2, 32, 1, false, cpu);
BM_BandedTriangularSolve(4, 32, 1, true, cpu);
BM_BandedTriangularSolve(4, 32, 1, false, cpu);
BM_BandedTriangularSolve(8, 32, 1, true, cpu);
BM_BandedTriangularSolve(8, 32, 1, false, cpu);
BM_BandedTriangularSolve(16, 32, 1, true, cpu);
BM_BandedTriangularSolve(16, 32, 1, false, cpu);
BM_BandedTriangularSolve(2, 128, 1, true, cpu);
BM_BandedTriangularSolve(2, 128, 1, false, cpu);
BM_BandedTriangularSolve(4, 128, 1, true, cpu);
BM_BandedTriangularSolve(4, 128, 1, false, cpu);
BM_BandedTriangularSolve(8, 128, 1, true, cpu);
BM_BandedTriangularSolve(8, 128, 1, false, cpu);
BM_BandedTriangularSolve(16, 128, 1, true, cpu);
BM_BandedTriangularSolve(16, 128, 1, false, cpu);
BM_BandedTriangularSolve(2, 512, 1, true, cpu);
BM_BandedTriangularSolve(2, 512, 1, false, cpu);
BM_BandedTriangularSolve(4, 512, 1, true, cpu);
BM_BandedTriangularSolve(4, 512, 1, false, cpu);
BM_BandedTriangularSolve(8, 512, 1, true, cpu);
BM_BandedTriangularSolve(8, 512, 1, false, cpu);
BM_BandedTriangularSolve(16, 512, 1, true, cpu);
BM_BandedTriangularSolve(16, 512, 1, false, cpu);
// Larger # rhs
BM_BandedTriangularSolve(2, 32, 32, true, cpu);
BM_BandedTriangularSolve(2, 32, 32, false, cpu);
BM_BandedTriangularSolve(4, 32, 32, true, cpu);
BM_BandedTriangularSolve(4, 32, 32, false, cpu);
BM_BandedTriangularSolve(8, 32, 32, true, cpu);
BM_BandedTriangularSolve(8, 32, 32, false, cpu);
BM_BandedTriangularSolve(16, 32, 32, true, cpu);
BM_BandedTriangularSolve(16, 32, 32, false, cpu);
BM_BandedTriangularSolve(2, 128, 128, true, cpu);
BM_BandedTriangularSolve(2, 128, 128, false, cpu);
BM_BandedTriangularSolve(4, 128, 128, true, cpu);
BM_BandedTriangularSolve(4, 128, 128, false, cpu);
BM_BandedTriangularSolve(8, 128, 128, true, cpu);
BM_BandedTriangularSolve(8, 128, 128, false, cpu);
BM_BandedTriangularSolve(16, 128, 128, true, cpu);
BM_BandedTriangularSolve(16, 128, 128, false, cpu);
BM_BandedTriangularSolve(2, 512, 512, true, cpu);
BM_BandedTriangularSolve(2, 512, 512, false, cpu);
BM_BandedTriangularSolve(4, 512, 512, true, cpu);
BM_BandedTriangularSolve(4, 512, 512, false, cpu);
BM_BandedTriangularSolve(8, 512, 512, true, cpu);
BM_BandedTriangularSolve(8, 512, 512, false, cpu);
BM_BandedTriangularSolve(16, 512, 512, true, cpu);
BM_BandedTriangularSolve(16, 512, 512, false, cpu);
BM_BandedTriangularSolve(2, 2048, 2048, true, cpu);
BM_BandedTriangularSolve(2, 2048, 2048, false, cpu);
BM_BandedTriangularSolve(4, 2048, 2048, true, cpu);
BM_BandedTriangularSolve(4, 2048, 2048, false, cpu);
BM_BandedTriangularSolve(8, 2048, 2048, true, cpu);
BM_BandedTriangularSolve(8, 2048, 2048, false, cpu);
BM_BandedTriangularSolve(16, 2048, 2048, true, cpu);
BM_BandedTriangularSolve(16, 2048, 2048, false, cpu);
BM_BandedTriangularSolve(32, 2048, 2048, true, cpu);
BM_BandedTriangularSolve(32, 2048, 2048, false, cpu);
BM_BandedTriangularSolve(64, 2048, 2048, true, cpu);
BM_BandedTriangularSolve(64, 2048, 2048, false, cpu);
} // namespace
} // namespace tensorflow

View File

@ -47,6 +47,49 @@ Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
return Status::OK(); return Status::OK();
} }
// The first input is [...,K,M] and second input is [...,M,N].
Status BandedTriangularSolveShapeFn(InferenceContext* c) {
ShapeHandle lhs;
ShapeHandle rhs;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
// Check K > 0.
DimensionHandle num_bands = c->Dim(lhs, -2);
DimensionHandle m = c->Dim(lhs, -1);
if (c->ValueKnown(num_bands) && c->Value(num_bands) <= 0) {
return errors::InvalidArgument("Number of bands must be positive, but is ",
c->Value(num_bands));
}
if (c->ValueKnown(num_bands) && c->ValueKnown(m) &&
c->Value(num_bands) > c->Value(m)) {
return errors::InvalidArgument("Number of bands ", c->Value(num_bands),
" cannot exceed the size of the matrix ",
c->Value(m));
}
ShapeHandle lhs_batch_shape;
ShapeHandle rhs_batch_shape;
ShapeHandle output_batch_shape;
// Make the common batch subshape.
TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
// lhs and rhs have the same value for M to be compatible.
TF_RETURN_IF_ERROR(c->Merge(m, c->Dim(rhs, -2), &m));
// Build final shape (batch_shape + m + n) in <out>.
ShapeHandle out;
TF_RETURN_IF_ERROR(
c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
c->set_output(0, out);
return Status::OK();
}
// The first input is [...,M,N] and second input is either [...,M,K] or [...,M]. // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
// Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M]. // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
Status MatrixSolveShapeFn(InferenceContext* c, bool square) { Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
@ -446,6 +489,17 @@ REGISTER_OP("MatrixSolve")
return MatrixSolveShapeFn(c, true /* square (*/); return MatrixSolveShapeFn(c, true /* square (*/);
}); });
REGISTER_OP("BandedTriangularSolve")
.Input("matrix: T")
.Input("rhs: T")
.Output("output: T")
.Attr("lower: bool = True")
.Attr("adjoint: bool = False")
.Attr("T: {double, float, half, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
return BandedTriangularSolveShapeFn(c);
});
REGISTER_OP("MatrixTriangularSolve") REGISTER_OP("MatrixTriangularSolve")
.Input("matrix: T") .Input("matrix: T")
.Input("rhs: T") .Input("rhs: T")

View File

@ -762,6 +762,17 @@ cuda_py_test(
], ],
) )
cuda_py_test(
name = "banded_triangular_solve_op_test",
size = "small",
srcs = ["banded_triangular_solve_op_test.py"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:linalg_ops",
"//third_party/py/numpy",
],
)
cuda_py_test( cuda_py_test(
name = "matrix_triangular_solve_op_test", name = "matrix_triangular_solve_op_test",
size = "medium", size = "medium",

View File

@ -0,0 +1,232 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.banded_triangular_solve."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.platform import test
class BandedTriangularSolveOpTest(test.TestCase):
def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None):
for lower in (False,):
for adjoint in (False, True):
for use_placeholder in True, False:
self._verifySolve(
x,
y,
lower=lower,
adjoint=adjoint,
batch_dims=batch_dims,
use_placeholder=use_placeholder,
dtypes=dtypes)
def _verifySolveAllWaysReal(self, x, y, batch_dims=None):
self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
def _verifySolveAllWaysComplex(self, x, y, batch_dims=None):
self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
def _verifySolve(self,
x,
y,
lower=True,
adjoint=False,
batch_dims=None,
use_placeholder=False,
dtypes=(np.float32, np.float64)):
for np_type in dtypes:
a = x.astype(np_type)
b = y.astype(np_type)
# Now we need to convert a to a dense triangular matrix.
def make_diags(diags, lower=True):
n = len(diags[0])
a = np.zeros(n * n, dtype=diags.dtype)
if lower:
for i, diag in enumerate(diags):
a[n * i:n * n:n + 1] = diag[i:]
else:
diags_flip = np.flip(diags, 0)
for i, diag in enumerate(diags_flip):
a[i:(n - i) * n:n + 1] = diag[:(n - i)]
return a.reshape(n, n)
# For numpy.solve we have to explicitly zero out the strictly
# upper or lower triangle.
if a.size > 0:
a_np = make_diags(a, lower=lower)
else:
a_np = a
if adjoint:
a_np = np.conj(np.transpose(a_np))
if batch_dims is not None:
a = np.tile(a, batch_dims + [1, 1])
a_np = np.tile(a_np, batch_dims + [1, 1])
b = np.tile(b, batch_dims + [1, 1])
with self.cached_session(use_gpu=True):
a_tf = a
b_tf = b
if use_placeholder:
a_tf = array_ops.placeholder_with_default(a_tf, shape=None)
b_tf = array_ops.placeholder_with_default(b_tf, shape=None)
tf_ans = linalg_ops.banded_triangular_solve(
a_tf, b_tf, lower=lower, adjoint=adjoint)
tf_val = self.evaluate(tf_ans)
np_ans = np.linalg.solve(a_np, b)
self.assertEqual(np_ans.shape, tf_val.shape)
self.assertAllClose(np_ans, tf_val)
@test_util.run_deprecated_v1
def testSolve(self):
# 1x1 matrix, single rhs.
matrix = np.array([[0.1]])
rhs0 = np.array([[1.]])
self._verifySolveAllWaysReal(matrix, rhs0)
# 2x2 matrix with 2 bands, single right-hand side.
# Corresponds to the lower triangular
# [[1., 0.], [3., 4.]]
# and upper triangular
# [[2., 1.], [0., 3.]]
matrix = np.array([[1., 4.], [2., 3.]])
rhs0 = np.array([[1.], [1.]])
self._verifySolveAllWaysReal(matrix, rhs0)
# 2x2 matrix with 2 bands, 3 right-hand sides.
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
self._verifySolveAllWaysReal(matrix, rhs1)
# 4 x 4 matrix with 2 bands, 3 right hand sides.
# Corresponds to the lower triangular
# [[1., 0., 0., 0.],
# [-1., 2., 0., 0.],
# [0., -2., 3., 0.],
# [0., 0., -3., 4.]]
# and upper triangular
# [[1., 1., 0., 0.],
# [0., -1., 2., 0.],
# [0., 0., -2., 3.],
# [0., 0., 0., -3.]]
matrix = np.array([[1., 2., 3., 4.], [1., -1., -2., -3.]])
rhs0 = np.array([[1., 0., 1.], [0., 1., 1.], [-1., 2., 1.], [0., -1., -1.]])
self._verifySolveAllWaysReal(matrix, rhs0)
def testSolveBandSizeSmaller(self):
rhs0 = np.random.randn(6, 4)
# 6 x 6 matrix with 2 bands. Ensure all non-zero entries.
matrix = 2. * np.random.uniform(size=[3, 6]) + 1.
self._verifySolveAllWaysReal(matrix, rhs0)
# 6 x 6 matrix with 3 bands. Ensure all non-zero entries.
matrix = 2. * np.random.uniform(size=[3, 6]) + 1.
self._verifySolveAllWaysReal(matrix, rhs0)
@test_util.run_deprecated_v1
def testSolveComplex(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
# 1x1 matrix, single rhs.
matrix = np.array([[0.1 + 1j * 0.1]])
rhs0 = np.array([[1. + 1j]])
self._verifySolveAllWaysComplex(matrix, rhs0)
# 2x2 matrix with 2 bands, single right-hand side.
# Corresponds to
# [[1. + 1j, 0.], [4 + 1j, 2 + 1j]]
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
matrix += 1j * matrix
rhs0 = np.array([[1.], [1.]]).astype(np.complex64)
rhs0 += 1j * rhs0
self._verifySolveAllWaysComplex(matrix, rhs0)
# 2x2 matrix with 2 bands, 3 right-hand sides.
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
rhs1 += 1j * rhs1
self._verifySolveAllWaysComplex(matrix, rhs1)
@test_util.run_deprecated_v1
def testSolveBatch(self):
matrix = np.array([[1., 2.], [3., 4.]])
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
matrix = np.array([[1., 2., 3., 4.], [-1., -2., -3., -4.],
[-1., 1., 2., 3.]])
rhs = np.array([[-1., 2.], [1., 1.], [0., 1.], [2., 3.]])
# Batch of 2x3x4x4 matrices with 3 bands, 2x3x4x2 right-hand sides.
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
# Batch of 3x2x4x4 matrices with 3 bands, 3x2x4x2 right-hand sides.
self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
@test_util.run_deprecated_v1
def testSolveBatchComplex(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
matrix += 1j * matrix
rhs = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
rhs += 1j * rhs
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[2, 3])
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[3, 2])
@test_util.run_deprecated_v1
def testWrongDimensions(self):
# The matrix should have the same number of rows as the
# right-hand sides.
matrix = np.array([[1., 1.], [1., 1.]])
rhs = np.array([[1., 0.]])
with self.cached_session(use_gpu=True):
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs, batch_dims=[2, 3])
# Number of bands exceeds the dimension of the matrix.
matrix = np.ones((6, 4))
rhs = np.ones((4, 2))
with self.cached_session(use_gpu=True):
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs, batch_dims=[2, 3])
@test_util.run_deprecated_v1
@test_util.disable_xla("XLA cannot throw assertion errors during a kernel.")
def testNotInvertible(self):
# The input should be invertible.
# The matrix is singular because it has a zero on the diagonal.
# FIXME(rmlarsen): The GPU kernel does not check for singularity.
singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
with self.cached_session():
with self.assertRaisesOpError("Input matrix is not invertible."):
self._verifySolve(singular_matrix, singular_matrix)
with self.assertRaisesOpError("Input matrix is not invertible."):
self._verifySolve(singular_matrix, singular_matrix, batch_dims=[2, 3])
if __name__ == "__main__":
test.main()

View File

@ -132,6 +132,44 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
return Test return Test
def _GetBandedTriangularSolveGradientTest(
functor_,
dtype_,
shape_,
float32_tol_fudge=1.0, # pylint: disable=redefined-outer-name
**kwargs_):
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def Test(self):
n = shape_[-1]
np.random.seed(1)
# Make sure invertible.
a_np = np.random.uniform(low=1.0, high=2.0, size=shape_).astype(dtype_)
a = constant_op.constant(a_np)
b_np = np.random.uniform(low=-1.0, high=1.0, size=[n, n]).astype(dtype_)
b = constant_op.constant(b_np)
epsilon = np.finfo(dtype_).eps
delta = epsilon**(1.0 / 3.0)
# tolerance obtained by looking at actual differences using
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
# check gradient w.r.t. left argument.
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda x: functor_(x, b, **kwargs_), [a], delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
# check gradient w.r.t. right argument.
theoretical, numerical = gradient_checker_v2.compute_gradient(
lambda y: functor_(a, y, **kwargs_), [b], delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
return Test
if __name__ == '__main__': if __name__ == '__main__':
# Tests for gradients of binary matrix operations. # Tests for gradients of binary matrix operations.
for dtype in np.float32, np.float64: for dtype in np.float32, np.float64:
@ -166,6 +204,20 @@ if __name__ == '__main__':
adjoint=adjoint, adjoint=adjoint,
lower=lower)) lower=lower))
band_shape = extra + (size // 2 + 1, size)
name = '%s_%s_adj_%s_low_%s' % (dtype.__name__, '_'.join(
map(str, band_shape)), str(adjoint), lower)
_AddTest(
MatrixBinaryFunctorGradientTest,
'BandedTriangularSolveGradient', name,
_GetBandedTriangularSolveGradientTest(
linalg_ops.banded_triangular_solve,
dtype,
band_shape,
float32_tol_fudge=4.0,
adjoint=adjoint,
lower=lower))
# Tests for gradients of unary matrix operations. # Tests for gradients of unary matrix operations.
for dtype in np.float32, np.float64: for dtype in np.float32, np.float64:
for size in 2, 5, 10: for size in 2, 5, 10:

View File

@ -340,6 +340,102 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:])) return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
@tf_export('linalg.banded_triangular_solve', v1=[])
def banded_triangular_solve(
bands,
rhs,
lower=True,
adjoint=False, # pylint: disable=redefined-outer-name
name=None):
r"""Solve triangular systems of equations with a banded solver.
`bands` is a tensor of shape `[..., K, M]`, where `K` represents the number
of bands stored. This corresponds to a batch of `M` by `M` matrices, whose
`K` subdiagonals (when `lower` is `True`) are stored.
This operator broadcasts the batch dimensions of `bands` and the batch
dimensions of `rhs`.
Examples:
Storing 2 bands of a 3x3 matrix.
Note that first element in the second row is ignored due to
the 'LEFT_RIGHT' padding.
>>> x = [[2., 3., 4.], [1., 2., 3.]]
>>> x2 = [[2., 3., 4.], [10000., 2., 3.]]
>>> y = tf.zeros([3, 3])
>>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
>>> z
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[2., 0., 0.],
[2., 3., 0.],
[0., 3., 4.]], dtype=float32)>
>>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
>>> soln
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[0.5 ],
[0. ],
[0.25]], dtype=float32)>
>>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
>>> tf.reduce_all(are_equal).numpy()
True
>>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
>>> tf.reduce_all(are_equal).numpy()
True
Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding
the last element of the first row is ignored.
>>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
>>> y = tf.zeros([4, 4])
>>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
>>> z
<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[-1., 2., 0., 0.],
[ 0., -2., 3., 0.],
[ 0., 0., -3., 4.],
[ 0., 0., -0., -4.]], dtype=float32)>
>>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
>>> soln
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[-4. ],
[-1.5 ],
[-0.6666667],
[-0.25 ]], dtype=float32)>
>>> are_equal = (soln == tf.linalg.triangular_solve(
... z, tf.ones([4, 1]), lower=False))
>>> tf.reduce_all(are_equal).numpy()
True
Args:
bands: A `Tensor` describing the bands of the left hand side, with shape
`[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th
diagonal (the diagonal is the top row) when `lower` is `True` and
otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is
the bottom row) when `lower` is `False`. The bands are stored with
'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right
and subdiagonals are padded on the left. This is the alignment cuSPARSE
uses. See `tf.linalg.set_diag` for more details.
rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as
`diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
statically, `rhs` will be treated as a matrix rather than a vector.
lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
`bands` represents a lower or upper triangular matrix.
adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
to solve with the matrix's block-wise adjoint.
name: A name to give this `Op` (optional).
Returns:
A `Tensor` of shape [..., M] or [..., M, N] containing the solutions.
"""
with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]):
return gen_linalg_ops.banded_triangular_solve(
bands, rhs, lower=lower, adjoint=adjoint)
@tf_export('linalg.tridiagonal_solve') @tf_export('linalg.tridiagonal_solve')
@dispatch.add_dispatch_support @dispatch.add_dispatch_support
def tridiagonal_solve(diagonals, def tridiagonal_solve(diagonals,

View File

@ -607,6 +607,39 @@ def _MatrixSolveLsGrad(op, grad):
lambda: _Underdetermined(op, grad)) lambda: _Underdetermined(op, grad))
@ops.RegisterGradient("BandedTriangularSolve")
def _BandedTriangularSolveGrad(op, grad):
"""Gradient for BandedTriangularSolve."""
a = op.inputs[0]
b = op.inputs[1]
num_bands = array_ops.shape(a)[-2]
adjoint_a = op.get_attr("adjoint")
lower_a = op.get_attr("lower")
c = op.outputs[0]
grad_b = linalg_ops.banded_triangular_solve(
a, grad, lower=lower_a, adjoint=not adjoint_a)
if adjoint_a:
grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)
else:
grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)
if lower_a:
grad_a = array_ops.matrix_diag_part(
grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
else:
grad_a = array_ops.matrix_diag_part(
grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
# If the static batch shapes are equal, we don't need to unbroadcast.
if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
a.shape[:-2] == b.shape[:-2]):
return grad_a, grad_b
a_shape = array_ops.shape(a)
b_shape = array_ops.shape(b)
ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
return grad_a, grad_b
@ops.RegisterGradient("MatrixTriangularSolve") @ops.RegisterGradient("MatrixTriangularSolve")
def _MatrixTriangularSolveGrad(op, grad): def _MatrixTriangularSolveGrad(op, grad):
"""Gradient for MatrixTriangularSolve.""" """Gradient for MatrixTriangularSolve."""

View File

@ -284,6 +284,10 @@ tf_module {
name: "AvgPoolGrad" name: "AvgPoolGrad"
argspec: "args=[\'orig_input_shape\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " argspec: "args=[\'orig_input_shape\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
} }
member_method {
name: "BandedTriangularSolve"
argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
}
member_method { member_method {
name: "Barrier" name: "Barrier"
argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], " argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], "

View File

@ -96,6 +96,10 @@ tf_module {
name: "band_part" name: "band_part"
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "banded_triangular_solve"
argspec: "args=[\'bands\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
}
member_method { member_method {
name: "cholesky" name: "cholesky"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -284,6 +284,10 @@ tf_module {
name: "AvgPoolGrad" name: "AvgPoolGrad"
argspec: "args=[\'orig_input_shape\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " argspec: "args=[\'orig_input_shape\', \'grad\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
} }
member_method {
name: "BandedTriangularSolve"
argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
}
member_method { member_method {
name: "Barrier" name: "Barrier"
argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], " argspec: "args=[\'component_types\', \'shapes\', \'capacity\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'-1\', \'\', \'\', \'None\'], "