Change the condition related to the accepted error tolerance.
Use the condition from the paper: http://www.netlib.org/lapack/lawnspdf/lawn15.pdf (however apply square on both sides of the comparison) With this change, we can enable the tests on CPU and GPU backends. PiperOrigin-RevId: 249792805
This commit is contained in:
parent
43c7131efc
commit
87aaed125a
@ -472,11 +472,6 @@ cc_library(
|
||||
xla_test(
|
||||
name = "svd_test",
|
||||
srcs = ["svd_test.cc"],
|
||||
# Blacklisted because the tests are flaky.
|
||||
blacklisted_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
real_hardware_only = True,
|
||||
shard_count = 10,
|
||||
tags = ["optonly"],
|
||||
|
@ -46,23 +46,34 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
|
||||
return ConvertElementType(indicator, type);
|
||||
}
|
||||
|
||||
XlaOp GetDiagonalMask(XlaOp x, int diagonal) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
auto n_dims = static_cast<int32>(shape.rank());
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
auto m = shape.dimensions(n_dims - 2);
|
||||
auto n = shape.dimensions(n_dims - 1);
|
||||
absl::Span<const int64> major_dims =
|
||||
AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
|
||||
auto a = Iota(builder, S32, n);
|
||||
auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal);
|
||||
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
|
||||
auto mask = Broadcast(indicator, major_dims);
|
||||
return mask;
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp GetMatrixDiagonal(XlaOp x, int k) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
const int64 n_dims = shape.rank();
|
||||
auto n_dims = static_cast<int32>(shape.rank());
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
const int64 m = shape.dimensions(n_dims - 2);
|
||||
const int64 n = shape.dimensions(n_dims - 1);
|
||||
|
||||
auto offset = ConstantR0WithType(builder, S32, k);
|
||||
|
||||
absl::Span<const int64> major_dims =
|
||||
AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
|
||||
auto a = Iota(builder, S32, n);
|
||||
auto b = Iota(builder, S32, m) + offset;
|
||||
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
|
||||
auto mask = Broadcast(indicator, major_dims);
|
||||
auto mask = GetDiagonalMask(x, k);
|
||||
|
||||
// TPUs don't support S64 add reduction at the moment. But fortunately
|
||||
// OR-reductions work just as well for integers.
|
||||
|
@ -31,6 +31,10 @@ namespace xla {
|
||||
// else.
|
||||
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
|
||||
|
||||
// Returns a mask where the 'diagonal'-th diagonal is true and everything else
|
||||
// is false.
|
||||
XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0);
|
||||
|
||||
// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the
|
||||
// main diagonal, and k<0 for diagonals below the main diagonal.
|
||||
//
|
||||
|
@ -75,11 +75,6 @@ struct OneSidedJacobiRotation {
|
||||
JacobiRotation rot_r;
|
||||
};
|
||||
|
||||
struct FrobeniusNorms {
|
||||
XlaOp off_diagonal_norm;
|
||||
XlaOp total_norm;
|
||||
};
|
||||
|
||||
// Householder reflection on the trailing elements of a vector.
|
||||
//
|
||||
// H = I - beta * [1, v]' * [1, v]
|
||||
@ -567,27 +562,26 @@ StatusOr<SVDResult> OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q,
|
||||
return svd_result;
|
||||
}
|
||||
|
||||
StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w) {
|
||||
StatusOr<XlaOp> ComputeToleranceComparison(XlaOp w, XlaOp epsilon) {
|
||||
XlaBuilder* builder = w.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
|
||||
const int64 num_dims = shape.rank();
|
||||
auto frobenius_norm =
|
||||
Sqrt(Reduce(Square(w), ScalarLike(w, 0.0),
|
||||
CreateScalarAddComputation(shape.element_type(), builder),
|
||||
{num_dims - 2, num_dims - 1}));
|
||||
auto diag = GetMatrixDiagonal(w);
|
||||
auto diag_square =
|
||||
Reduce(Square(diag), ScalarLike(w, 0.0),
|
||||
CreateScalarAddComputation(shape.element_type(), builder),
|
||||
{num_dims - 2});
|
||||
|
||||
FrobeniusNorms frobenius_norms;
|
||||
|
||||
frobenius_norms.off_diagonal_norm =
|
||||
Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0)));
|
||||
frobenius_norms.total_norm = frobenius_norm;
|
||||
|
||||
return frobenius_norms;
|
||||
auto num_dims = static_cast<int32>(shape.rank());
|
||||
int64 n = shape.dimensions(num_dims - 1);
|
||||
shape.set_dimensions(num_dims - 2, n);
|
||||
auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n});
|
||||
auto diag = GetMatrixDiagonal(w_sliced);
|
||||
diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag);
|
||||
std::vector<int64> broadcasted_dims(num_dims - 1);
|
||||
std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
|
||||
auto broadcast_to_rows =
|
||||
BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
|
||||
broadcasted_dims.back() = num_dims - 1;
|
||||
auto broadcast_to_columns =
|
||||
BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
|
||||
// Compute w_{i,i} * w_{j,j} * epsilon^2 < (w_{i,j})^2
|
||||
return Lt(
|
||||
broadcast_to_rows * broadcast_to_columns * epsilon * epsilon,
|
||||
Square(Select(GetDiagonalMask(w_sliced), ZerosLike(w_sliced), w_sliced)));
|
||||
}
|
||||
|
||||
// Main boby of One-sided Jacobi Method.
|
||||
@ -603,13 +597,13 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
|
||||
auto max_sweeps = ScalarLike(k, max_sweep_updates);
|
||||
auto sweep_update_cond = Gt(max_sweeps, k);
|
||||
|
||||
auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie();
|
||||
auto tol = norms.total_norm * values[4];
|
||||
auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm),
|
||||
xla::ConstantR0<bool>(cond_builder, false),
|
||||
CreateScalarOrComputation(PRED, cond_builder));
|
||||
TF_ASSIGN_OR_RETURN(auto tolerance_comparison,
|
||||
ComputeToleranceComparison(values[3], values[4]));
|
||||
auto tolerance_cond = ReduceAll(
|
||||
tolerance_comparison, xla::ConstantR0<bool>(cond_builder, false),
|
||||
CreateScalarOrComputation(PRED, cond_builder));
|
||||
|
||||
return And(sweep_update_cond, tol_cond);
|
||||
return And(sweep_update_cond, tolerance_cond);
|
||||
};
|
||||
|
||||
auto while_body_fn =
|
||||
|
@ -184,12 +184,14 @@ XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) {
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) {
|
||||
// Too slow on the interpreter backend.
|
||||
XLA_TEST_F(SVDTest,
|
||||
DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array2D<float> a_val = GenerateRandomMatrix(512, 128);
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||
auto result = SVD(a, 100, 1e-6);
|
||||
auto result = SVD(a, 100, 1e-4);
|
||||
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||
@ -201,7 +203,7 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) {
|
||||
Array2D<float> a_val = GenerateRandomMatrix(128, 256);
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||
auto result = SVD(a, 100, 1e-6);
|
||||
auto result = SVD(a, 100, 1e-4);
|
||||
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||
@ -213,40 +215,44 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) {
|
||||
Array2D<float> a_val = GenerateRandomMatrix(256, 128);
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||
auto result = SVD(a, 100, 1e-6);
|
||||
auto result = SVD(a, 100, 1e-4);
|
||||
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
// TODO(b/133353535): This test seems too sensitive to the particular choice of
|
||||
// random matrix.
|
||||
XLA_TEST_F(SVDTest, DISABLED_Various_Size_Random_Matrix_128x512) {
|
||||
// Too slow on the interpreter backend.
|
||||
XLA_TEST_F(SVDTest,
|
||||
DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array2D<float> a_val = GenerateRandomMatrix(128, 512);
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||
auto result = SVD(a, 100, 1e-6);
|
||||
auto result = SVD(a, 100, 1e-4);
|
||||
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) {
|
||||
// Too slow on the interpreter and CPU backends.
|
||||
XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
|
||||
Various_Size_Random_Matrix_512x256))) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array2D<float> a_val = GenerateRandomMatrix(512, 256);
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
|
||||
auto result = SVD(a, 100, 1e-6);
|
||||
auto result = SVD(a, 100, 1e-4);
|
||||
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
|
||||
ErrorSpec(1e-3, 1e-3));
|
||||
}
|
||||
|
||||
XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) {
|
||||
// Too slow on the CPU, GPU and interpreter backends.
|
||||
XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
|
||||
Various_Size_Random_Matrix_512x512)))) {
|
||||
XlaBuilder builder(TestName());
|
||||
Array2D<float> a_val = GenerateRandomMatrix(512, 512);
|
||||
XlaOp a;
|
||||
|
Loading…
Reference in New Issue
Block a user