[XLA] Fix edge cases in erfinv implementation.
We were not handling +/-1 correctly. Some backends would return the wrong infinity (i.e. -/+inf rather than +/-inf), while others were returning nan. Also clean up the tests a tad. PiperOrigin-RevId: 235271238
This commit is contained in:
parent
932c34f8d5
commit
8be0d24cc9
@ -273,7 +273,18 @@ XlaOp ErfInv(XlaOp x) {
|
|||||||
for (int i = 1; i < kDegree; ++i) {
|
for (int i = 1; i < kDegree; ++i) {
|
||||||
p = coefficient(i) + p * w;
|
p = coefficient(i) + p * w;
|
||||||
}
|
}
|
||||||
return p * x;
|
|
||||||
|
// Result modulo edge cases.
|
||||||
|
XlaOp result = p * x;
|
||||||
|
|
||||||
|
// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
|
||||||
|
// indeterminate, and can give nan or -/+inf.)
|
||||||
|
auto& b = *x.builder();
|
||||||
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
|
||||||
|
return Select(Eq(Abs(x), ScalarLike(x, 1)),
|
||||||
|
x * MaxValue(&b, shape.element_type()), result);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -171,9 +171,6 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
::testing::ValuesIn(std::vector<Testcase>{
|
::testing::ValuesIn(std::vector<Testcase>{
|
||||||
Testcase{"square", Square, [](float x) { return x * x; }},
|
Testcase{"square", Square, [](float x) { return x * x; }},
|
||||||
Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }},
|
Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }},
|
||||||
Testcase{"lgamma", Lgamma, std::lgamma}
|
|
||||||
.set_tolerance(0.1, 0.15)
|
|
||||||
.set_fewer_infs_ok(),
|
|
||||||
Testcase{"asin", Asin, std::asin}.set_skip_infs(),
|
Testcase{"asin", Asin, std::asin}.set_skip_infs(),
|
||||||
Testcase{"acos", Acos, std::acos}.set_skip_infs(),
|
Testcase{"acos", Acos, std::acos}.set_skip_infs(),
|
||||||
Testcase{"atan", Atan, std::atan},
|
Testcase{"atan", Atan, std::atan},
|
||||||
|
@ -129,13 +129,25 @@ class MathTypedTest : public MathTest {
|
|||||||
|
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
auto x = AddParam(LiteralUtil::CreateR1<T>({-inf}), &b);
|
auto x = AddParam(LiteralUtil::CreateR1<T>({-inf}), &b);
|
||||||
ConstantR1<T>(&b, {-inf});
|
|
||||||
ConcatInDim(
|
ConcatInDim(
|
||||||
&b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))},
|
&b, {Sqrt(x), Pow(x, ScalarLike(x, 0.5)), Pow(x, ScalarLike(x, 0.3))},
|
||||||
0);
|
0);
|
||||||
std::vector<T> expected = {nan, inf, inf};
|
std::vector<T> expected = {nan, inf, inf};
|
||||||
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
|
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestErfEdgeCases() {
|
||||||
|
SetFastMathDisabled(true);
|
||||||
|
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
auto x = AddParam(LiteralUtil::CreateR1<T>({T{-1}, T{1}, T{0}}), &b);
|
||||||
|
ErfInv(x);
|
||||||
|
|
||||||
|
const T inf(std::numeric_limits<float>::infinity());
|
||||||
|
std::vector<T> expected = {-inf, inf, T{0}};
|
||||||
|
|
||||||
|
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
|
||||||
@ -154,6 +166,7 @@ XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); }
|
|||||||
XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) {
|
XLA_TYPED_TEST(MathTypedTest, SqrtPowInequivalence) {
|
||||||
this->TestSqrtPowInequivalence();
|
this->TestSqrtPowInequivalence();
|
||||||
}
|
}
|
||||||
|
XLA_TYPED_TEST(MathTypedTest, ErfInvEdgeCases) { this->TestErfEdgeCases(); }
|
||||||
|
|
||||||
// Check that certain ops only support real, floating-point inputs.
|
// Check that certain ops only support real, floating-point inputs.
|
||||||
//
|
//
|
||||||
|
@ -27,6 +27,82 @@ namespace {
|
|||||||
|
|
||||||
using Eigen::half;
|
using Eigen::half;
|
||||||
|
|
||||||
|
template <typename T, size_t N>
|
||||||
|
T EvaluatePolynomial(T x, const std::array<T, N>& coeffs) {
|
||||||
|
T result = 0;
|
||||||
|
for (T c : coeffs) {
|
||||||
|
result = result * x + c;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// There's no std::erfinv, so we have to implement it ourselves. This follows
|
||||||
|
// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a
|
||||||
|
// different implementation from that in math.cc.
|
||||||
|
float HostErfInv(float x) {
|
||||||
|
std::array<double, 8> kPolyA = {
|
||||||
|
8.8709406962545514830200e2, 1.1819493347062294404278e4,
|
||||||
|
2.3782041382114385731252e4, 1.6235862515167575384252e4,
|
||||||
|
4.8548868893843886794648e3, 6.9706266534389598238465e2,
|
||||||
|
4.7072688112383978012285e1, 1.1975323115670912564578e0,
|
||||||
|
};
|
||||||
|
std::array<double, 8> kPolyB = {
|
||||||
|
5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4,
|
||||||
|
2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2,
|
||||||
|
4.2313330701600911252e1, 1.0000000000000000000e0,
|
||||||
|
};
|
||||||
|
std::array<double, 8> kPolyC = {
|
||||||
|
7.74545014278341407640e-4, 2.27238449892691845833e-2,
|
||||||
|
2.41780725177450611770e-1, 1.27045825245236838258e0,
|
||||||
|
3.64784832476320460504e0, 5.76949722146069140550e0,
|
||||||
|
4.63033784615654529590e0, 1.42343711074968357734e0,
|
||||||
|
};
|
||||||
|
std::array<double, 8> kPolyD = {
|
||||||
|
1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4,
|
||||||
|
2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1,
|
||||||
|
9.7547832001787427186894837e-1, 2.3707661626024532365971225e0,
|
||||||
|
2.9036514445419946173133295e0, 1.4142135623730950488016887e0,
|
||||||
|
};
|
||||||
|
std::array<double, 8> kPolyE = {
|
||||||
|
2.01033439929228813265e-7, 2.71155556874348757815e-5,
|
||||||
|
1.24266094738807843860e-3, 2.65321895265761230930e-2,
|
||||||
|
2.96560571828504891230e-1, 1.78482653991729133580e0,
|
||||||
|
5.46378491116411436990e0, 6.65790464350110377720e0,
|
||||||
|
};
|
||||||
|
std::array<double, 8> kPolyF = {
|
||||||
|
2.891024605872965461538222e-15, 2.010321207683943062279931e-7,
|
||||||
|
2.611088405080593625138020e-5, 1.112800997078859844711555e-3,
|
||||||
|
2.103693768272068968719679e-2, 1.936480946950659106176712e-1,
|
||||||
|
8.482908416595164588112026e-1, 1.414213562373095048801689e0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (std::abs(x) > 1 || std::isnan(x)) {
|
||||||
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
if (std::abs(x) == 1) {
|
||||||
|
return std::copysign(std::numeric_limits<float>::infinity(), x);
|
||||||
|
}
|
||||||
|
|
||||||
|
float unsigned_result = [&] {
|
||||||
|
float y = std::abs(x);
|
||||||
|
if (y <= 0.85) {
|
||||||
|
double r = 0.180625 - 0.25 * y * y;
|
||||||
|
return (y * EvaluatePolynomial(r, kPolyA)) /
|
||||||
|
EvaluatePolynomial(r, kPolyB);
|
||||||
|
} else {
|
||||||
|
double r = std::sqrt(std::log(2.0) - std::log1p(-y));
|
||||||
|
if (r <= 5.0) {
|
||||||
|
r -= 1.6;
|
||||||
|
return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD);
|
||||||
|
} else {
|
||||||
|
r -= 5;
|
||||||
|
return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
return std::copysign(unsigned_result, x);
|
||||||
|
}
|
||||||
|
|
||||||
// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
|
// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
|
||||||
// guaranteed that we're printing the full number.
|
// guaranteed that we're printing the full number.
|
||||||
//
|
//
|
||||||
@ -397,6 +473,7 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
|
|||||||
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
|
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
|
||||||
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
|
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
|
||||||
XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
|
XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
|
||||||
|
XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); }
|
||||||
XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
|
XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
|
||||||
// Our implementation gets within 0.0001 rel error except for ~20 denormal
|
// Our implementation gets within 0.0001 rel error except for ~20 denormal
|
||||||
// inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
|
// inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
|
||||||
|
Loading…
Reference in New Issue
Block a user