From 151396d26d249110bcb36deeb954687223ca7a52 Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins@google.com>
Date: Thu, 15 Oct 2020 11:04:30 -0700
Subject: [PATCH] [XLA] Switch implementation of erf to use the same rational
 polynomial approximation as Eigen.

PiperOrigin-RevId: 337344225
Change-Id: I881171616bf5e9cf2ed3711e06fb28a2724d3238
---
 tensorflow/compiler/xla/client/lib/math.cc | 30 +++++++++++++++++-----
 1 file changed, 24 insertions(+), 6 deletions(-)

diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index 410c86732d6..76cc6f0159b 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -203,7 +203,7 @@ static XlaOp ErfcImpl32(XlaOp x) {
 // Precondition: abs(x) <= 1.  Otherwise, use ErfcImpl.
 //
 // This follows Cephes's f32 implementation of erf.
-static XlaOp ErfImpl32(XlaOp x) {
+static XlaOp ErfImpl32Cephes(XlaOp x) {
   // Coefficients for by erf(f32), from Cephes.
   //
   // erf(x) = x P(x^2), 0 < x < 1
@@ -291,11 +291,31 @@ XlaOp Erfc(XlaOp x) {
     // (not surprising!), so upcast to f32 in this case.
     return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
       return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
-                    ScalarLike(x, 1) - ErfImpl32(x));
+                    ScalarLike(x, 1) - ErfImpl32Cephes(x));
     });
   });
 }
 
+// Compute a polynomial approximation of the error function.
+// This is the same approximation used by Eigen.
+static XlaOp ErfImpl32(XlaOp x) {
+  static const std::array<float, 7> kAlpha{
+      -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
+      -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
+      -1.60960333262415e-02f,
+  };
+
+  static const std::array<float, 5> kBeta{
+      -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
+      -7.37332916720468e-03f, -1.42647390514189e-02f,
+  };
+
+  x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
+  auto x2 = x * x;
+  return x * EvaluatePolynomial<float>(x2, kAlpha) /
+         EvaluatePolynomial<float>(x2, kBeta);
+}
+
 XlaOp Erf(XlaOp x) {
   auto& b = *x.builder();
   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -310,10 +330,8 @@ XlaOp Erf(XlaOp x) {
     }
     // Erf(c)Impl don't have enough precision when run with bf16 intermediates
     // (not surprising!), so upcast to f32 in this case.
-    return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
-      return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x),
-                    ScalarLike(x, 1) - ErfcImpl32(x));
-    });
+    return DoWithUpcastToF32(x, {BF16, F16},
+                             [](XlaOp x) { return ErfImpl32(x); });
   });
 }