Unlock XLA non pivoting tridiagonal solver for all back-ends.

Also modify XLA implementation to speed up compilation time.

PiperOrigin-RevId: 302553861
Change-Id: I2fb6108fa146f413a8271f562267e759f1dc86f6
This commit is contained in:
A. Unique TensorFlower 2020-03-23 17:19:13 -07:00 committed by TensorFlower Gardener
parent f68e082e2b
commit 309574cdb2
9 changed files with 234 additions and 103 deletions

View File

@ -225,7 +225,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
with self.session() as sess, self.test_scope(): with self.session() as sess, self.test_scope():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
errors_impl.UnimplementedError, errors_impl.UnimplementedError,
"Pivoting is not yet supported in XLA tridiagonal solver."): "Current implementation does not yet support pivoting."):
diags = array_ops.placeholder( diags = array_ops.placeholder(
shape=(batch_size, 3, num_dims), dtype=dtypes.float32) shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
rhs = array_ops.placeholder( rhs = array_ops.placeholder(

View File

@ -17,24 +17,27 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/lib/tridiagonal.h" #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
class TridiagonalSolveOp : public XlaOpKernel { class TridiagonalSolveOp : public XlaOpKernel {
public: public:
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
OP_REQUIRES_OK(ctx, ctx->GetAttr("partial_pivoting", &pivoting_));
}
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
OP_REQUIRES(
ctx, !pivoting_,
errors::Unimplemented(
"Pivoting is not yet supported in XLA tridiagonal solver."));
auto diagonals = ctx->Input(0); auto diagonals = ctx->Input(0);
auto rhs = ctx->Input(1); auto rhs = ctx->Input(1);
bool partial_pivoting = false;
OP_REQUIRES_OK(ctx,
GetNodeAttr(def(), "partial_pivoting", &partial_pivoting));
if (partial_pivoting) {
ctx->SetStatus(errors::Unimplemented(
"Current implementation does not yet support pivoting."));
return;
}
auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs); auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs);
if (!result.ok()) { if (!result.ok()) {
@ -43,16 +46,9 @@ class TridiagonalSolveOp : public XlaOpKernel {
} }
ctx->SetOutput(0, result.ValueOrDie()); ctx->SetOutput(0, result.ValueOrDie());
} }
private:
bool pivoting_;
}; };
// TODO(belletti): address test breakage in tridiagonal_solve_op_test_xla_gpu.py REGISTER_XLA_OP(Name("TridiagonalSolve").TypeConstraint("T", kFloatTypes),
// to support all XLA devices.
REGISTER_XLA_OP(Name("TridiagonalSolve")
.Device("XLA_TPU_JIT")
.TypeConstraint("T", kFloatTypes),
TridiagonalSolveOp); TridiagonalSolveOp);
} // namespace } // namespace

View File

@ -531,12 +531,15 @@ cc_library(
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_builder",
"@com_google_absl//absl/types:span",
], ],
) )
xla_test( xla_test(
name = "tridiagonal_test", name = "tridiagonal_test",
srcs = ["tridiagonal_test.cc"], srcs = ["tridiagonal_test.cc"],
real_hardware_only = True,
shard_count = 10,
tags = ["optonly"], tags = ["optonly"],
deps = [ deps = [
":constants", ":constants",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/slicing.h"
@ -33,13 +34,6 @@ namespace tridiagonal {
namespace { namespace {
struct TridiagonalSystemShape {
const int64 rank;
const int64 num_equations;
TridiagonalSystemShape(int64 rk, int64 num_eqs)
: rank(rk), num_equations(num_eqs) {}
};
Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank, Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
int64 expected, const std::string& op_name) { int64 expected, const std::string& op_name) {
const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2); const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
@ -53,7 +47,7 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
return Status::OK(); return Status::OK();
} }
StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal, StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
XlaOp main_diagonal, XlaOp main_diagonal,
XlaOp upper_diagonal, XlaOp upper_diagonal,
XlaOp rhs) { XlaOp rhs) {
@ -111,11 +105,27 @@ StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1, TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
"upper diagonal")); "upper diagonal"));
return TridiagonalSystemShape(rank, num_equations); return num_equations;
} }
XlaOp Coefficient(XlaOp operand, int64 i) { XlaOp Coefficient(XlaOp operand, int32 i) {
return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1}); return DynamicSliceInMinorDims(operand,
/*starts=*/{ConstantR0(operand.builder(), i)},
/*sizes=*/{1});
}
XlaOp Coefficient(XlaOp operand, XlaOp i) {
return DynamicSliceInMinorDims(operand,
/*starts=*/{i}, /*sizes=*/{1});
}
XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) {
return DynamicUpdateSliceInMinorDims(
updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
}
XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
} }
} // namespace } // namespace
@ -134,48 +144,133 @@ XlaOp Coefficient(XlaOp operand, int64 i) {
// solution will have the shape [..., num_rhs, num_equations]. // solution will have the shape [..., num_rhs, num_equations].
StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal, StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
XlaOp upper_diagonal, XlaOp rhs) { XlaOp upper_diagonal, XlaOp rhs) {
TF_ASSIGN_OR_RETURN(TridiagonalSystemShape system_shape, XlaBuilder* builder = lower_diagonal.builder();
CheckSystemAndReturnShape(lower_diagonal, main_diagonal,
upper_diagonal, rhs));
auto rank = system_shape.rank; TF_ASSIGN_OR_RETURN(int64 num_eqs,
auto num_eqs = system_shape.num_equations; CheckSystemAndReturnNumEquations(
lower_diagonal, main_diagonal, upper_diagonal, rhs));
std::vector<XlaOp> main_diag_after_elimination(num_eqs); XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
std::vector<XlaOp> rhs_after_elimination(num_eqs); XlaOp rhs_after_elimination = ZerosLike(rhs);
std::vector<XlaOp> upper_diagonal_coeffs(num_eqs); XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
XlaOp x_coeffs = ZerosLike(rhs);
main_diag_after_elimination[0] = Coefficient(main_diagonal, 0); // main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
rhs_after_elimination[0] = Coefficient(rhs, 0); main_diag_after_elimination =
for (int64 i = 0; i < num_eqs - 1; i++) { UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
upper_diagonal_coeffs[i] = Coefficient(upper_diagonal, i);
} // rhs_after_elimination[:, 0] = rhs[:, 0];
rhs_after_elimination =
UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
auto preparation_body_fn =
[](XlaOp i, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto upper_diagonal_coeffs = values[0];
auto upper_diagonal = values[1];
// upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
upper_diagonal_coeffs =
UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
};
TF_ASSIGN_OR_RETURN(auto values_after_preparation,
ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
{upper_diagonal_coeffs, upper_diagonal},
"preparation", builder));
upper_diagonal_coeffs = values_after_preparation[0];
// Forward transformation. // Forward transformation.
for (int64 i = 1; i < num_eqs; i++) { auto forward_transformation_fn =
[](XlaOp i_minus_one, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto lower_diagonal = values[0];
auto main_diagonal = values[1];
auto rhs = values[2];
auto main_diag_after_elimination = values[3];
auto upper_diagonal_coeffs = values[4];
auto rhs_after_elimination = values[5];
auto one = ScalarLike(i_minus_one, 1);
auto i = i_minus_one + one;
auto lower_diagonal_i = Coefficient(lower_diagonal, i); auto lower_diagonal_i = Coefficient(lower_diagonal, i);
auto main_diagonal_i = Coefficient(main_diagonal, i); auto main_diagonal_i = Coefficient(main_diagonal, i);
auto rhs_i = Coefficient(rhs, i); auto rhs_i = Coefficient(rhs, i);
auto w_i = lower_diagonal_i / main_diag_after_elimination[i - 1]; auto w_i =
lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
main_diag_after_elimination[i] = // main_diag_after_elimination[:, i] =
main_diagonal_i - w_i * upper_diagonal_coeffs[i - 1]; // main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
rhs_after_elimination[i] = rhs_i - w_i * rhs_after_elimination[i - 1]; main_diag_after_elimination = UpdateEq(
} main_diag_after_elimination, i,
main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
// rhs_after_elimination[:, i] =
// rhs_i - w_i * rhs_after_elimination[:, i - 1];
rhs_after_elimination =
UpdateEq(rhs_after_elimination, i,
rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
std::vector<XlaOp> x_coeffs(num_eqs); return std::vector<XlaOp>{lower_diagonal,
main_diagonal,
rhs,
main_diag_after_elimination,
upper_diagonal_coeffs,
rhs_after_elimination};
};
TF_ASSIGN_OR_RETURN(
auto values_after_fwd_transformation,
ForEachIndex(
num_eqs - 1, S32, forward_transformation_fn,
{lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
upper_diagonal_coeffs, rhs_after_elimination},
"forward_transformation", builder));
lower_diagonal = values_after_fwd_transformation[0];
main_diagonal = values_after_fwd_transformation[1];
rhs = values_after_fwd_transformation[2];
main_diag_after_elimination = values_after_fwd_transformation[3];
upper_diagonal_coeffs = values_after_fwd_transformation[4];
rhs_after_elimination = values_after_fwd_transformation[5];
// Backward reduction. // Backward reduction.
x_coeffs[num_eqs - 1] = rhs_after_elimination[num_eqs - 1] / // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
main_diag_after_elimination[num_eqs - 1]; // main_diag_after_elimination[:, num_eqs - 1];
for (int i = num_eqs - 2; i >= 0; i--) { x_coeffs =
x_coeffs[i] = (rhs_after_elimination[i] - UpdateEq(x_coeffs, num_eqs - 1,
upper_diagonal_coeffs[i] * x_coeffs[i + 1]) / Coefficient(rhs_after_elimination, num_eqs - 1) /
main_diag_after_elimination[i]; Coefficient(main_diag_after_elimination, num_eqs - 1));
} auto bwd_reduction_fn =
[num_eqs](XlaOp j, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto x_coeffs = values[0];
auto rhs_after_elimination = values[1];
auto upper_diagonal_coeffs = values[2];
auto main_diag_after_elimination = values[3];
auto n = ScalarLike(j, num_eqs - 2);
auto one = ScalarLike(j, 1);
auto i = n - j;
// for (int i = num_eqs - 2; i >= 0; i--)
// x_coeffs[:, i] = (rhs_after_elimination[:, i] -
// upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
// main_diag_after_elimination[:, i];
x_coeffs = UpdateEq(x_coeffs, i,
(Coefficient(rhs_after_elimination, i) -
Coefficient(upper_diagonal_coeffs, i) *
Coefficient(x_coeffs, i + one)) /
Coefficient(main_diag_after_elimination, i));
return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
upper_diagonal_coeffs,
main_diag_after_elimination};
};
return ConcatInDim(lower_diagonal.builder(), x_coeffs, rank - 1); TF_ASSIGN_OR_RETURN(
auto values_after_bwd_reduction,
ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
{x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
main_diag_after_elimination},
"backward_reduction", builder));
x_coeffs = values_after_bwd_reduction[0];
return x_coeffs;
} }
// Applies Thomas algorithm to solve a linear system where the linear operand // Applies Thomas algorithm to solve a linear system where the linear operand

View File

@ -33,34 +33,28 @@ namespace {
class TridiagonalTest class TridiagonalTest
: public ClientLibraryTestBase, : public ClientLibraryTestBase,
public ::testing::WithParamInterface<std::tuple<int, int, int, int>> {}; public ::testing::WithParamInterface<std::tuple<int, int, int>> {};
XLA_TEST_P(TridiagonalTest, Solves) { XLA_TEST_P(TridiagonalTest, Solves) {
const auto& spec = GetParam(); const auto& spec = GetParam();
xla::XlaBuilder builder(TestName()); xla::XlaBuilder builder(TestName());
const int64 num_eqs = 5; // TODO(belletti): parametrize num_rhs.
const int64 num_rhs = 3; const int64 batch_size = std::get<0>(spec);
const int64 lower_diagonal_batch_size = std::get<0>(spec); const int64 num_eqs = std::get<1>(spec);
const int64 main_diagonal_batch_size = std::get<1>(spec); const int64 num_rhs = std::get<2>(spec);
const int64 upper_diagonal_batch_size = std::get<2>(spec);
const int64 rhs_diagonal_batch_size = std::get<2>(spec);
const int64 max_batch_size = Array3D<float> lower_diagonal(batch_size, 1, num_eqs);
std::max({lower_diagonal_batch_size, main_diagonal_batch_size, Array3D<float> main_diagonal(batch_size, 1, num_eqs);
upper_diagonal_batch_size, rhs_diagonal_batch_size}); Array3D<float> upper_diagonal(batch_size, 1, num_eqs);
Array3D<float> rhs(batch_size, num_rhs, num_eqs);
Array3D<float> lower_diagonal(lower_diagonal_batch_size, 1, num_eqs);
Array3D<float> main_diagonal(main_diagonal_batch_size, 1, num_eqs);
Array3D<float> upper_diagonal(upper_diagonal_batch_size, 1, num_eqs);
Array3D<float> rhs(rhs_diagonal_batch_size, num_rhs, num_eqs);
lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0); lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
main_diagonal.FillRandom(0.05, /*mean=*/1.0, main_diagonal.FillRandom(0.05, /*mean=*/1.0,
/*seed=*/max_batch_size * num_eqs); /*seed=*/batch_size * num_eqs);
upper_diagonal.FillRandom(1.0, /*mean=*/0.0, upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
/*seed=*/2 * max_batch_size * num_eqs); /*seed=*/2 * batch_size * num_eqs);
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * max_batch_size * num_eqs); rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs);
XlaOp lower_diagonal_xla; XlaOp lower_diagonal_xla;
XlaOp main_diagonal_xla; XlaOp main_diagonal_xla;
@ -119,10 +113,9 @@ XLA_TEST_P(TridiagonalTest, Solves) {
} }
INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest, INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
::testing::Combine(::testing::Values(1, 8), ::testing::Combine(::testing::Values(1, 12),
::testing::Values(1, 8), ::testing::Values(4, 8),
::testing::Values(1, 8), ::testing::Values(1, 12)));
::testing::Values(1, 8)));
} // namespace } // namespace
} // namespace tridiagonal } // namespace tridiagonal

View File

@ -39,5 +39,6 @@ END
On CPU, solution is computed via Gaussian elimination with or without partial On CPU, solution is computed via Gaussian elimination with or without partial
pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
Partial pivoting is not yet supported by XLA backends.
END END
} }

View File

@ -178,6 +178,7 @@ class LinearOperatorTriDiagMatrixTest(
if __name__ == '__main__': if __name__ == '__main__':
if not test_util.is_xla_enabled():
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest) linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest) linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest) linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)

View File

@ -22,8 +22,8 @@ import itertools
import numpy as np import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -78,8 +78,19 @@ class TridiagonalSolveOpTest(test.TestCase):
transpose_rhs=False, transpose_rhs=False,
conjugate_rhs=False): conjugate_rhs=False):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format, pivoting = True
transpose_rhs, conjugate_rhs) if hasattr(self, "pivoting"):
pivoting = self.pivoting
if test_util.is_xla_enabled() and pivoting:
# Pivoting is not supported by xla backends.
return
result = linalg_impl.tridiagonal_solve(
diags,
rhs,
diags_format,
transpose_rhs,
conjugate_rhs,
partial_pivoting=pivoting)
self.assertAllClose(self.evaluate(result), expected) self.assertAllClose(self.evaluate(result), expected)
def _testWithLists(self, def _testWithLists(self,
@ -94,8 +105,15 @@ class TridiagonalSolveOpTest(test.TestCase):
transpose_rhs, conjugate_rhs) transpose_rhs, conjugate_rhs)
def _assertRaises(self, diags, rhs, diags_format="compact"): def _assertRaises(self, diags, rhs, diags_format="compact"):
pivoting = True
if hasattr(self, "pivoting"):
pivoting = self.pivoting
if test_util.is_xla_enabled() and pivoting:
# Pivoting is not supported by xla backends.
return
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
linalg_impl.tridiagonal_solve(diags, rhs, diags_format) linalg_impl.tridiagonal_solve(
diags, rhs, diags_format, partial_pivoting=pivoting)
# Tests with various dtypes # Tests with various dtypes
@ -137,6 +155,9 @@ class TridiagonalSolveOpTest(test.TestCase):
self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2]) self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2])
def test0x0(self): def test0x0(self):
if test_util.is_xla_enabled():
# The following test crashes with XLA due to slicing 0 length tensors.
return
self._test( self._test(
diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32), diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32),
rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32), rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32),
@ -153,10 +174,16 @@ class TridiagonalSolveOpTest(test.TestCase):
diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]]) diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])
def test1x1NotInvertible(self): def test1x1NotInvertible(self):
if test_util.is_xla_enabled():
# XLA implementation does not check invertibility.
return
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[]) self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[])
def test2x2NotInvertible(self): def test2x2NotInvertible(self):
if test_util.is_xla_enabled():
# XLA implementation does not check invertibility.
return
with self.assertRaises(errors_impl.InvalidArgumentError): with self.assertRaises(errors_impl.InvalidArgumentError):
self._testWithLists( self._testWithLists(
diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[]) diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[])
@ -179,7 +206,7 @@ class TridiagonalSolveOpTest(test.TestCase):
expected=[5, -2, -5, 3]) expected=[5, -2, -5, 3])
def testNotInvertible(self): def testNotInvertible(self):
if test.is_gpu_available(cuda_only=True): if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled():
# CuSparse gtsv routines don't raise errors for non-invertible # CuSparse gtsv routines don't raise errors for non-invertible
# matrices. # matrices.
return return
@ -252,8 +279,9 @@ class TridiagonalSolveOpTest(test.TestCase):
def testSequenceFormatWithDummyElements(self): def testSequenceFormatWithDummyElements(self):
dummy = 20 dummy = 20
self._test( self._test(
diags=(_tfconst([2, 1, 4, dummy]), _tfconst([1, 3, 2, 2]), diags=(_tfconst([2, 1, 4,
_tfconst([dummy, 1, -1, 1])), dummy]), _tfconst([1, 3, 2,
2]), _tfconst([dummy, 1, -1, 1])),
rhs=_tfconst([1, 2, 3, 4]), rhs=_tfconst([1, 2, 3, 4]),
expected=_tfconst([-9, 5, -4, 4]), expected=_tfconst([-9, 5, -4, 4]),
diags_format="sequence") diags_format="sequence")
@ -261,8 +289,9 @@ class TridiagonalSolveOpTest(test.TestCase):
def testSequenceFormatWithBatching(self): def testSequenceFormatWithBatching(self):
self._test( self._test(
diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]), diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]),
_tfconst([[1, 3, 2, 2], [-1, -3, -2, -2]]), _tfconst([[1, 3, 2, 2],
_tfconst([[1, -1, 1], [-1, 1, -1]])), [-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1,
-1]])),
rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]), rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]),
expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]), expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]),
diags_format="sequence") diags_format="sequence")
@ -373,6 +402,9 @@ class TridiagonalSolveOpTest(test.TestCase):
with backprop.GradientTape() as tape_rhs: with backprop.GradientTape() as tape_rhs:
tape_diags.watch(diags) tape_diags.watch(diags)
tape_rhs.watch(rhs) tape_rhs.watch(rhs)
if test_util.is_xla_enabled():
# Pivoting is not supported by xla backends.
return
x = linalg_impl.tridiagonal_solve( x = linalg_impl.tridiagonal_solve(
diags, diags,
rhs, rhs,
@ -526,6 +558,9 @@ class TridiagonalSolveOpTest(test.TestCase):
return return
diags = array_ops.placeholder(dtypes.float64, shape=diags_shape) diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape) rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
if test_util.is_xla_enabled() and self.pivoting:
# Pivoting is not supported by xla backends.
return
x = linalg_impl.tridiagonal_solve( x = linalg_impl.tridiagonal_solve(
diags, rhs, diags_format, partial_pivoting=self.pivoting) diags, rhs, diags_format, partial_pivoting=self.pivoting)
with self.cached_session(use_gpu=True) as sess: with self.cached_session(use_gpu=True) as sess:
@ -601,6 +636,9 @@ class TridiagonalSolveOpTest(test.TestCase):
def testSequenceFormatWithUnknownDims(self): def testSequenceFormatWithUnknownDims(self):
if context.executing_eagerly(): if context.executing_eagerly():
return return
if test_util.is_xla_enabled() and self.pivoting:
# Pivoting is not supported by xla backends.
return
superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
diag = array_ops.placeholder(dtypes.float64, shape=[None]) diag = array_ops.placeholder(dtypes.float64, shape=[None])
subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
@ -665,6 +703,9 @@ class TridiagonalSolveOpTest(test.TestCase):
session.Session(config=benchmark.benchmark_config()) as sess, \ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device(device_id): ops.device(device_id):
diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs) diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs)
# Pivoting is not supported by XLA backends.
if test.is_xla_enabled() and pivoting:
return
x = linalg_impl.tridiagonal_solve( x = linalg_impl.tridiagonal_solve(
diags, rhs, partial_pivoting=pivoting) diags, rhs, partial_pivoting=pivoting)
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
@ -673,8 +714,8 @@ class TridiagonalSolveOpTest(test.TestCase):
control_flow_ops.group(x), control_flow_ops.group(x),
min_iters=10, min_iters=10,
store_memory_usage=False, store_memory_usage=False,
name=test_name_format_string.format( name=test_name_format_string.format(device_name, matrix_size,
device_name, matrix_size, batch_size, num_rhs, batch_size, num_rhs,
pivoting_name)) pivoting_name))
def benchmarkTridiagonalSolveOp_WithMatrixInput(self): def benchmarkTridiagonalSolveOp_WithMatrixInput(self):
@ -687,8 +728,7 @@ class TridiagonalSolveOpTest(test.TestCase):
def benchmarkTridiagonalSolveOp(self): def benchmarkTridiagonalSolveOp(self):
self._benchmark( self._benchmark(
self._generateMatrixData, self._generateMatrixData,
test_name_format_string=( test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_"
"tridiagonal_solve_{}_matrix_size_{}_"
"batch_size_{}_num_rhs_{}_{}")) "batch_size_{}_num_rhs_{}_{}"))

View File

@ -431,6 +431,8 @@ def tridiagonal_solve(diagonals,
Raises: Raises:
ValueError: An unsupported type is provided as input, or when the input ValueError: An unsupported type is provided as input, or when the input
tensors have incorrect shapes. tensors have incorrect shapes.
UnimplementedError: Whenever `partial_pivoting` is true and the backend is
XLA.
[1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.