[XLA] Switch implementation of erf to use the same rational polynomial approximation as Eigen.

PiperOrigin-RevId: 337344225
Change-Id: I881171616bf5e9cf2ed3711e06fb28a2724d3238
This commit is contained in:
Peter Hawkins 2020-10-15 11:04:30 -07:00 committed by TensorFlower Gardener
parent 4a1f962ca2
commit 151396d26d

View File

@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) {
// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
//
// This follows Cephes's f32 implementation of erf.
static XlaOp ErfImpl32(XlaOp x) {
static XlaOp ErfImpl32Cephes(XlaOp x) {
// Coefficients for by erf(f32), from Cephes.
//
// erf(x) = x P(x^2), 0 < x < 1
@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) {
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
ScalarLike(x, 1) - ErfImpl32(x));
ScalarLike(x, 1) - ErfImpl32Cephes(x));
});
});
}
// Compute a polynomial approximation of the error function.
// This is the same approximation used by Eigen.
static XlaOp ErfImpl32(XlaOp x) {
static const std::array<float, 7> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
-1.60960333262415e-02f,
};
static const std::array<float, 5> kBeta{
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
-7.37332916720468e-03f, -1.42647390514189e-02f,
};
x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
auto x2 = x * x;
return x * EvaluatePolynomial<float>(x2, kAlpha) /
EvaluatePolynomial<float>(x2, kBeta);
}
XlaOp Erf(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) {
}
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
// (not surprising!), so upcast to f32 in this case.
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x),
ScalarLike(x, 1) - ErfcImpl32(x));
});
return DoWithUpcastToF32(x, {BF16, F16},
[](XlaOp x) { return ErfImpl32(x); });
});
}