Unlock XLA non pivoting tridiagonal solver for all back-ends.
Also modify XLA implementation to speed up compilation time. PiperOrigin-RevId: 302553861 Change-Id: I2fb6108fa146f413a8271f562267e759f1dc86f6
This commit is contained in:
parent
f68e082e2b
commit
309574cdb2
|
@ -225,7 +225,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
|
||||||
with self.session() as sess, self.test_scope():
|
with self.session() as sess, self.test_scope():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
errors_impl.UnimplementedError,
|
errors_impl.UnimplementedError,
|
||||||
"Pivoting is not yet supported in XLA tridiagonal solver."):
|
"Current implementation does not yet support pivoting."):
|
||||||
diags = array_ops.placeholder(
|
diags = array_ops.placeholder(
|
||||||
shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
|
shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
|
||||||
rhs = array_ops.placeholder(
|
rhs = array_ops.placeholder(
|
||||||
|
|
|
@ -17,24 +17,27 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
|
#include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class TridiagonalSolveOp : public XlaOpKernel {
|
class TridiagonalSolveOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("partial_pivoting", &pivoting_));
|
|
||||||
}
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, !pivoting_,
|
|
||||||
errors::Unimplemented(
|
|
||||||
"Pivoting is not yet supported in XLA tridiagonal solver."));
|
|
||||||
|
|
||||||
auto diagonals = ctx->Input(0);
|
auto diagonals = ctx->Input(0);
|
||||||
auto rhs = ctx->Input(1);
|
auto rhs = ctx->Input(1);
|
||||||
|
bool partial_pivoting = false;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
GetNodeAttr(def(), "partial_pivoting", &partial_pivoting));
|
||||||
|
if (partial_pivoting) {
|
||||||
|
ctx->SetStatus(errors::Unimplemented(
|
||||||
|
"Current implementation does not yet support pivoting."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs);
|
auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs);
|
||||||
if (!result.ok()) {
|
if (!result.ok()) {
|
||||||
|
@ -43,16 +46,9 @@ class TridiagonalSolveOp : public XlaOpKernel {
|
||||||
}
|
}
|
||||||
ctx->SetOutput(0, result.ValueOrDie());
|
ctx->SetOutput(0, result.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
bool pivoting_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(belletti): address test breakage in tridiagonal_solve_op_test_xla_gpu.py
|
REGISTER_XLA_OP(Name("TridiagonalSolve").TypeConstraint("T", kFloatTypes),
|
||||||
// to support all XLA devices.
|
|
||||||
REGISTER_XLA_OP(Name("TridiagonalSolve")
|
|
||||||
.Device("XLA_TPU_JIT")
|
|
||||||
.TypeConstraint("T", kFloatTypes),
|
|
||||||
TridiagonalSolveOp);
|
TridiagonalSolveOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -531,12 +531,15 @@ cc_library(
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
xla_test(
|
xla_test(
|
||||||
name = "tridiagonal_test",
|
name = "tridiagonal_test",
|
||||||
srcs = ["tridiagonal_test.cc"],
|
srcs = ["tridiagonal_test.cc"],
|
||||||
|
real_hardware_only = True,
|
||||||
|
shard_count = 10,
|
||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
|
@ -33,13 +34,6 @@ namespace tridiagonal {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TridiagonalSystemShape {
|
|
||||||
const int64 rank;
|
|
||||||
const int64 num_equations;
|
|
||||||
TridiagonalSystemShape(int64 rk, int64 num_eqs)
|
|
||||||
: rank(rk), num_equations(num_eqs) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
|
Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
|
||||||
int64 expected, const std::string& op_name) {
|
int64 expected, const std::string& op_name) {
|
||||||
const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
|
const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2);
|
||||||
|
@ -53,7 +47,7 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
|
StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal,
|
||||||
XlaOp main_diagonal,
|
XlaOp main_diagonal,
|
||||||
XlaOp upper_diagonal,
|
XlaOp upper_diagonal,
|
||||||
XlaOp rhs) {
|
XlaOp rhs) {
|
||||||
|
@ -111,11 +105,27 @@ StatusOr<TridiagonalSystemShape> CheckSystemAndReturnShape(XlaOp lower_diagonal,
|
||||||
TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
|
TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1,
|
||||||
"upper diagonal"));
|
"upper diagonal"));
|
||||||
|
|
||||||
return TridiagonalSystemShape(rank, num_equations);
|
return num_equations;
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Coefficient(XlaOp operand, int64 i) {
|
XlaOp Coefficient(XlaOp operand, int32 i) {
|
||||||
return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1});
|
return DynamicSliceInMinorDims(operand,
|
||||||
|
/*starts=*/{ConstantR0(operand.builder(), i)},
|
||||||
|
/*sizes=*/{1});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp Coefficient(XlaOp operand, XlaOp i) {
|
||||||
|
return DynamicSliceInMinorDims(operand,
|
||||||
|
/*starts=*/{i}, /*sizes=*/{1});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) {
|
||||||
|
return DynamicUpdateSliceInMinorDims(
|
||||||
|
updated, update, /*starts=*/{ConstantR0(updated.builder(), i)});
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) {
|
||||||
|
return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -134,48 +144,133 @@ XlaOp Coefficient(XlaOp operand, int64 i) {
|
||||||
// solution will have the shape [..., num_rhs, num_equations].
|
// solution will have the shape [..., num_rhs, num_equations].
|
||||||
StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
|
StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal,
|
||||||
XlaOp upper_diagonal, XlaOp rhs) {
|
XlaOp upper_diagonal, XlaOp rhs) {
|
||||||
TF_ASSIGN_OR_RETURN(TridiagonalSystemShape system_shape,
|
XlaBuilder* builder = lower_diagonal.builder();
|
||||||
CheckSystemAndReturnShape(lower_diagonal, main_diagonal,
|
|
||||||
upper_diagonal, rhs));
|
|
||||||
|
|
||||||
auto rank = system_shape.rank;
|
TF_ASSIGN_OR_RETURN(int64 num_eqs,
|
||||||
auto num_eqs = system_shape.num_equations;
|
CheckSystemAndReturnNumEquations(
|
||||||
|
lower_diagonal, main_diagonal, upper_diagonal, rhs));
|
||||||
|
|
||||||
std::vector<XlaOp> main_diag_after_elimination(num_eqs);
|
XlaOp main_diag_after_elimination = ZerosLike(main_diagonal);
|
||||||
std::vector<XlaOp> rhs_after_elimination(num_eqs);
|
XlaOp rhs_after_elimination = ZerosLike(rhs);
|
||||||
std::vector<XlaOp> upper_diagonal_coeffs(num_eqs);
|
XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal);
|
||||||
|
XlaOp x_coeffs = ZerosLike(rhs);
|
||||||
|
|
||||||
main_diag_after_elimination[0] = Coefficient(main_diagonal, 0);
|
// main_diag_after_elimination[:, 0] = main_diagonal[:, 0];
|
||||||
rhs_after_elimination[0] = Coefficient(rhs, 0);
|
main_diag_after_elimination =
|
||||||
for (int64 i = 0; i < num_eqs - 1; i++) {
|
UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0));
|
||||||
upper_diagonal_coeffs[i] = Coefficient(upper_diagonal, i);
|
|
||||||
}
|
// rhs_after_elimination[:, 0] = rhs[:, 0];
|
||||||
|
rhs_after_elimination =
|
||||||
|
UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0));
|
||||||
|
|
||||||
|
auto preparation_body_fn =
|
||||||
|
[](XlaOp i, absl::Span<const XlaOp> values,
|
||||||
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto upper_diagonal_coeffs = values[0];
|
||||||
|
auto upper_diagonal = values[1];
|
||||||
|
// upper_diagonal_coeffs[:, i] = upper_diagonal[:, i];
|
||||||
|
upper_diagonal_coeffs =
|
||||||
|
UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i));
|
||||||
|
return std::vector<XlaOp>{upper_diagonal_coeffs, upper_diagonal};
|
||||||
|
};
|
||||||
|
TF_ASSIGN_OR_RETURN(auto values_after_preparation,
|
||||||
|
ForEachIndex(num_eqs - 1, S32, preparation_body_fn,
|
||||||
|
{upper_diagonal_coeffs, upper_diagonal},
|
||||||
|
"preparation", builder));
|
||||||
|
upper_diagonal_coeffs = values_after_preparation[0];
|
||||||
|
|
||||||
// Forward transformation.
|
// Forward transformation.
|
||||||
for (int64 i = 1; i < num_eqs; i++) {
|
auto forward_transformation_fn =
|
||||||
|
[](XlaOp i_minus_one, absl::Span<const XlaOp> values,
|
||||||
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto lower_diagonal = values[0];
|
||||||
|
auto main_diagonal = values[1];
|
||||||
|
auto rhs = values[2];
|
||||||
|
auto main_diag_after_elimination = values[3];
|
||||||
|
auto upper_diagonal_coeffs = values[4];
|
||||||
|
auto rhs_after_elimination = values[5];
|
||||||
|
|
||||||
|
auto one = ScalarLike(i_minus_one, 1);
|
||||||
|
auto i = i_minus_one + one;
|
||||||
auto lower_diagonal_i = Coefficient(lower_diagonal, i);
|
auto lower_diagonal_i = Coefficient(lower_diagonal, i);
|
||||||
auto main_diagonal_i = Coefficient(main_diagonal, i);
|
auto main_diagonal_i = Coefficient(main_diagonal, i);
|
||||||
auto rhs_i = Coefficient(rhs, i);
|
auto rhs_i = Coefficient(rhs, i);
|
||||||
|
|
||||||
auto w_i = lower_diagonal_i / main_diag_after_elimination[i - 1];
|
auto w_i =
|
||||||
|
lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one);
|
||||||
|
|
||||||
main_diag_after_elimination[i] =
|
// main_diag_after_elimination[:, i] =
|
||||||
main_diagonal_i - w_i * upper_diagonal_coeffs[i - 1];
|
// main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1];
|
||||||
rhs_after_elimination[i] = rhs_i - w_i * rhs_after_elimination[i - 1];
|
main_diag_after_elimination = UpdateEq(
|
||||||
}
|
main_diag_after_elimination, i,
|
||||||
|
main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one));
|
||||||
|
// rhs_after_elimination[:, i] =
|
||||||
|
// rhs_i - w_i * rhs_after_elimination[:, i - 1];
|
||||||
|
rhs_after_elimination =
|
||||||
|
UpdateEq(rhs_after_elimination, i,
|
||||||
|
rhs_i - w_i * Coefficient(rhs_after_elimination, i - one));
|
||||||
|
|
||||||
std::vector<XlaOp> x_coeffs(num_eqs);
|
return std::vector<XlaOp>{lower_diagonal,
|
||||||
|
main_diagonal,
|
||||||
|
rhs,
|
||||||
|
main_diag_after_elimination,
|
||||||
|
upper_diagonal_coeffs,
|
||||||
|
rhs_after_elimination};
|
||||||
|
};
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto values_after_fwd_transformation,
|
||||||
|
ForEachIndex(
|
||||||
|
num_eqs - 1, S32, forward_transformation_fn,
|
||||||
|
{lower_diagonal, main_diagonal, rhs, main_diag_after_elimination,
|
||||||
|
upper_diagonal_coeffs, rhs_after_elimination},
|
||||||
|
"forward_transformation", builder));
|
||||||
|
lower_diagonal = values_after_fwd_transformation[0];
|
||||||
|
main_diagonal = values_after_fwd_transformation[1];
|
||||||
|
rhs = values_after_fwd_transformation[2];
|
||||||
|
main_diag_after_elimination = values_after_fwd_transformation[3];
|
||||||
|
upper_diagonal_coeffs = values_after_fwd_transformation[4];
|
||||||
|
rhs_after_elimination = values_after_fwd_transformation[5];
|
||||||
|
|
||||||
// Backward reduction.
|
// Backward reduction.
|
||||||
x_coeffs[num_eqs - 1] = rhs_after_elimination[num_eqs - 1] /
|
// x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] /
|
||||||
main_diag_after_elimination[num_eqs - 1];
|
// main_diag_after_elimination[:, num_eqs - 1];
|
||||||
for (int i = num_eqs - 2; i >= 0; i--) {
|
x_coeffs =
|
||||||
x_coeffs[i] = (rhs_after_elimination[i] -
|
UpdateEq(x_coeffs, num_eqs - 1,
|
||||||
upper_diagonal_coeffs[i] * x_coeffs[i + 1]) /
|
Coefficient(rhs_after_elimination, num_eqs - 1) /
|
||||||
main_diag_after_elimination[i];
|
Coefficient(main_diag_after_elimination, num_eqs - 1));
|
||||||
}
|
auto bwd_reduction_fn =
|
||||||
|
[num_eqs](XlaOp j, absl::Span<const XlaOp> values,
|
||||||
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
|
auto x_coeffs = values[0];
|
||||||
|
auto rhs_after_elimination = values[1];
|
||||||
|
auto upper_diagonal_coeffs = values[2];
|
||||||
|
auto main_diag_after_elimination = values[3];
|
||||||
|
auto n = ScalarLike(j, num_eqs - 2);
|
||||||
|
auto one = ScalarLike(j, 1);
|
||||||
|
auto i = n - j;
|
||||||
|
// for (int i = num_eqs - 2; i >= 0; i--)
|
||||||
|
// x_coeffs[:, i] = (rhs_after_elimination[:, i] -
|
||||||
|
// upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) /
|
||||||
|
// main_diag_after_elimination[:, i];
|
||||||
|
x_coeffs = UpdateEq(x_coeffs, i,
|
||||||
|
(Coefficient(rhs_after_elimination, i) -
|
||||||
|
Coefficient(upper_diagonal_coeffs, i) *
|
||||||
|
Coefficient(x_coeffs, i + one)) /
|
||||||
|
Coefficient(main_diag_after_elimination, i));
|
||||||
|
return std::vector<XlaOp>{x_coeffs, rhs_after_elimination,
|
||||||
|
upper_diagonal_coeffs,
|
||||||
|
main_diag_after_elimination};
|
||||||
|
};
|
||||||
|
|
||||||
return ConcatInDim(lower_diagonal.builder(), x_coeffs, rank - 1);
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto values_after_bwd_reduction,
|
||||||
|
ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn,
|
||||||
|
{x_coeffs, rhs_after_elimination, upper_diagonal_coeffs,
|
||||||
|
main_diag_after_elimination},
|
||||||
|
"backward_reduction", builder));
|
||||||
|
x_coeffs = values_after_bwd_reduction[0];
|
||||||
|
|
||||||
|
return x_coeffs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Applies Thomas algorithm to solve a linear system where the linear operand
|
// Applies Thomas algorithm to solve a linear system where the linear operand
|
||||||
|
|
|
@ -33,34 +33,28 @@ namespace {
|
||||||
|
|
||||||
class TridiagonalTest
|
class TridiagonalTest
|
||||||
: public ClientLibraryTestBase,
|
: public ClientLibraryTestBase,
|
||||||
public ::testing::WithParamInterface<std::tuple<int, int, int, int>> {};
|
public ::testing::WithParamInterface<std::tuple<int, int, int>> {};
|
||||||
|
|
||||||
XLA_TEST_P(TridiagonalTest, Solves) {
|
XLA_TEST_P(TridiagonalTest, Solves) {
|
||||||
const auto& spec = GetParam();
|
const auto& spec = GetParam();
|
||||||
xla::XlaBuilder builder(TestName());
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
const int64 num_eqs = 5;
|
// TODO(belletti): parametrize num_rhs.
|
||||||
const int64 num_rhs = 3;
|
const int64 batch_size = std::get<0>(spec);
|
||||||
const int64 lower_diagonal_batch_size = std::get<0>(spec);
|
const int64 num_eqs = std::get<1>(spec);
|
||||||
const int64 main_diagonal_batch_size = std::get<1>(spec);
|
const int64 num_rhs = std::get<2>(spec);
|
||||||
const int64 upper_diagonal_batch_size = std::get<2>(spec);
|
|
||||||
const int64 rhs_diagonal_batch_size = std::get<2>(spec);
|
|
||||||
|
|
||||||
const int64 max_batch_size =
|
Array3D<float> lower_diagonal(batch_size, 1, num_eqs);
|
||||||
std::max({lower_diagonal_batch_size, main_diagonal_batch_size,
|
Array3D<float> main_diagonal(batch_size, 1, num_eqs);
|
||||||
upper_diagonal_batch_size, rhs_diagonal_batch_size});
|
Array3D<float> upper_diagonal(batch_size, 1, num_eqs);
|
||||||
|
Array3D<float> rhs(batch_size, num_rhs, num_eqs);
|
||||||
Array3D<float> lower_diagonal(lower_diagonal_batch_size, 1, num_eqs);
|
|
||||||
Array3D<float> main_diagonal(main_diagonal_batch_size, 1, num_eqs);
|
|
||||||
Array3D<float> upper_diagonal(upper_diagonal_batch_size, 1, num_eqs);
|
|
||||||
Array3D<float> rhs(rhs_diagonal_batch_size, num_rhs, num_eqs);
|
|
||||||
|
|
||||||
lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
|
lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
|
||||||
main_diagonal.FillRandom(0.05, /*mean=*/1.0,
|
main_diagonal.FillRandom(0.05, /*mean=*/1.0,
|
||||||
/*seed=*/max_batch_size * num_eqs);
|
/*seed=*/batch_size * num_eqs);
|
||||||
upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
|
upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
|
||||||
/*seed=*/2 * max_batch_size * num_eqs);
|
/*seed=*/2 * batch_size * num_eqs);
|
||||||
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * max_batch_size * num_eqs);
|
rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs);
|
||||||
|
|
||||||
XlaOp lower_diagonal_xla;
|
XlaOp lower_diagonal_xla;
|
||||||
XlaOp main_diagonal_xla;
|
XlaOp main_diagonal_xla;
|
||||||
|
@ -119,10 +113,9 @@ XLA_TEST_P(TridiagonalTest, Solves) {
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
|
INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
|
||||||
::testing::Combine(::testing::Values(1, 8),
|
::testing::Combine(::testing::Values(1, 12),
|
||||||
::testing::Values(1, 8),
|
::testing::Values(4, 8),
|
||||||
::testing::Values(1, 8),
|
::testing::Values(1, 12)));
|
||||||
::testing::Values(1, 8)));
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tridiagonal
|
} // namespace tridiagonal
|
||||||
|
|
|
@ -39,5 +39,6 @@ END
|
||||||
On CPU, solution is computed via Gaussian elimination with or without partial
|
On CPU, solution is computed via Gaussian elimination with or without partial
|
||||||
pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
|
pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE
|
||||||
library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
|
library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
|
||||||
|
Partial pivoting is not yet supported by XLA backends.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,6 +178,7 @@ class LinearOperatorTriDiagMatrixTest(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
if not test_util.is_xla_enabled():
|
||||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
|
linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
|
||||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
|
linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
|
||||||
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
|
linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
|
||||||
|
|
|
@ -22,8 +22,8 @@ import itertools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -78,8 +78,19 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
transpose_rhs=False,
|
transpose_rhs=False,
|
||||||
conjugate_rhs=False):
|
conjugate_rhs=False):
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format,
|
pivoting = True
|
||||||
transpose_rhs, conjugate_rhs)
|
if hasattr(self, "pivoting"):
|
||||||
|
pivoting = self.pivoting
|
||||||
|
if test_util.is_xla_enabled() and pivoting:
|
||||||
|
# Pivoting is not supported by xla backends.
|
||||||
|
return
|
||||||
|
result = linalg_impl.tridiagonal_solve(
|
||||||
|
diags,
|
||||||
|
rhs,
|
||||||
|
diags_format,
|
||||||
|
transpose_rhs,
|
||||||
|
conjugate_rhs,
|
||||||
|
partial_pivoting=pivoting)
|
||||||
self.assertAllClose(self.evaluate(result), expected)
|
self.assertAllClose(self.evaluate(result), expected)
|
||||||
|
|
||||||
def _testWithLists(self,
|
def _testWithLists(self,
|
||||||
|
@ -94,8 +105,15 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
transpose_rhs, conjugate_rhs)
|
transpose_rhs, conjugate_rhs)
|
||||||
|
|
||||||
def _assertRaises(self, diags, rhs, diags_format="compact"):
|
def _assertRaises(self, diags, rhs, diags_format="compact"):
|
||||||
|
pivoting = True
|
||||||
|
if hasattr(self, "pivoting"):
|
||||||
|
pivoting = self.pivoting
|
||||||
|
if test_util.is_xla_enabled() and pivoting:
|
||||||
|
# Pivoting is not supported by xla backends.
|
||||||
|
return
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
linalg_impl.tridiagonal_solve(diags, rhs, diags_format)
|
linalg_impl.tridiagonal_solve(
|
||||||
|
diags, rhs, diags_format, partial_pivoting=pivoting)
|
||||||
|
|
||||||
# Tests with various dtypes
|
# Tests with various dtypes
|
||||||
|
|
||||||
|
@ -137,6 +155,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2])
|
self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2])
|
||||||
|
|
||||||
def test0x0(self):
|
def test0x0(self):
|
||||||
|
if test_util.is_xla_enabled():
|
||||||
|
# The following test crashes with XLA due to slicing 0 length tensors.
|
||||||
|
return
|
||||||
self._test(
|
self._test(
|
||||||
diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32),
|
diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32),
|
||||||
rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32),
|
rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32),
|
||||||
|
@ -153,10 +174,16 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])
|
diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])
|
||||||
|
|
||||||
def test1x1NotInvertible(self):
|
def test1x1NotInvertible(self):
|
||||||
|
if test_util.is_xla_enabled():
|
||||||
|
# XLA implementation does not check invertibility.
|
||||||
|
return
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[])
|
self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[])
|
||||||
|
|
||||||
def test2x2NotInvertible(self):
|
def test2x2NotInvertible(self):
|
||||||
|
if test_util.is_xla_enabled():
|
||||||
|
# XLA implementation does not check invertibility.
|
||||||
|
return
|
||||||
with self.assertRaises(errors_impl.InvalidArgumentError):
|
with self.assertRaises(errors_impl.InvalidArgumentError):
|
||||||
self._testWithLists(
|
self._testWithLists(
|
||||||
diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[])
|
diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[])
|
||||||
|
@ -179,7 +206,7 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
expected=[5, -2, -5, 3])
|
expected=[5, -2, -5, 3])
|
||||||
|
|
||||||
def testNotInvertible(self):
|
def testNotInvertible(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled():
|
||||||
# CuSparse gtsv routines don't raise errors for non-invertible
|
# CuSparse gtsv routines don't raise errors for non-invertible
|
||||||
# matrices.
|
# matrices.
|
||||||
return
|
return
|
||||||
|
@ -252,8 +279,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
def testSequenceFormatWithDummyElements(self):
|
def testSequenceFormatWithDummyElements(self):
|
||||||
dummy = 20
|
dummy = 20
|
||||||
self._test(
|
self._test(
|
||||||
diags=(_tfconst([2, 1, 4, dummy]), _tfconst([1, 3, 2, 2]),
|
diags=(_tfconst([2, 1, 4,
|
||||||
_tfconst([dummy, 1, -1, 1])),
|
dummy]), _tfconst([1, 3, 2,
|
||||||
|
2]), _tfconst([dummy, 1, -1, 1])),
|
||||||
rhs=_tfconst([1, 2, 3, 4]),
|
rhs=_tfconst([1, 2, 3, 4]),
|
||||||
expected=_tfconst([-9, 5, -4, 4]),
|
expected=_tfconst([-9, 5, -4, 4]),
|
||||||
diags_format="sequence")
|
diags_format="sequence")
|
||||||
|
@ -261,8 +289,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
def testSequenceFormatWithBatching(self):
|
def testSequenceFormatWithBatching(self):
|
||||||
self._test(
|
self._test(
|
||||||
diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]),
|
diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]),
|
||||||
_tfconst([[1, 3, 2, 2], [-1, -3, -2, -2]]),
|
_tfconst([[1, 3, 2, 2],
|
||||||
_tfconst([[1, -1, 1], [-1, 1, -1]])),
|
[-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1,
|
||||||
|
-1]])),
|
||||||
rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]),
|
rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]),
|
||||||
expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]),
|
expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]),
|
||||||
diags_format="sequence")
|
diags_format="sequence")
|
||||||
|
@ -373,6 +402,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
with backprop.GradientTape() as tape_rhs:
|
with backprop.GradientTape() as tape_rhs:
|
||||||
tape_diags.watch(diags)
|
tape_diags.watch(diags)
|
||||||
tape_rhs.watch(rhs)
|
tape_rhs.watch(rhs)
|
||||||
|
if test_util.is_xla_enabled():
|
||||||
|
# Pivoting is not supported by xla backends.
|
||||||
|
return
|
||||||
x = linalg_impl.tridiagonal_solve(
|
x = linalg_impl.tridiagonal_solve(
|
||||||
diags,
|
diags,
|
||||||
rhs,
|
rhs,
|
||||||
|
@ -526,6 +558,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
return
|
return
|
||||||
diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
|
diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
|
||||||
rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
|
rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
|
||||||
|
if test_util.is_xla_enabled() and self.pivoting:
|
||||||
|
# Pivoting is not supported by xla backends.
|
||||||
|
return
|
||||||
x = linalg_impl.tridiagonal_solve(
|
x = linalg_impl.tridiagonal_solve(
|
||||||
diags, rhs, diags_format, partial_pivoting=self.pivoting)
|
diags, rhs, diags_format, partial_pivoting=self.pivoting)
|
||||||
with self.cached_session(use_gpu=True) as sess:
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
|
@ -601,6 +636,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
def testSequenceFormatWithUnknownDims(self):
|
def testSequenceFormatWithUnknownDims(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return
|
return
|
||||||
|
if test_util.is_xla_enabled() and self.pivoting:
|
||||||
|
# Pivoting is not supported by xla backends.
|
||||||
|
return
|
||||||
superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
|
superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
|
||||||
diag = array_ops.placeholder(dtypes.float64, shape=[None])
|
diag = array_ops.placeholder(dtypes.float64, shape=[None])
|
||||||
subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
|
subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
|
||||||
|
@ -665,6 +703,9 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
session.Session(config=benchmark.benchmark_config()) as sess, \
|
session.Session(config=benchmark.benchmark_config()) as sess, \
|
||||||
ops.device(device_id):
|
ops.device(device_id):
|
||||||
diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs)
|
diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs)
|
||||||
|
# Pivoting is not supported by XLA backends.
|
||||||
|
if test.is_xla_enabled() and pivoting:
|
||||||
|
return
|
||||||
x = linalg_impl.tridiagonal_solve(
|
x = linalg_impl.tridiagonal_solve(
|
||||||
diags, rhs, partial_pivoting=pivoting)
|
diags, rhs, partial_pivoting=pivoting)
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
@ -673,8 +714,8 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
control_flow_ops.group(x),
|
control_flow_ops.group(x),
|
||||||
min_iters=10,
|
min_iters=10,
|
||||||
store_memory_usage=False,
|
store_memory_usage=False,
|
||||||
name=test_name_format_string.format(
|
name=test_name_format_string.format(device_name, matrix_size,
|
||||||
device_name, matrix_size, batch_size, num_rhs,
|
batch_size, num_rhs,
|
||||||
pivoting_name))
|
pivoting_name))
|
||||||
|
|
||||||
def benchmarkTridiagonalSolveOp_WithMatrixInput(self):
|
def benchmarkTridiagonalSolveOp_WithMatrixInput(self):
|
||||||
|
@ -687,8 +728,7 @@ class TridiagonalSolveOpTest(test.TestCase):
|
||||||
def benchmarkTridiagonalSolveOp(self):
|
def benchmarkTridiagonalSolveOp(self):
|
||||||
self._benchmark(
|
self._benchmark(
|
||||||
self._generateMatrixData,
|
self._generateMatrixData,
|
||||||
test_name_format_string=(
|
test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_"
|
||||||
"tridiagonal_solve_{}_matrix_size_{}_"
|
|
||||||
"batch_size_{}_num_rhs_{}_{}"))
|
"batch_size_{}_num_rhs_{}_{}"))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -431,6 +431,8 @@ def tridiagonal_solve(diagonals,
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: An unsupported type is provided as input, or when the input
|
ValueError: An unsupported type is provided as input, or when the input
|
||||||
tensors have incorrect shapes.
|
tensors have incorrect shapes.
|
||||||
|
UnimplementedError: Whenever `partial_pivoting` is true and the backend is
|
||||||
|
XLA.
|
||||||
|
|
||||||
[1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
|
[1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
|
||||||
Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
|
Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
|
||||||
|
|
Loading…
Reference in New Issue