[XLA] Switch implementation of erf to use the same rational polynomial approximation as Eigen.
PiperOrigin-RevId: 337344225 Change-Id: I881171616bf5e9cf2ed3711e06fb28a2724d3238
This commit is contained in:
parent
4a1f962ca2
commit
151396d26d
@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) {
|
||||
// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
|
||||
//
|
||||
// This follows Cephes's f32 implementation of erf.
|
||||
static XlaOp ErfImpl32(XlaOp x) {
|
||||
static XlaOp ErfImpl32Cephes(XlaOp x) {
|
||||
// Coefficients for by erf(f32), from Cephes.
|
||||
//
|
||||
// erf(x) = x P(x^2), 0 < x < 1
|
||||
@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) {
|
||||
// (not surprising!), so upcast to f32 in this case.
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
|
||||
ScalarLike(x, 1) - ErfImpl32(x));
|
||||
ScalarLike(x, 1) - ErfImpl32Cephes(x));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Compute a polynomial approximation of the error function.
|
||||
// This is the same approximation used by Eigen.
|
||||
static XlaOp ErfImpl32(XlaOp x) {
|
||||
static const std::array<float, 7> kAlpha{
|
||||
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
||||
-1.60960333262415e-02f,
|
||||
};
|
||||
|
||||
static const std::array<float, 5> kBeta{
|
||||
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
|
||||
-7.37332916720468e-03f, -1.42647390514189e-02f,
|
||||
};
|
||||
|
||||
x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
|
||||
auto x2 = x * x;
|
||||
return x * EvaluatePolynomial<float>(x2, kAlpha) /
|
||||
EvaluatePolynomial<float>(x2, kBeta);
|
||||
}
|
||||
|
||||
XlaOp Erf(XlaOp x) {
|
||||
auto& b = *x.builder();
|
||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) {
|
||||
}
|
||||
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
|
||||
// (not surprising!), so upcast to f32 in this case.
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||
return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x),
|
||||
ScalarLike(x, 1) - ErfcImpl32(x));
|
||||
});
|
||||
return DoWithUpcastToF32(x, {BF16, F16},
|
||||
[](XlaOp x) { return ErfImpl32(x); });
|
||||
});
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user