Unlock XLA non pivoting tridiagonal solver for all back-ends.
Also modify XLA implementation to speed up compilation time. PiperOrigin-RevId: 302553861 Change-Id: I2fb6108fa146f413a8271f562267e759f1dc86f6
This commit is contained in:
parent
f68e082e2b
commit
309574cdb2
|
@ -225,7 +225,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
|
|||
with self.session() as sess, self.test_scope():
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.UnimplementedError,
|
||||
"Pivoting is not yet supported in XLA tridiagonal solver."):
|
||||
"Current implementation does not yet support pivoting."):
|
||||
diags = array_ops.placeholder(
|
||||
shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
|
||||
rhs = array_ops.placeholder(
|
||||
|
|
|
@ -17,24 +17,27 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class TridiagonalSolveOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("partial_pivoting", &pivoting_));
|
||||
}
|
||||
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
OP_REQUIRES(
|
||||
ctx, !pivoting_,
|
||||
errors::Unimplemented(
|
||||
"Pivoting is not yet supported in XLA tridiagonal solver."));
|
||||
|
||||
auto diagonals = ctx->Input(0);
|
||||
auto rhs = ctx->Input(1);
|
||||
bool partial_pivoting = false;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetNodeAttr(def(), "partial_pivoting", &partial_pivoting));
|
||||
if (partial_pivoting) {
|
||||
ctx->SetStatus(errors::Unimplemented(
|
||||
"Current implementation does not yet support pivoting."));
|
||||
return;
|
||||
}
|
||||
|
||||
auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs);
|
||||
if (!result.ok()) {
|
||||
|
@ -43,16 +46,9 @@ class TridiagonalSolveOp : public XlaOpKernel {
|
|||
}
|
||||
ctx->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
|
||||
private:
|
||||
bool pivoting_;
|
||||
};
|
||||
|
||||
// TODO(belletti): address test breakage in tridiagonal_solve_op_test_xla_gpu.py
|
||||
// to support all XLA devices.
|
||||
REGISTER_XLA_OP(Name("TridiagonalSolve")
|
||||
.Device("XLA_TPU_JIT")
|
||||
.TypeConstraint("T", kFloatTypes),
|
||||
REGISTER_XLA_OP(Name("TridiagonalSolve").TypeConstraint("T", kFloatTypes),
|
||||
TridiagonalSolveOp);
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -531,12 +531,15 @@ cc_library(
|
|||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "tridiagonal_test",
|
||||
srcs = ["tridiagonal_test.cc"],
|
||||
real_hardware_only = True,
|
||||
shard_count = 10,
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":constants",
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
@ -33,13 +34,6 @@ namespace tridiagonal {
|
|||
|
||||
namespace {
|
||||
|
||||
struct TridiagonalSystemShape {
|
||||
const int64 rank;
|
||||
const int64 num_equations;
|
||||
TridiagonalSystemShape(int64 rk, int64 num_eqs)
|
||||
: rank(rk), num_equations(num_eqs) {}
|
||||
};
|
||||
|
||||
Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
|
||||
int64 expected, const std::string& op_name) {
|
||||
const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
|
||||
|
@ -53,10 +47,10 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
|
||||
XlaOp main_diagonal,
|
||||
XlaOp upper_diagonal,
|
||||
XlaOp rhs) {
|
||||
StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
|
||||
XlaOp main_diagonal,
|
||||
XlaOp upper_diagonal,
|
||||
XlaOp rhs) {
|
||||
XlaBuilder* builder = lower_diagonal.builder();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape,
|
||||
|
@ -111,11 +105,27 @@ StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
|
|||
TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
|
||||
"upper diagonal"));
|
||||
|
||||
return TridiagonalSystemShape(rank, num_equations);
|
||||
return num_equations;
|
||||
}
|
||||
|
||||
XlaOp Coefficient(XlaOp operand, int64 i) {
|
||||
return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1});
|
||||
XlaOp Coefficient(XlaOp operand, int32 i) {
|
||||
return DynamicSliceInMinorDims(operand,
|
||||
/*starts=*/{ConstantR0(operand.builder(), i)},
|
||||
/*sizes=*/{1});
|
||||
}
|
||||
|
||||
XlaOp Coefficient(XlaOp operand, XlaOp i) {
|
||||
return DynamicSliceInMinorDims(operand,
|
||||
/*starts=*/{i}, /*sizes=*/{1});
|
||||
}
|
||||
|
||||
XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) {
|
||||
return DynamicUpdateSliceInMinorDims(
|
||||
updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
|
||||
}
|
||||
|
||||
XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
|
||||
return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -134,48 +144,133 @@ XlaOp Coefficient(XlaOp operand, int64 i) {
|
|||
// solution will have the shape [..., num_rhs, num_equations].
|
||||
StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
|
||||
XlaOp upper_diagonal, XlaOp rhs) {
|
||||
TF_ASSIGN_OR_RETURN(TridiagonalSystemShape system_shape,
|
||||
CheckSystemAndReturnShape(lower_diagonal, main_diagonal,
|
||||
upper_diagonal, rhs));
|
||||
XlaBuilder* builder = lower_diagonal.builder();
|
||||
|
||||
auto rank = system_shape.rank;
|
||||
auto num_eqs = system_shape.num_equations;
|
||||
TF_ASSIGN_OR_RETURN(int64 num_eqs,
|
||||
CheckSystemAndReturnNumEquations(
|
||||
lower_diagonal, main_diagonal, upper_diagonal, rhs));
|
||||
|
||||
std::vector<XlaOp> main_diag_after_elimination(num_eqs);
|
||||
std::vector<XlaOp> rhs_after_elimination(num_eqs);
|
||||
std::vector<XlaOp> upper_diagonal_coeffs(num_eqs);
|
||||
XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
|
||||
XlaOp rhs_after_elimination = ZerosLike(rhs);
|
||||
XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
|
||||
XlaOp x_coeffs = ZerosLike(rhs);
|
||||
|
||||
main_diag_after_elimination[0] = Coefficient(main_diagonal, 0);
|
||||
rhs_after_elimination[0] = Coefficient(rhs, 0);
|
||||
for (int64 i = 0; i < num_eqs - 1; i++) {
|
||||
upper_diagonal_coeffs[i] = Coefficient(upper_diagonal, i);
|
||||
}
|
||||
// main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
|
||||
main_diag_after_elimination =
|
||||
UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
|
||||
|
||||
// rhs_after_elimination[:, 0] = rhs[:, 0];
|
||||
rhs_after_elimination =
|
||||
UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
|
||||
|
||||
auto preparation_body_fn =
|
||||
[](XlaOp i, absl::Span<const XlaOp> values,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto upper_diagonal_coeffs = values[0];
|
||||
auto upper_diagonal = values[1];
|
||||
// upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
|
||||
upper_diagonal_coeffs =
|
||||
UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
|
||||
return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(auto values_after_preparation,
|
||||
ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
|
||||
{upper_diagonal_coeffs, upper_diagonal},
|
||||
"preparation", builder));
|
||||
upper_diagonal_coeffs = values_after_preparation[0];
|
||||
|
||||
// Forward transformation.
|
||||
for (int64 i = 1; i < num_eqs; i++) {
|
||||
auto forward_transformation_fn =
|
||||
[](XlaOp i_minus_one, absl::Span<const XlaOp> values,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto lower_diagonal = values[0];
|
||||
auto main_diagonal = values[1];
|
||||
auto rhs = values[2];
|
||||
auto main_diag_after_elimination = values[3];
|
||||
auto upper_diagonal_coeffs = values[4];
|
||||
auto rhs_after_elimination = values[5];
|
||||
|
||||
auto one = ScalarLike(i_minus_one, 1);
|
||||
auto i = i_minus_one + one;
|
||||
auto lower_diagonal_i = Coefficient(lower_diagonal, i);
|
||||
auto main_diagonal_i = Coefficient(main_diagonal, i);
|
||||
auto rhs_i = Coefficient(rhs, i);
|
||||
|
||||
auto w_i = lower_diagonal_i / main_diag_after_elimination[i - 1];
|
||||
auto w_i =
|
||||
lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
|
||||
|
||||
main_diag_after_elimination[i] =
|
||||
main_diagonal_i - w_i * upper_diagonal_coeffs[i - 1];
|
||||
rhs_after_elimination[i] = rhs_i - w_i * rhs_after_elimination[i - 1];
|
||||
}
|
||||
// main_diag_after_elimination[:, i] =
|
||||
// main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
|
||||
main_diag_after_elimination = UpdateEq(
|
||||
main_diag_after_elimination, i,
|
||||
main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
|
||||
// rhs_after_elimination[:, i] =
|
||||
// rhs_i - w_i * rhs_after_elimination[:, i - 1];
|
||||
rhs_after_elimination =
|
||||
UpdateEq(rhs_after_elimination, i,
|
||||
rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
|
||||
|
||||
std::vector<XlaOp> x_coeffs(num_eqs);
|
||||
return std::vector<XlaOp>{lower_diagonal,
|
||||
main_diagonal,
|
||||
rhs,
|
||||
main_diag_after_elimination,
|
||||
upper_diagonal_coeffs,
|
||||
rhs_after_elimination};
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto values_after_fwd_transformation,
|
||||
ForEachIndex(
|
||||
num_eqs - 1, S32, forward_transformation_fn,
|
||||
{lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
|
||||
upper_diagonal_coeffs, rhs_after_elimination},
|
||||
"forward_transformation", builder));
|
||||
lower_diagonal = values_after_fwd_transformation[0];
|
||||
main_diagonal = values_after_fwd_transformation[1];
|
||||
rhs = values_after_fwd_transformation[2];
|
||||
main_diag_after_elimination = values_after_fwd_transformation[3];
|
||||
upper_diagonal_coeffs = values_after_fwd_transformation[4];
|
||||
rhs_after_elimination = values_after_fwd_transformation[5];
|
||||
|
||||
// Backward reduction.
|
||||
x_coeffs[num_eqs - 1] = rhs_after_elimination[num_eqs - 1] /
|
||||
main_diag_after_elimination[num_eqs - 1];
|
||||
for (int i = num_eqs - 2; i >= 0; i--) {
|
||||
x_coeffs[i] = (rhs_after_elimination[i] -
|
||||
upper_diagonal_coeffs[i] * x_coeffs[i + 1]) /
|
||||
main_diag_after_elimination[i];
|
||||
}
|
||||
// x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
|
||||
// main_diag_after_elimination[:, num_eqs - 1];
|
||||
x_coeffs =
|
||||
UpdateEq(x_coeffs, num_eqs - 1,
|
||||
Coefficient(rhs_after_elimination, num_eqs - 1) /
|
||||
Coefficient(main_diag_after_elimination, num_eqs - 1));
|
||||
auto bwd_reduction_fn =
|
||||
[num_eqs](XlaOp j, absl::Span<const XlaOp> values,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto x_coeffs = values[0];
|
||||
auto rhs_after_elimination = values[1];
|
||||
auto upper_diagonal_coeffs = values[2];
|
||||
auto main_diag_after_elimination = values[3];
|
||||
auto n = ScalarLike(j, num_eqs - 2);
|
||||
auto one = ScalarLike(j, 1);
|
||||
auto i = n - j;
|
||||
// for (int i = num_eqs - 2; i >= 0; i--)
|
||||
// x_coeffs[:, i] = (rhs_after_elimination[:, i] -
|
||||
// upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
|
||||
// main_diag_after_elimination[:, i];
|
||||
x_coeffs = UpdateEq(x_coeffs, i,
|
||||
(Coefficient(rhs_after_elimination, i) -
|
||||
Coefficient(upper_diagonal_coeffs, i) *
|
||||
Coefficient(x_coeffs, i + one)) /
|
||||
Coefficient(main_diag_after_elimination, i));
|
||||
return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
|
||||
upper_diagonal_coeffs,
|
||||
main_diag_after_elimination};
|
||||
};
|
||||
|
||||
return ConcatInDim(lower_diagonal.builder(), x_coeffs, rank - 1);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto values_after_bwd_reduction,
|
||||
ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
|
||||
{x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
|
||||
main_diag_after_elimination},
|
||||
"backward_reduction", builder));
|
||||
x_coeffs = values_after_bwd_reduction[0];
|
||||
|
||||
return x_coeffs;
|
||||
}
|
||||
|
||||
// Applies Thomas algorithm to solve a linear system where the linear operand
|
||||
|
|
|
@ -33,34 +33,28 @@ namespace {
|
|||
|
||||
class TridiagonalTest
|
||||
: public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<std::tuple<int, int, int, int>> {};
|
||||
public ::testing::WithParamInterface<std::tuple<int, int, int>> {};
|
||||
|
||||
XLA_TEST_P(TridiagonalTest, Solves) {
|
||||
const auto& spec = GetParam();
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
const int64 num_eqs = 5;
|
||||
const int64 num_rhs = 3;
|
||||
const int64 lower_diagonal_batch_size = std::get<0>(spec);
|
||||
const int64 main_diagonal_batch_size = std::get<1>(spec);
|
||||
const int64 upper_diagonal_batch_size = std::get<2>(spec);
|
||||
const int64 rhs_diagonal_batch_size = std::get<2>(spec);
|
||||
// TODO(belletti): parametrize num_rhs.
|
||||
const int64 batch_size = std::get<0>(spec);
|
||||
const int64 num_eqs = std::get<1>(spec);
|
||||
const int64 num_rhs = std::get<2>(spec);
|
||||
|
||||
const int64 max_batch_size =
|
||||
std::max({lower_diagonal_batch_size, main_diagonal_batch_size,
|
||||
upper_diagonal_batch_size, rhs_diagonal_batch_size});
|
||||
|
||||
Array3D<float> lower_diagonal(lower_diagonal_batch_size, 1, num_eqs);
|
||||
Array3D<float> main_diagonal(main_diagonal_batch_size, 1, num_eqs);
|
||||
Array3D<float> upper_diagonal(upper_diagonal_batch_size, 1, num_eqs);
|
||||
Array3D<float> rhs(rhs_diagonal_batch_size, num_rhs, num_eqs);
|
||||
Array3D<float> lower_diagonal(batch_size, 1, num_eqs);
|
||||
Array3D<float> main_diagonal(batch_size, 1, num_eqs);
|
||||
Array3D<float> upper_diagonal(batch_size, 1, num_eqs);
|
||||
Array3D<float> rhs(batch_size, num_rhs, num_eqs);
|
||||
|
||||
lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
|
||||
main_diagonal.FillRandom(0.05, /*mean=*/1.0,
|
||||
/*seed=*/max_batch_size * num_eqs);
|
||||
/*seed=*/batch_size * num_eqs);
|
||||
upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
|
||||
/*seed=*/2 * max_batch_size * num_eqs);
|
||||
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * max_batch_size * num_eqs);
|
||||
/*seed=*/2 * batch_size * num_eqs);
|
||||
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs);
|
||||
|
||||
XlaOp lower_diagonal_xla;
|
||||
XlaOp main_diagonal_xla;
|
||||
|
@ -119,10 +113,9 @@ XLA_TEST_P(TridiagonalTest, Solves) {
|
|||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
|
||||
::testing::Combine(::testing::Values(1, 8),
|
||||
::testing::Values(1, 8),
|
||||
::testing::Values(1, 8),
|
||||
::testing::Values(1, 8)));
|
||||
::testing::Combine(::testing::Values(1, 12),
|
||||
::testing::Values(4, 8),
|
||||
::testing::Values(1, 12)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tridiagonal
|
||||
|
|
|
@ -39,5 +39,6 @@ END
|
|||
On CPU, solution is computed via Gaussian elimination with or without partial
|
||||
pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
|
||||
library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
|
||||
Partial pivoting is not yet supported by XLA backends.
|
||||
END
|
||||
}
|
||||
|
|
|
@ -178,7 +178,8 @@ class LinearOperatorTriDiagMatrixTest(
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
|
||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
|
||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
|
||||
if not test_util.is_xla_enabled():
|
||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
|
||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
|
||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
|
||||
test.main()
|
||||
|
|
|
@ -22,8 +22,8 @@ import itertools
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
@ -78,8 +78,19 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
transpose_rhs=False,
|
||||
conjugate_rhs=False):
|
||||
with self.cached_session(use_gpu=True):
|
||||
result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format,
|
||||
transpose_rhs, conjugate_rhs)
|
||||
pivoting = True
|
||||
if hasattr(self, "pivoting"):
|
||||
pivoting = self.pivoting
|
||||
if test_util.is_xla_enabled() and pivoting:
|
||||
# Pivoting is not supported by xla backends.
|
||||
return
|
||||
result = linalg_impl.tridiagonal_solve(
|
||||
diags,
|
||||
rhs,
|
||||
diags_format,
|
||||
transpose_rhs,
|
||||
conjugate_rhs,
|
||||
partial_pivoting=pivoting)
|
||||
self.assertAllClose(self.evaluate(result), expected)
|
||||
|
||||
def _testWithLists(self,
|
||||
|
@ -94,8 +105,15 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
transpose_rhs, conjugate_rhs)
|
||||
|
||||
def _assertRaises(self, diags, rhs, diags_format="compact"):
|
||||
pivoting = True
|
||||
if hasattr(self, "pivoting"):
|
||||
pivoting = self.pivoting
|
||||
if test_util.is_xla_enabled() and pivoting:
|
||||
# Pivoting is not supported by xla backends.
|
||||
return
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_impl.tridiagonal_solve(diags, rhs, diags_format)
|
||||
linalg_impl.tridiagonal_solve(
|
||||
diags, rhs, diags_format, partial_pivoting=pivoting)
|
||||
|
||||
# Tests with various dtypes
|
||||
|
||||
|
@ -137,6 +155,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2])
|
||||
|
||||
def test0x0(self):
|
||||
if test_util.is_xla_enabled():
|
||||
# The following test crashes with XLA due to slicing 0 length tensors.
|
||||
return
|
||||
self._test(
|
||||
diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32),
|
||||
rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32),
|
||||
|
@ -153,10 +174,16 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])
|
||||
|
||||
def test1x1NotInvertible(self):
|
||||
if test_util.is_xla_enabled():
|
||||
# XLA implementation does not check invertibility.
|
||||
return
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[])
|
||||
|
||||
def test2x2NotInvertible(self):
|
||||
if test_util.is_xla_enabled():
|
||||
# XLA implementation does not check invertibility.
|
||||
return
|
||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||
self._testWithLists(
|
||||
diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[])
|
||||
|
@ -179,7 +206,7 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
expected=[5, -2, -5, 3])
|
||||
|
||||
def testNotInvertible(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled():
|
||||
# CuSparse gtsv routines don't raise errors for non-invertible
|
||||
# matrices.
|
||||
return
|
||||
|
@ -252,8 +279,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
def testSequenceFormatWithDummyElements(self):
|
||||
dummy = 20
|
||||
self._test(
|
||||
diags=(_tfconst([2, 1, 4, dummy]), _tfconst([1, 3, 2, 2]),
|
||||
_tfconst([dummy, 1, -1, 1])),
|
||||
diags=(_tfconst([2, 1, 4,
|
||||
dummy]), _tfconst([1, 3, 2,
|
||||
2]), _tfconst([dummy, 1, -1, 1])),
|
||||
rhs=_tfconst([1, 2, 3, 4]),
|
||||
expected=_tfconst([-9, 5, -4, 4]),
|
||||
diags_format="sequence")
|
||||
|
@ -261,8 +289,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
def testSequenceFormatWithBatching(self):
|
||||
self._test(
|
||||
diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]),
|
||||
_tfconst([[1, 3, 2, 2], [-1, -3, -2, -2]]),
|
||||
_tfconst([[1, -1, 1], [-1, 1, -1]])),
|
||||
_tfconst([[1, 3, 2, 2],
|
||||
[-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1,
|
||||
-1]])),
|
||||
rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]),
|
||||
expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]),
|
||||
diags_format="sequence")
|
||||
|
@ -373,6 +402,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
with backprop.GradientTape() as tape_rhs:
|
||||
tape_diags.watch(diags)
|
||||
tape_rhs.watch(rhs)
|
||||
if test_util.is_xla_enabled():
|
||||
# Pivoting is not supported by xla backends.
|
||||
return
|
||||
x = linalg_impl.tridiagonal_solve(
|
||||
diags,
|
||||
rhs,
|
||||
|
@ -526,6 +558,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
return
|
||||
diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
|
||||
rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
|
||||
if test_util.is_xla_enabled() and self.pivoting:
|
||||
# Pivoting is not supported by xla backends.
|
||||
return
|
||||
x = linalg_impl.tridiagonal_solve(
|
||||
diags, rhs, diags_format, partial_pivoting=self.pivoting)
|
||||
with self.cached_session(use_gpu=True) as sess:
|
||||
|
@ -601,6 +636,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
def testSequenceFormatWithUnknownDims(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
if test_util.is_xla_enabled() and self.pivoting:
|
||||
# Pivoting is not supported by xla backends.
|
||||
return
|
||||
superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
|
||||
diag = array_ops.placeholder(dtypes.float64, shape=[None])
|
||||
subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
|
||||
|
@ -641,9 +679,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
np.random.seed(seed)
|
||||
import scipy.sparse as sparse # pylint:disable=g-import-not-at-top
|
||||
# By being strictly diagonally dominant, we guarantee invertibility.d
|
||||
diag = 2* np.abs(np.random.randn(matrix_size)) + 4.1
|
||||
subdiag = 2* np.abs(np.random.randn(matrix_size-1))
|
||||
superdiag = 2* np.abs(np.random.randn(matrix_size-1))
|
||||
diag = 2 * np.abs(np.random.randn(matrix_size)) + 4.1
|
||||
subdiag = 2 * np.abs(np.random.randn(matrix_size - 1))
|
||||
superdiag = 2 * np.abs(np.random.randn(matrix_size - 1))
|
||||
matrix = sparse.diags([superdiag, diag, subdiag], [1, 0, -1]).toarray()
|
||||
vector = np.random.randn(batch_size, matrix_size, num_rhs)
|
||||
return (variables.Variable(np.tile(matrix, (batch_size, 1, 1))),
|
||||
|
@ -665,6 +703,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||
ops.device(device_id):
|
||||
diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs)
|
||||
# Pivoting is not supported by XLA backends.
|
||||
if test.is_xla_enabled() and pivoting:
|
||||
return
|
||||
x = linalg_impl.tridiagonal_solve(
|
||||
diags, rhs, partial_pivoting=pivoting)
|
||||
variables.global_variables_initializer().run()
|
||||
|
@ -673,9 +714,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
control_flow_ops.group(x),
|
||||
min_iters=10,
|
||||
store_memory_usage=False,
|
||||
name=test_name_format_string.format(
|
||||
device_name, matrix_size, batch_size, num_rhs,
|
||||
pivoting_name))
|
||||
name=test_name_format_string.format(device_name, matrix_size,
|
||||
batch_size, num_rhs,
|
||||
pivoting_name))
|
||||
|
||||
def benchmarkTridiagonalSolveOp_WithMatrixInput(self):
|
||||
self._benchmark(
|
||||
|
@ -687,9 +728,8 @@ class TridiagonalSolveOpTest(test.TestCase):
|
|||
def benchmarkTridiagonalSolveOp(self):
|
||||
self._benchmark(
|
||||
self._generateMatrixData,
|
||||
test_name_format_string=(
|
||||
"tridiagonal_solve_{}_matrix_size_{}_"
|
||||
"batch_size_{}_num_rhs_{}_{}"))
|
||||
test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_"
|
||||
"batch_size_{}_num_rhs_{}_{}"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -430,7 +430,9 @@ def tridiagonal_solve(diagonals,
|
|||
|
||||
Raises:
|
||||
ValueError: An unsupported type is provided as input, or when the input
|
||||
tensors have incorrect shapes.
|
||||
tensors have incorrect shapes.
|
||||
UnimplementedError: Whenever `partial_pivoting` is true and the backend is
|
||||
XLA.
|
||||
|
||||
[1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
|
||||
Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
|
||||
|
|
Loading…
Reference in New Issue