From 8be0d24cc9d842fd06e25b628ab0aabbd614b1a2 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 22 Feb 2019 15:35:10 -0800 Subject: [PATCH] [XLA] Fix edge cases in erfinv implementation. We were not handling +/-1 correctly. Some backends would return the wrong infinity (i.e. -/+inf rather than +/-inf), while others were returning nan. Also clean up the tests a tad. PiperOrigin-RevId: 235271238 --- tensorflow/compiler/xla/client/lib/math.cc | 13 +++- .../xla/client/lib/math_exhaustive_test.cc | 3 - .../compiler/xla/client/lib/math_test.cc | 15 +++- .../compiler/xla/tests/exhaustive_op_test.cc | 77 +++++++++++++++++++ 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index adcaac98616..381da66a71a 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -273,7 +273,18 @@ XlaOp ErfInv(XlaOp x) { for (int i = 1; i < kDegree; ++i) { p = coefficient(i) + p * w; } - return p * x; + + // Result modulo edge cases. + XlaOp result = p * x; + + // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is + // indeterminate, and can give nan or -/+inf.) + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x)); + return Select(Eq(Abs(x), ScalarLike(x, 1)), + x * MaxValue(&b, shape.element_type()), result); + }); } namespace { diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc index 2d6c8797d16..09a7e295b57 100644 --- a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc @@ -171,9 +171,6 @@ INSTANTIATE_TEST_CASE_P( ::testing::ValuesIn(std::vector{ Testcase{"square", Square, [](float x) { return x * x; }}, Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }}, - Testcase{"lgamma", Lgamma, std::lgamma} - .set_tolerance(0.1, 0.15) - .set_fewer_infs_ok(), Testcase{"asin", Asin, std::asin}.set_skip_infs(), Testcase{"acos", Acos, std::acos}.set_skip_infs(), Testcase{"atan", Atan, std::atan}, diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index f5ba3e78056..50613ce5025 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -129,13 +129,25 @@ class MathTypedTest : public MathTest { XlaBuilder b(TestName()); auto x = AddParam(LiteralUtil::CreateR1({-inf}), &b); - ConstantR1(&b, {-inf}); ConcatInDim( &b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))}, 0); std::vector expected = {nan, inf, inf}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } + + void TestErfEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + auto x = AddParam(LiteralUtil::CreateR1({T{-1}, T{1}, T{0}}), &b); + ErfInv(x); + + const T inf(std::numeric_limits::infinity()); + std::vector expected = {-inf, inf, T{0}}; + + ComputeAndCompareR1(&b, expected, {}, error_spec_); + } }; // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. @@ -154,6 +166,7 @@ XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); } XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) { this->TestSqrtPowInequivalence(); } +XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfEdgeCases(); } // Check that certain ops only support real, floating-point inputs. // diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc index c405f5580ea..e143a1c3d11 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -27,6 +27,82 @@ namespace { using Eigen::half; +template +T EvaluatePolynomial(T x, const std::array& coeffs) { + T result = 0; + for (T c : coeffs) { + result = result * x + c; + } + return result; +} + +// There's no std::erfinv, so we have to implement it ourselves. This follows +// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a +// different implementation from that in math.cc. +float HostErfInv(float x) { + std::array kPolyA = { + 8.8709406962545514830200e2, 1.1819493347062294404278e4, + 2.3782041382114385731252e4, 1.6235862515167575384252e4, + 4.8548868893843886794648e3, 6.9706266534389598238465e2, + 4.7072688112383978012285e1, 1.1975323115670912564578e0, + }; + std::array kPolyB = { + 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, + 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, + 4.2313330701600911252e1, 1.0000000000000000000e0, + }; + std::array kPolyC = { + 7.74545014278341407640e-4, 2.27238449892691845833e-2, + 2.41780725177450611770e-1, 1.27045825245236838258e0, + 3.64784832476320460504e0, 5.76949722146069140550e0, + 4.63033784615654529590e0, 1.42343711074968357734e0, + }; + std::array kPolyD = { + 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, + 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, + 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, + 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, + }; + std::array kPolyE = { + 2.01033439929228813265e-7, 2.71155556874348757815e-5, + 1.24266094738807843860e-3, 2.65321895265761230930e-2, + 2.96560571828504891230e-1, 1.78482653991729133580e0, + 5.46378491116411436990e0, 6.65790464350110377720e0, + }; + std::array kPolyF = { + 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, + 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, + 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, + 8.482908416595164588112026e-1, 1.414213562373095048801689e0, + }; + + if (std::abs(x) > 1 || std::isnan(x)) { + return std::numeric_limits::quiet_NaN(); + } + if (std::abs(x) == 1) { + return std::copysign(std::numeric_limits::infinity(), x); + } + + float unsigned_result = [&] { + float y = std::abs(x); + if (y <= 0.85) { + double r = 0.180625 - 0.25 * y * y; + return (y * EvaluatePolynomial(r, kPolyA)) / + EvaluatePolynomial(r, kPolyB); + } else { + double r = std::sqrt(std::log(2.0) - std::log1p(-y)); + if (r <= 5.0) { + r -= 1.6; + return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD); + } else { + r -= 5; + return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF); + } + } + }(); + return std::copysign(unsigned_result, x); +} + // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be // guaranteed that we're printing the full number. // @@ -397,6 +473,7 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) { XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); } XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); } XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); } +XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); } XLA_TEST_P(ExhaustiveOpTest, Lgamma) { // Our implementation gets within 0.0001 rel error except for ~20 denormal // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.