[XLA] Fix edge cases in erfinv implementation.
We were not handling +/-1 correctly. Some backends would return the wrong infinity (i.e. -/+inf rather than +/-inf), while others were returning nan. Also clean up the tests a tad. PiperOrigin-RevId: 235271238
This commit is contained in:
parent
932c34f8d5
commit
8be0d24cc9
@ -273,7 +273,18 @@ XlaOp ErfInv(XlaOp x) {
|
||||
for (int i = 1; i < kDegree; ++i) {
|
||||
p = coefficient(i) + p * w;
|
||||
}
|
||||
return p * x;
|
||||
|
||||
// Result modulo edge cases.
|
||||
XlaOp result = p * x;
|
||||
|
||||
// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
|
||||
// indeterminate, and can give nan or -/+inf.)
|
||||
auto& b = *x.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
|
||||
return Select(Eq(Abs(x), ScalarLike(x, 1)),
|
||||
x * MaxValue(&b, shape.element_type()), result);
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -171,9 +171,6 @@ INSTANTIATE_TEST_CASE_P(
|
||||
::testing::ValuesIn(std::vector<Testcase>{
|
||||
Testcase{"square", Square, [](float x) { return x * x; }},
|
||||
Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }},
|
||||
Testcase{"lgamma", Lgamma, std::lgamma}
|
||||
.set_tolerance(0.1, 0.15)
|
||||
.set_fewer_infs_ok(),
|
||||
Testcase{"asin", Asin, std::asin}.set_skip_infs(),
|
||||
Testcase{"acos", Acos, std::acos}.set_skip_infs(),
|
||||
Testcase{"atan", Atan, std::atan},
|
||||
|
@ -129,13 +129,25 @@ class MathTypedTest : public MathTest {
|
||||
|
||||
XlaBuilder b(TestName());
|
||||
auto x = AddParam(LiteralUtil::CreateR1<T>({-inf}), &b);
|
||||
ConstantR1<T>(&b, {-inf});
|
||||
ConcatInDim(
|
||||
&b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))},
|
||||
0);
|
||||
std::vector<T> expected = {nan, inf, inf};
|
||||
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
void TestErfEdgeCases() {
|
||||
SetFastMathDisabled(true);
|
||||
|
||||
XlaBuilder b(TestName());
|
||||
auto x = AddParam(LiteralUtil::CreateR1<T>({T{-1}, T{1}, T{0}}), &b);
|
||||
ErfInv(x);
|
||||
|
||||
const T inf(std::numeric_limits<float>::infinity());
|
||||
std::vector<T> expected = {-inf, inf, T{0}};
|
||||
|
||||
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
||||
@ -154,6 +166,7 @@ XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); }
|
||||
XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) {
|
||||
this->TestSqrtPowInequivalence();
|
||||
}
|
||||
XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfEdgeCases(); }
|
||||
|
||||
// Check that certain ops only support real, floating-point inputs.
|
||||
//
|
||||
|
@ -27,6 +27,82 @@ namespace {
|
||||
|
||||
using Eigen::half;
|
||||
|
||||
template <typename T, size_t N>
|
||||
T EvaluatePolynomial(T x, const std::array<T, N>& coeffs) {
|
||||
T result = 0;
|
||||
for (T c : coeffs) {
|
||||
result = result * x + c;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// There's no std::erfinv, so we have to implement it ourselves. This follows
|
||||
// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a
|
||||
// different implementation from that in math.cc.
|
||||
float HostErfInv(float x) {
|
||||
std::array<double, 8> kPolyA = {
|
||||
8.8709406962545514830200e2, 1.1819493347062294404278e4,
|
||||
2.3782041382114385731252e4, 1.6235862515167575384252e4,
|
||||
4.8548868893843886794648e3, 6.9706266534389598238465e2,
|
||||
4.7072688112383978012285e1, 1.1975323115670912564578e0,
|
||||
};
|
||||
std::array<double, 8> kPolyB = {
|
||||
5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4,
|
||||
2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2,
|
||||
4.2313330701600911252e1, 1.0000000000000000000e0,
|
||||
};
|
||||
std::array<double, 8> kPolyC = {
|
||||
7.74545014278341407640e-4, 2.27238449892691845833e-2,
|
||||
2.41780725177450611770e-1, 1.27045825245236838258e0,
|
||||
3.64784832476320460504e0, 5.76949722146069140550e0,
|
||||
4.63033784615654529590e0, 1.42343711074968357734e0,
|
||||
};
|
||||
std::array<double, 8> kPolyD = {
|
||||
1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4,
|
||||
2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1,
|
||||
9.7547832001787427186894837e-1, 2.3707661626024532365971225e0,
|
||||
2.9036514445419946173133295e0, 1.4142135623730950488016887e0,
|
||||
};
|
||||
std::array<double, 8> kPolyE = {
|
||||
2.01033439929228813265e-7, 2.71155556874348757815e-5,
|
||||
1.24266094738807843860e-3, 2.65321895265761230930e-2,
|
||||
2.96560571828504891230e-1, 1.78482653991729133580e0,
|
||||
5.46378491116411436990e0, 6.65790464350110377720e0,
|
||||
};
|
||||
std::array<double, 8> kPolyF = {
|
||||
2.891024605872965461538222e-15, 2.010321207683943062279931e-7,
|
||||
2.611088405080593625138020e-5, 1.112800997078859844711555e-3,
|
||||
2.103693768272068968719679e-2, 1.936480946950659106176712e-1,
|
||||
8.482908416595164588112026e-1, 1.414213562373095048801689e0,
|
||||
};
|
||||
|
||||
if (std::abs(x) > 1 || std::isnan(x)) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (std::abs(x) == 1) {
|
||||
return std::copysign(std::numeric_limits<float>::infinity(), x);
|
||||
}
|
||||
|
||||
float unsigned_result = [&] {
|
||||
float y = std::abs(x);
|
||||
if (y <= 0.85) {
|
||||
double r = 0.180625 - 0.25 * y * y;
|
||||
return (y * EvaluatePolynomial(r, kPolyA)) /
|
||||
EvaluatePolynomial(r, kPolyB);
|
||||
} else {
|
||||
double r = std::sqrt(std::log(2.0) - std::log1p(-y));
|
||||
if (r <= 5.0) {
|
||||
r -= 1.6;
|
||||
return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD);
|
||||
} else {
|
||||
r -= 5;
|
||||
return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF);
|
||||
}
|
||||
}
|
||||
}();
|
||||
return std::copysign(unsigned_result, x);
|
||||
}
|
||||
|
||||
// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
|
||||
// guaranteed that we're printing the full number.
|
||||
//
|
||||
@ -397,6 +473,7 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
|
||||
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
|
||||
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
|
||||
XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
|
||||
XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); }
|
||||
XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
|
||||
// Our implementation gets within 0.0001 rel error except for ~20 denormal
|
||||
// inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
|
||||
|
Loading…
Reference in New Issue
Block a user