diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 4a99debbe70..cd6746d997c 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -472,11 +472,6 @@ cc_library( xla_test( name = "svd_test", srcs = ["svd_test.cc"], - # Blacklisted because the tests are flaky. - blacklisted_backends = [ - "cpu", - "gpu", - ], real_hardware_only = True, shard_count = 10, tags = ["optonly"], diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index 93f3d3ab131..902269d9412 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -46,23 +46,34 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, return ConvertElementType(indicator, type); } +XlaOp GetDiagonalMask(XlaOp x, int diagonal) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto n_dims = static_cast(shape.rank()); + TF_RET_CHECK(n_dims >= 2); + auto m = shape.dimensions(n_dims - 2); + auto n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + ConstantR0WithType(builder, S32, diagonal); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + return mask; + }); +} + XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = shape.rank(); + auto n_dims = static_cast(shape.rank()); TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - auto offset = ConstantR0WithType(builder, S32, k); - - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, S32, n); - auto b = Iota(builder, S32, m) + offset; - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); + auto mask = GetDiagonalMask(x, k); // TPUs don't support S64 add reduction at the moment. But fortunately // OR-reductions work just as well for integers. diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 5f1ca964a41..541ce2897f5 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -31,6 +31,10 @@ namespace xla { // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); +// Returns a mask where the 'diagonal'-th diagonal is true and everything else +// is false. +XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); + // Get the diagonals of the last two dimensions. Use k>0 for diagonals above the // main diagonal, and k<0 for diagonals below the main diagonal. // diff --git a/tensorflow/compiler/xla/client/lib/svd.cc b/tensorflow/compiler/xla/client/lib/svd.cc index 53a23872709..646875a20a2 100644 --- a/tensorflow/compiler/xla/client/lib/svd.cc +++ b/tensorflow/compiler/xla/client/lib/svd.cc @@ -75,11 +75,6 @@ struct OneSidedJacobiRotation { JacobiRotation rot_r; }; -struct FrobeniusNorms { - XlaOp off_diagonal_norm; - XlaOp total_norm; -}; - // Householder reflection on the trailing elements of a vector. // // H = I - beta * [1, v]' * [1, v] @@ -567,27 +562,26 @@ StatusOr OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q, return svd_result; } -StatusOr ComputeFrobeniusNorms(XlaOp w) { +StatusOr ComputeToleranceComparison(XlaOp w, XlaOp epsilon) { XlaBuilder* builder = w.builder(); TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); - const int64 num_dims = shape.rank(); - auto frobenius_norm = - Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), - CreateScalarAddComputation(shape.element_type(), builder), - {num_dims - 2, num_dims - 1})); - auto diag = GetMatrixDiagonal(w); - auto diag_square = - Reduce(Square(diag), ScalarLike(w, 0.0), - CreateScalarAddComputation(shape.element_type(), builder), - {num_dims - 2}); - - FrobeniusNorms frobenius_norms; - - frobenius_norms.off_diagonal_norm = - Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); - frobenius_norms.total_norm = frobenius_norm; - - return frobenius_norms; + auto num_dims = static_cast(shape.rank()); + int64 n = shape.dimensions(num_dims - 1); + shape.set_dimensions(num_dims - 2, n); + auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n}); + auto diag = GetMatrixDiagonal(w_sliced); + diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag); + std::vector broadcasted_dims(num_dims - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto broadcast_to_rows = + BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + broadcasted_dims.back() = num_dims - 1; + auto broadcast_to_columns = + BroadcastInDim(diag, shape.dimensions(), broadcasted_dims); + // Compute w_{i,i} * w_{j,j} * epsilon^2 < (w_{i,j})^2 + return Lt( + broadcast_to_rows * broadcast_to_columns * epsilon * epsilon, + Square(Select(GetDiagonalMask(w_sliced), ZerosLike(w_sliced), w_sliced))); } // Main boby of One-sided Jacobi Method. @@ -603,13 +597,13 @@ StatusOr> WhileLoopFn( auto max_sweeps = ScalarLike(k, max_sweep_updates); auto sweep_update_cond = Gt(max_sweeps, k); - auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie(); - auto tol = norms.total_norm * values[4]; - auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), - xla::ConstantR0(cond_builder, false), - CreateScalarOrComputation(PRED, cond_builder)); + TF_ASSIGN_OR_RETURN(auto tolerance_comparison, + ComputeToleranceComparison(values[3], values[4])); + auto tolerance_cond = ReduceAll( + tolerance_comparison, xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); - return And(sweep_update_cond, tol_cond); + return And(sweep_update_cond, tolerance_cond); }; auto while_body_fn = diff --git a/tensorflow/compiler/xla/client/lib/svd_test.cc b/tensorflow/compiler/xla/client/lib/svd_test.cc index 047095deff1..a39238548fc 100644 --- a/tensorflow/compiler/xla/client/lib/svd_test.cc +++ b/tensorflow/compiler/xla/client/lib/svd_test.cc @@ -184,12 +184,14 @@ XLA_TEST_F(SVDTest, TestSingleValuesMatchNumpy) { ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x128) { +// Too slow on the interpreter backend. +XLA_TEST_F(SVDTest, + DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_512x128)) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 128); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, @@ -201,7 +203,7 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_128x256) { Array2D a_val = GenerateRandomMatrix(128, 256); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, @@ -213,40 +215,44 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_256x128) { Array2D a_val = GenerateRandomMatrix(256, 128); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -// TODO(b/133353535): This test seems too sensitive to the particular choice of -// random matrix. -XLA_TEST_F(SVDTest, DISABLED_Various_Size_Random_Matrix_128x512) { +// Too slow on the interpreter backend. +XLA_TEST_F(SVDTest, + DISABLED_ON_INTERPRETER(Various_Size_Random_Matrix_128x512)) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(128, 512); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x256) { +// Too slow on the interpreter and CPU backends. +XLA_TEST_F(SVDTest, DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + Various_Size_Random_Matrix_512x256))) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 256); XlaOp a; auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); - auto result = SVD(a, 100, 1e-6); + auto result = SVD(a, 100, 1e-4); GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder); ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, ErrorSpec(1e-3, 1e-3)); } -XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) { +// Too slow on the CPU, GPU and interpreter backends. +XLA_TEST_F(SVDTest, DISABLED_ON_GPU(DISABLED_ON_CPU(DISABLED_ON_INTERPRETER( + Various_Size_Random_Matrix_512x512)))) { XlaBuilder builder(TestName()); Array2D a_val = GenerateRandomMatrix(512, 512); XlaOp a;