From 151396d26d249110bcb36deeb954687223ca7a52 Mon Sep 17 00:00:00 2001 From: Peter Hawkins <phawkins@google.com> Date: Thu, 15 Oct 2020 11:04:30 -0700 Subject: [PATCH] [XLA] Switch implementation of erf to use the same rational polynomial approximation as Eigen. PiperOrigin-RevId: 337344225 Change-Id: I881171616bf5e9cf2ed3711e06fb28a2724d3238 --- tensorflow/compiler/xla/client/lib/math.cc | 30 +++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 410c86732d6..76cc6f0159b 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) { // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. // // This follows Cephes's f32 implementation of erf. -static XlaOp ErfImpl32(XlaOp x) { +static XlaOp ErfImpl32Cephes(XlaOp x) { // Coefficients for by erf(f32), from Cephes. // // erf(x) = x P(x^2), 0 < x < 1 @@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) { // (not surprising!), so upcast to f32 in this case. return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), - ScalarLike(x, 1) - ErfImpl32(x)); + ScalarLike(x, 1) - ErfImpl32Cephes(x)); }); }); } +// Compute a polynomial approximation of the error function. +// This is the same approximation used by Eigen. +static XlaOp ErfImpl32(XlaOp x) { + static const std::array<float, 7> kAlpha{ + -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, + -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f, + -1.60960333262415e-02f, + }; + + static const std::array<float, 5> kBeta{ + -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f, + -7.37332916720468e-03f, -1.42647390514189e-02f, + }; + + x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f)); + auto x2 = x * x; + return x * EvaluatePolynomial<float>(x2, kAlpha) / + EvaluatePolynomial<float>(x2, kBeta); +} + XlaOp Erf(XlaOp x) { auto& b = *x.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { @@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) { } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x), - ScalarLike(x, 1) - ErfcImpl32(x)); - }); + return DoWithUpcastToF32(x, {BF16, F16}, + [](XlaOp x) { return ErfImpl32(x); }); }); }