Unlock XLA non pivoting tridiagonal solver for all back-ends.

Also modify XLA implementation to speed up compilation time.

PiperOrigin-RevId: 302553861
Change-Id: I2fb6108fa146f413a8271f562267e759f1dc86f6
This commit is contained in:
A. Unique TensorFlower 2020-03-23 17:19:13 -07:00 committed by TensorFlower Gardener
parent f68e082e2b
commit 309574cdb2
9 changed files with 234 additions and 103 deletions

View File

@ -225,7 +225,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
with self.session() as sess, self.test_scope():
with self.assertRaisesRegexp(
errors_impl.UnimplementedError,
"Pivoting is not yet supported in XLA tridiagonal solver."):
"Current implementation does not yet support pivoting."):
diags = array_ops.placeholder(
shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
rhs = array_ops.placeholder(

View File

@ -17,24 +17,27 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace {
class TridiagonalSolveOp : public XlaOpKernel {
public:
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("partial_pivoting", &pivoting_));
}
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
OP_REQUIRES(
ctx, !pivoting_,
errors::Unimplemented(
"Pivoting is not yet supported in XLA tridiagonal solver."));
auto diagonals = ctx->Input(0);
auto rhs = ctx->Input(1);
bool partial_pivoting = false;
OP_REQUIRES_OK(ctx,
GetNodeAttr(def(), "partial_pivoting", &partial_pivoting));
if (partial_pivoting) {
ctx->SetStatus(errors::Unimplemented(
"Current implementation does not yet support pivoting."));
return;
}
auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs);
if (!result.ok()) {
@ -43,16 +46,9 @@ class TridiagonalSolveOp : public XlaOpKernel {
}
ctx->SetOutput(0, result.ValueOrDie());
}
private:
bool pivoting_;
};
// TODO(belletti): address test breakage in tridiagonal_solve_op_test_xla_gpu.py
// to support all XLA devices.
REGISTER_XLA_OP(Name("TridiagonalSolve")
.Device("XLA_TPU_JIT")
.TypeConstraint("T", kFloatTypes),
REGISTER_XLA_OP(Name("TridiagonalSolve").TypeConstraint("T", kFloatTypes),
TridiagonalSolveOp);
} // namespace

View File

@ -531,12 +531,15 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"@com_google_absl//absl/types:span",
],
)
xla_test(
name = "tridiagonal_test",
srcs = ["tridiagonal_test.cc"],
real_hardware_only = True,
shard_count = 10,
tags = ["optonly"],
deps = [
":constants",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
@ -33,13 +34,6 @@ namespace tridiagonal {
namespace {
struct TridiagonalSystemShape {
const int64 rank;
const int64 num_equations;
TridiagonalSystemShape(int64 rk, int64 num_eqs)
: rank(rk), num_equations(num_eqs) {}
};
Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
int64 expected, const std::string& op_name) {
const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
@ -53,10 +47,10 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
return Status::OK();
}
StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
XlaOp main_diagonal,
XlaOp upper_diagonal,
XlaOp rhs) {
StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
XlaOp main_diagonal,
XlaOp upper_diagonal,
XlaOp rhs) {
XlaBuilder* builder = lower_diagonal.builder();
TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape,
@ -111,11 +105,27 @@ StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
"upper diagonal"));
return TridiagonalSystemShape(rank, num_equations);
return num_equations;
}
XlaOp Coefficient(XlaOp operand, int64 i) {
return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1});
XlaOp Coefficient(XlaOp operand, int32 i) {
return DynamicSliceInMinorDims(operand,
/*starts=*/{ConstantR0(operand.builder(), i)},
/*sizes=*/{1});
}
XlaOp Coefficient(XlaOp operand, XlaOp i) {
return DynamicSliceInMinorDims(operand,
/*starts=*/{i}, /*sizes=*/{1});
}
XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) {
return DynamicUpdateSliceInMinorDims(
updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
}
XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
}
} // namespace
@ -134,48 +144,133 @@ XlaOp Coefficient(XlaOp operand, int64 i) {
// solution will have the shape [..., num_rhs, num_equations].
StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
XlaOp upper_diagonal, XlaOp rhs) {
TF_ASSIGN_OR_RETURN(TridiagonalSystemShape system_shape,
CheckSystemAndReturnShape(lower_diagonal, main_diagonal,
upper_diagonal, rhs));
XlaBuilder* builder = lower_diagonal.builder();
auto rank = system_shape.rank;
auto num_eqs = system_shape.num_equations;
TF_ASSIGN_OR_RETURN(int64 num_eqs,
CheckSystemAndReturnNumEquations(
lower_diagonal, main_diagonal, upper_diagonal, rhs));
std::vector<XlaOp> main_diag_after_elimination(num_eqs);
std::vector<XlaOp> rhs_after_elimination(num_eqs);
std::vector<XlaOp> upper_diagonal_coeffs(num_eqs);
XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
XlaOp rhs_after_elimination = ZerosLike(rhs);
XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
XlaOp x_coeffs = ZerosLike(rhs);
main_diag_after_elimination[0] = Coefficient(main_diagonal, 0);
rhs_after_elimination[0] = Coefficient(rhs, 0);
for (int64 i = 0; i < num_eqs - 1; i++) {
upper_diagonal_coeffs[i] = Coefficient(upper_diagonal, i);
}
// main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
main_diag_after_elimination =
UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
// rhs_after_elimination[:, 0] = rhs[:, 0];
rhs_after_elimination =
UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
auto preparation_body_fn =
[](XlaOp i, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto upper_diagonal_coeffs = values[0];
auto upper_diagonal = values[1];
// upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
upper_diagonal_coeffs =
UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
};
TF_ASSIGN_OR_RETURN(auto values_after_preparation,
ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
{upper_diagonal_coeffs, upper_diagonal},
"preparation", builder));
upper_diagonal_coeffs = values_after_preparation[0];
// Forward transformation.
for (int64 i = 1; i < num_eqs; i++) {
auto forward_transformation_fn =
[](XlaOp i_minus_one, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto lower_diagonal = values[0];
auto main_diagonal = values[1];
auto rhs = values[2];
auto main_diag_after_elimination = values[3];
auto upper_diagonal_coeffs = values[4];
auto rhs_after_elimination = values[5];
auto one = ScalarLike(i_minus_one, 1);
auto i = i_minus_one + one;
auto lower_diagonal_i = Coefficient(lower_diagonal, i);
auto main_diagonal_i = Coefficient(main_diagonal, i);
auto rhs_i = Coefficient(rhs, i);
auto w_i = lower_diagonal_i / main_diag_after_elimination[i - 1];
auto w_i =
lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
main_diag_after_elimination[i] =
main_diagonal_i - w_i * upper_diagonal_coeffs[i - 1];
rhs_after_elimination[i] = rhs_i - w_i * rhs_after_elimination[i - 1];
}
// main_diag_after_elimination[:, i] =
// main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
main_diag_after_elimination = UpdateEq(
main_diag_after_elimination, i,
main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
// rhs_after_elimination[:, i] =
// rhs_i - w_i * rhs_after_elimination[:, i - 1];
rhs_after_elimination =
UpdateEq(rhs_after_elimination, i,
rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
std::vector<XlaOp> x_coeffs(num_eqs);
return std::vector<XlaOp>{lower_diagonal,
main_diagonal,
rhs,
main_diag_after_elimination,
upper_diagonal_coeffs,
rhs_after_elimination};
};
TF_ASSIGN_OR_RETURN(
auto values_after_fwd_transformation,
ForEachIndex(
num_eqs - 1, S32, forward_transformation_fn,
{lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
upper_diagonal_coeffs, rhs_after_elimination},
"forward_transformation", builder));
lower_diagonal = values_after_fwd_transformation[0];
main_diagonal = values_after_fwd_transformation[1];
rhs = values_after_fwd_transformation[2];
main_diag_after_elimination = values_after_fwd_transformation[3];
upper_diagonal_coeffs = values_after_fwd_transformation[4];
rhs_after_elimination = values_after_fwd_transformation[5];
// Backward reduction.
x_coeffs[num_eqs - 1] = rhs_after_elimination[num_eqs - 1] /
main_diag_after_elimination[num_eqs - 1];
for (int i = num_eqs - 2; i >= 0; i--) {
x_coeffs[i] = (rhs_after_elimination[i] -
upper_diagonal_coeffs[i] * x_coeffs[i + 1]) /
main_diag_after_elimination[i];
}
// x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
// main_diag_after_elimination[:, num_eqs - 1];
x_coeffs =
UpdateEq(x_coeffs, num_eqs - 1,
Coefficient(rhs_after_elimination, num_eqs - 1) /
Coefficient(main_diag_after_elimination, num_eqs - 1));
auto bwd_reduction_fn =
[num_eqs](XlaOp j, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto x_coeffs = values[0];
auto rhs_after_elimination = values[1];
auto upper_diagonal_coeffs = values[2];
auto main_diag_after_elimination = values[3];
auto n = ScalarLike(j, num_eqs - 2);
auto one = ScalarLike(j, 1);
auto i = n - j;
// for (int i = num_eqs - 2; i >= 0; i--)
// x_coeffs[:, i] = (rhs_after_elimination[:, i] -
// upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
// main_diag_after_elimination[:, i];
x_coeffs = UpdateEq(x_coeffs, i,
(Coefficient(rhs_after_elimination, i) -
Coefficient(upper_diagonal_coeffs, i) *
Coefficient(x_coeffs, i + one)) /
Coefficient(main_diag_after_elimination, i));
return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
upper_diagonal_coeffs,
main_diag_after_elimination};
};
return ConcatInDim(lower_diagonal.builder(), x_coeffs, rank - 1);
TF_ASSIGN_OR_RETURN(
auto values_after_bwd_reduction,
ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
{x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
main_diag_after_elimination},
"backward_reduction", builder));
x_coeffs = values_after_bwd_reduction[0];
return x_coeffs;
}
// Applies Thomas algorithm to solve a linear system where the linear operand

View File

@ -33,34 +33,28 @@ namespace {
class TridiagonalTest
: public ClientLibraryTestBase,
public ::testing::WithParamInterface<std::tuple<int, int, int, int>> {};
public ::testing::WithParamInterface<std::tuple<int, int, int>> {};
XLA_TEST_P(TridiagonalTest, Solves) {
const auto& spec = GetParam();
xla::XlaBuilder builder(TestName());
const int64 num_eqs = 5;
const int64 num_rhs = 3;
const int64 lower_diagonal_batch_size = std::get<0>(spec);
const int64 main_diagonal_batch_size = std::get<1>(spec);
const int64 upper_diagonal_batch_size = std::get<2>(spec);
const int64 rhs_diagonal_batch_size = std::get<2>(spec);
// TODO(belletti): parametrize num_rhs.
const int64 batch_size = std::get<0>(spec);
const int64 num_eqs = std::get<1>(spec);
const int64 num_rhs = std::get<2>(spec);
const int64 max_batch_size =
std::max({lower_diagonal_batch_size, main_diagonal_batch_size,
upper_diagonal_batch_size, rhs_diagonal_batch_size});
Array3D<float> lower_diagonal(lower_diagonal_batch_size, 1, num_eqs);
Array3D<float> main_diagonal(main_diagonal_batch_size, 1, num_eqs);
Array3D<float> upper_diagonal(upper_diagonal_batch_size, 1, num_eqs);
Array3D<float> rhs(rhs_diagonal_batch_size, num_rhs, num_eqs);
Array3D<float> lower_diagonal(batch_size, 1, num_eqs);
Array3D<float> main_diagonal(batch_size, 1, num_eqs);
Array3D<float> upper_diagonal(batch_size, 1, num_eqs);
Array3D<float> rhs(batch_size, num_rhs, num_eqs);
lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
main_diagonal.FillRandom(0.05, /*mean=*/1.0,
/*seed=*/max_batch_size * num_eqs);
/*seed=*/batch_size * num_eqs);
upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
/*seed=*/2 * max_batch_size * num_eqs);
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * max_batch_size * num_eqs);
/*seed=*/2 * batch_size * num_eqs);
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs);
XlaOp lower_diagonal_xla;
XlaOp main_diagonal_xla;
@ -119,10 +113,9 @@ XLA_TEST_P(TridiagonalTest, Solves) {
}
INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
::testing::Combine(::testing::Values(1, 8),
::testing::Values(1, 8),
::testing::Values(1, 8),
::testing::Values(1, 8)));
::testing::Combine(::testing::Values(1, 12),
::testing::Values(4, 8),
::testing::Values(1, 12)));
} // namespace
} // namespace tridiagonal

View File

@ -39,5 +39,6 @@ END
On CPU, solution is computed via Gaussian elimination with or without partial
pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
Partial pivoting is not yet supported by XLA backends.
END
}

View File

@ -178,7 +178,8 @@ class LinearOperatorTriDiagMatrixTest(
if __name__ == '__main__':
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
if not test_util.is_xla_enabled():
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
test.main()

View File

@ -22,8 +22,8 @@ import itertools
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -78,8 +78,19 @@ class TridiagonalSolveOpTest(test.TestCase):
transpose_rhs=False,
conjugate_rhs=False):
with self.cached_session(use_gpu=True):
result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format,
transpose_rhs, conjugate_rhs)
pivoting = True
if hasattr(self, "pivoting"):
pivoting = self.pivoting
if test_util.is_xla_enabled() and pivoting:
# Pivoting is not supported by xla backends.
return
result = linalg_impl.tridiagonal_solve(
diags,
rhs,
diags_format,
transpose_rhs,
conjugate_rhs,
partial_pivoting=pivoting)
self.assertAllClose(self.evaluate(result), expected)
def _testWithLists(self,
@ -94,8 +105,15 @@ class TridiagonalSolveOpTest(test.TestCase):
transpose_rhs, conjugate_rhs)
def _assertRaises(self, diags, rhs, diags_format="compact"):
pivoting = True
if hasattr(self, "pivoting"):
pivoting = self.pivoting
if test_util.is_xla_enabled() and pivoting:
# Pivoting is not supported by xla backends.
return
with self.assertRaises(ValueError):
linalg_impl.tridiagonal_solve(diags, rhs, diags_format)
linalg_impl.tridiagonal_solve(
diags, rhs, diags_format, partial_pivoting=pivoting)
# Tests with various dtypes
@ -137,6 +155,9 @@ class TridiagonalSolveOpTest(test.TestCase):
self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2])
def test0x0(self):
if test_util.is_xla_enabled():
# The following test crashes with XLA due to slicing 0 length tensors.
return
self._test(
diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32),
rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32),
@ -153,10 +174,16 @@ class TridiagonalSolveOpTest(test.TestCase):
diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])
def test1x1NotInvertible(self):
if test_util.is_xla_enabled():
# XLA implementation does not check invertibility.
return
with self.assertRaises(errors_impl.InvalidArgumentError):
self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[])
def test2x2NotInvertible(self):
if test_util.is_xla_enabled():
# XLA implementation does not check invertibility.
return
with self.assertRaises(errors_impl.InvalidArgumentError):
self._testWithLists(
diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[])
@ -179,7 +206,7 @@ class TridiagonalSolveOpTest(test.TestCase):
expected=[5, -2, -5, 3])
def testNotInvertible(self):
if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled():
# CuSparse gtsv routines don't raise errors for non-invertible
# matrices.
return
@ -252,8 +279,9 @@ class TridiagonalSolveOpTest(test.TestCase):
def testSequenceFormatWithDummyElements(self):
dummy = 20
self._test(
diags=(_tfconst([2, 1, 4, dummy]), _tfconst([1, 3, 2, 2]),
_tfconst([dummy, 1, -1, 1])),
diags=(_tfconst([2, 1, 4,
dummy]), _tfconst([1, 3, 2,
2]), _tfconst([dummy, 1, -1, 1])),
rhs=_tfconst([1, 2, 3, 4]),
expected=_tfconst([-9, 5, -4, 4]),
diags_format="sequence")
@ -261,8 +289,9 @@ class TridiagonalSolveOpTest(test.TestCase):
def testSequenceFormatWithBatching(self):
self._test(
diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]),
_tfconst([[1, 3, 2, 2], [-1, -3, -2, -2]]),
_tfconst([[1, -1, 1], [-1, 1, -1]])),
_tfconst([[1, 3, 2, 2],
[-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1,
-1]])),
rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]),
expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]),
diags_format="sequence")
@ -373,6 +402,9 @@ class TridiagonalSolveOpTest(test.TestCase):
with backprop.GradientTape() as tape_rhs:
tape_diags.watch(diags)
tape_rhs.watch(rhs)
if test_util.is_xla_enabled():
# Pivoting is not supported by xla backends.
return
x = linalg_impl.tridiagonal_solve(
diags,
rhs,
@ -526,6 +558,9 @@ class TridiagonalSolveOpTest(test.TestCase):
return
diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
if test_util.is_xla_enabled() and self.pivoting:
# Pivoting is not supported by xla backends.
return
x = linalg_impl.tridiagonal_solve(
diags, rhs, diags_format, partial_pivoting=self.pivoting)
with self.cached_session(use_gpu=True) as sess:
@ -601,6 +636,9 @@ class TridiagonalSolveOpTest(test.TestCase):
def testSequenceFormatWithUnknownDims(self):
if context.executing_eagerly():
return
if test_util.is_xla_enabled() and self.pivoting:
# Pivoting is not supported by xla backends.
return
superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
diag = array_ops.placeholder(dtypes.float64, shape=[None])
subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
@ -641,9 +679,9 @@ class TridiagonalSolveOpTest(test.TestCase):
np.random.seed(seed)
import scipy.sparse as sparse # pylint:disable=g-import-not-at-top
# By being strictly diagonally dominant, we guarantee invertibility.d
diag = 2* np.abs(np.random.randn(matrix_size)) + 4.1
subdiag = 2* np.abs(np.random.randn(matrix_size-1))
superdiag = 2* np.abs(np.random.randn(matrix_size-1))
diag = 2 * np.abs(np.random.randn(matrix_size)) + 4.1
subdiag = 2 * np.abs(np.random.randn(matrix_size - 1))
superdiag = 2 * np.abs(np.random.randn(matrix_size - 1))
matrix = sparse.diags([superdiag, diag, subdiag], [1, 0, -1]).toarray()
vector = np.random.randn(batch_size, matrix_size, num_rhs)
return (variables.Variable(np.tile(matrix, (batch_size, 1, 1))),
@ -665,6 +703,9 @@ class TridiagonalSolveOpTest(test.TestCase):
session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device(device_id):
diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs)
# Pivoting is not supported by XLA backends.
if test.is_xla_enabled() and pivoting:
return
x = linalg_impl.tridiagonal_solve(
diags, rhs, partial_pivoting=pivoting)
variables.global_variables_initializer().run()
@ -673,9 +714,9 @@ class TridiagonalSolveOpTest(test.TestCase):
control_flow_ops.group(x),
min_iters=10,
store_memory_usage=False,
name=test_name_format_string.format(
device_name, matrix_size, batch_size, num_rhs,
pivoting_name))
name=test_name_format_string.format(device_name, matrix_size,
batch_size, num_rhs,
pivoting_name))
def benchmarkTridiagonalSolveOp_WithMatrixInput(self):
self._benchmark(
@ -687,9 +728,8 @@ class TridiagonalSolveOpTest(test.TestCase):
def benchmarkTridiagonalSolveOp(self):
self._benchmark(
self._generateMatrixData,
test_name_format_string=(
"tridiagonal_solve_{}_matrix_size_{}_"
"batch_size_{}_num_rhs_{}_{}"))
test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_"
"batch_size_{}_num_rhs_{}_{}"))
if __name__ == "__main__":

View File

@ -430,7 +430,9 @@ def tridiagonal_solve(diagonals,
Raises:
ValueError: An unsupported type is provided as input, or when the input
tensors have incorrect shapes.
tensors have incorrect shapes.
UnimplementedError: Whenever `partial_pivoting` is true and the backend is
XLA.
[1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.