Change the condition related to the accepted error tolerance.

Use the condition from the paper:
http://www.netlib.org/lapack/lawnspdf/lawn15.pdf
(however apply square on both sides of the comparison)

With this change, we can enable the tests on CPU and GPU backends.

PiperOrigin-RevId: 249792805
This commit is contained in:
Adrian Kuegel 2019-05-24 01:19:11 -07:00 committed by TensorFlower Gardener
parent 43c7131efc
commit 87aaed125a
5 changed files with 65 additions and 55 deletions

View File

@ -472,11 +472,6 @@ cc_library(
xla_test(
name = "svd_test",
srcs = ["svd_test.cc"],
# Blacklisted because the tests are flaky.
blacklisted_backends = [
"cpu",
"gpu",
],
real_hardware_only = True,
shard_count = 10,
tags = ["optonly"],

View File

@ -46,23 +46,34 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
return ConvertElementType(indicator, type);
}
XlaOp GetDiagonalMask(XlaOp x, int diagonal) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
auto n_dims = static_cast<int32>(shape.rank());
TF_RET_CHECK(n_dims >= 2);
auto m = shape.dimensions(n_dims - 2);
auto n = shape.dimensions(n_dims - 1);
absl::Span<const int64> major_dims =
AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, S32, n);
auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal);
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
auto mask = Broadcast(indicator, major_dims);
return mask;
});
}
XlaOp GetMatrixDiagonal(XlaOp x, int k) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
const int64 n_dims = shape.rank();
auto n_dims = static_cast<int32>(shape.rank());
TF_RET_CHECK(n_dims >= 2);
const int64 m = shape.dimensions(n_dims - 2);
const int64 n = shape.dimensions(n_dims - 1);
auto offset = ConstantR0WithType(builder, S32, k);
absl::Span<const int64> major_dims =
AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
auto a = Iota(builder, S32, n);
auto b = Iota(builder, S32, m) + offset;
auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
auto mask = Broadcast(indicator, major_dims);
auto mask = GetDiagonalMask(x, k);
// TPUs don't support S64 add reduction at the moment. But fortunately
// OR-reductions work just as well for integers.

View File

@ -31,6 +31,10 @@ namespace xla {
// else.
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
// Returns a mask where the 'diagonal'-th diagonal is true and everything else
// is false.
XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0);
// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the
// main diagonal, and k<0 for diagonals below the main diagonal.
//

View File

@ -75,11 +75,6 @@ struct OneSidedJacobiRotation {
JacobiRotation rot_r;
};
struct FrobeniusNorms {
XlaOp off_diagonal_norm;
XlaOp total_norm;
};
// Householder reflection on the trailing elements of a vector.
//
// H = I - beta * [1, v]' * [1, v]
@ -567,27 +562,26 @@ StatusOr<SVDResult> OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q,
return svd_result;
}
StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w) {
StatusOr<XlaOp> ComputeToleranceComparison(XlaOp w, XlaOp epsilon) {
XlaBuilder* builder = w.builder();
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
const int64 num_dims = shape.rank();
auto frobenius_norm =
Sqrt(Reduce(Square(w), ScalarLike(w, 0.0),
CreateScalarAddComputation(shape.element_type(), builder),
{num_dims - 2, num_dims - 1}));
auto diag = GetMatrixDiagonal(w);
auto diag_square =
Reduce(Square(diag), ScalarLike(w, 0.0),
CreateScalarAddComputation(shape.element_type(), builder),
{num_dims - 2});
FrobeniusNorms frobenius_norms;
frobenius_norms.off_diagonal_norm =
Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0)));
frobenius_norms.total_norm = frobenius_norm;
return frobenius_norms;
auto num_dims = static_cast<int32>(shape.rank());
int64 n = shape.dimensions(num_dims - 1);
shape.set_dimensions(num_dims - 2, n);
auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n});
auto diag = GetMatrixDiagonal(w_sliced);
diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag);
std::vector<int64> broadcasted_dims(num_dims - 1);
std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
auto broadcast_to_rows =
BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
broadcasted_dims.back() = num_dims - 1;
auto broadcast_to_columns =
BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
// Compute w_{i,i} * w_{j,j} * epsilon^2 < (w_{i,j})^2
return Lt(
broadcast_to_rows * broadcast_to_columns * epsilon * epsilon,
Square(Select(GetDiagonalMask(w_sliced), ZerosLike(w_sliced), w_sliced)));
}
// Main boby of One-sided Jacobi Method.
@ -603,13 +597,13 @@ StatusOr<std::vector<XlaOp>> WhileLoopFn(
auto max_sweeps = ScalarLike(k, max_sweep_updates);
auto sweep_update_cond = Gt(max_sweeps, k);
auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie();
auto tol = norms.total_norm * values[4];
auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm),
xla::ConstantR0<bool>(cond_builder, false),
CreateScalarOrComputation(PRED, cond_builder));
TF_ASSIGN_OR_RETURN(auto tolerance_comparison,
ComputeToleranceComparison(values[3], values[4]));
auto tolerance_cond = ReduceAll(
tolerance_comparison, xla::ConstantR0<bool>(cond_builder, false),
CreateScalarOrComputation(PRED, cond_builder));
return And(sweep_update_cond, tol_cond);
return And(sweep_update_cond, tolerance_cond);
};
auto while_body_fn =

View File

@ -184,12 +184,14 @@ XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) {
ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) {
// Too slow on the interpreter backend.
XLA_TEST_F(SVDTest,
DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) {
XlaBuilder builder(TestName());
Array2D<float> a_val = GenerateRandomMatrix(512, 128);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
auto result = SVD(a, 100, 1e-4);
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
@ -201,7 +203,7 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) {
Array2D<float> a_val = GenerateRandomMatrix(128, 256);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
auto result = SVD(a, 100, 1e-4);
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
@ -213,40 +215,44 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) {
Array2D<float> a_val = GenerateRandomMatrix(256, 128);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
auto result = SVD(a, 100, 1e-4);
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-3, 1e-3));
}
// TODO(b/133353535): This test seems too sensitive to the particular choice of
// random matrix.
XLA_TEST_F(SVDTest, DISABLED_Various_Size_Random_Matrix_128x512) {
// Too slow on the interpreter backend.
XLA_TEST_F(SVDTest,
DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) {
XlaBuilder builder(TestName());
Array2D<float> a_val = GenerateRandomMatrix(128, 512);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
auto result = SVD(a, 100, 1e-4);
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) {
// Too slow on the interpreter and CPU backends.
XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
Various_Size_Random_Matrix_512x256))) {
XlaBuilder builder(TestName());
Array2D<float> a_val = GenerateRandomMatrix(512, 256);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
auto result = SVD(a, 100, 1e-4);
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);
ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
ErrorSpec(1e-3, 1e-3));
}
XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) {
// Too slow on the CPU, GPU and interpreter backends.
XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER(
Various_Size_Random_Matrix_512x512)))) {
XlaBuilder builder(TestName());
Array2D<float> a_val = GenerateRandomMatrix(512, 512);
XlaOp a;