[XLA] Switch implementation of erf to use the same rational polynomial approximation as Eigen.
PiperOrigin-RevId: 337344225 Change-Id: I881171616bf5e9cf2ed3711e06fb28a2724d3238
This commit is contained in:
parent
4a1f962ca2
commit
151396d26d
@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) {
|
|||||||
// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
|
// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
|
||||||
//
|
//
|
||||||
// This follows Cephes's f32 implementation of erf.
|
// This follows Cephes's f32 implementation of erf.
|
||||||
static XlaOp ErfImpl32(XlaOp x) {
|
static XlaOp ErfImpl32Cephes(XlaOp x) {
|
||||||
// Coefficients for by erf(f32), from Cephes.
|
// Coefficients for by erf(f32), from Cephes.
|
||||||
//
|
//
|
||||||
// erf(x) = x P(x^2), 0 < x < 1
|
// erf(x) = x P(x^2), 0 < x < 1
|
||||||
@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) {
|
|||||||
// (not surprising!), so upcast to f32 in this case.
|
// (not surprising!), so upcast to f32 in this case.
|
||||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||||
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
|
return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
|
||||||
ScalarLike(x, 1) - ErfImpl32(x));
|
ScalarLike(x, 1) - ErfImpl32Cephes(x));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compute a polynomial approximation of the error function.
|
||||||
|
// This is the same approximation used by Eigen.
|
||||||
|
static XlaOp ErfImpl32(XlaOp x) {
|
||||||
|
static const std::array<float, 7> kAlpha{
|
||||||
|
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||||
|
-5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
|
||||||
|
-1.60960333262415e-02f,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::array<float, 5> kBeta{
|
||||||
|
-1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
|
||||||
|
-7.37332916720468e-03f, -1.42647390514189e-02f,
|
||||||
|
};
|
||||||
|
|
||||||
|
x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
|
||||||
|
auto x2 = x * x;
|
||||||
|
return x * EvaluatePolynomial<float>(x2, kAlpha) /
|
||||||
|
EvaluatePolynomial<float>(x2, kBeta);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Erf(XlaOp x) {
|
XlaOp Erf(XlaOp x) {
|
||||||
auto& b = *x.builder();
|
auto& b = *x.builder();
|
||||||
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) {
|
|||||||
}
|
}
|
||||||
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
|
// Erf(c)Impl don't have enough precision when run with bf16 intermediates
|
||||||
// (not surprising!), so upcast to f32 in this case.
|
// (not surprising!), so upcast to f32 in this case.
|
||||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
return DoWithUpcastToF32(x, {BF16, F16},
|
||||||
return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x),
|
[](XlaOp x) { return ErfImpl32(x); });
|
||||||
ScalarLike(x, 1) - ErfcImpl32(x));
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user