Improve numerics of Log1p for XLA:CPU.

PiperOrigin-RevId: 316714497
Change-Id: I20bf0148a850451077e435190034418e7eb9b38c
This commit is contained in:
Srinivas Vasudevan 2020-06-16 10:55:37 -07:00 committed by TensorFlower Gardener
parent 08e445e37f
commit e6e8d48f8b
4 changed files with 132 additions and 8 deletions

View File

@ -61,6 +61,81 @@ def implicit_reparameterization_grad(a, x):
return -gen_math_ops.igamma_grad_a(a, x) / prob
@def_function.function(experimental_compile=True)
def _log1p(x):
return math_ops.log1p(x)
class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
if flags.FLAGS.vary_seed:
entropy = os.urandom(64)
if six.PY2:
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer % (2**32 - 1))
super(Log1pTest, self).setUp()
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
if self.device not in ['TPU']:
return rtol, atol
if dtype == np.float32:
return 4e-4, 0.
return 1e-10, 0.
def _test_range(self, low, high, dtype, rtol, atol, is_negative=False):
# Test values near zero.
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
x = np.exp(np.random.uniform(
low=low, high=high, size=[NUM_SAMPLES])).astype(dtype)
if is_negative:
x = -x
expected_values = np.log1p(x)
with self.session() as sess:
with self.test_scope():
actual = _log1p(x)
actual = sess.run(actual)
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-7, 0.),
(np.float64, 1e-15, 0.))
def testSmallX(self, dtype, rtol, atol):
self._test_range(-40., -20., dtype, rtol, atol, is_negative=False)
self._test_range(-40., -20., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 1e-7, 0.),
(np.float64, 1e-15, 0.))
def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol):
self._test_range(-20., -10., dtype, rtol, atol, is_negative=False)
self._test_range(-20., -10., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 1e-7, 0.),
(np.float64, 1e-15, 0.))
def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol):
self._test_range(-10., -5., dtype, rtol, atol, is_negative=False)
self._test_range(-10., -5., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 2e-7, 0.),
(np.float64, 1e-15, 0.))
def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol):
self._test_range(-5., -1., dtype, rtol, atol, is_negative=False)
self._test_range(-5., -1., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 4e-7, 0.),
(np.float64, 3e-14, 0.))
def testXGreaterThanOneTenth(self, dtype, rtol, atol):
self._test_range(-1., 0., dtype, rtol, atol, is_negative=False)
self._test_range(-1., 0., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 2e-7, 0.),
(np.float64, 2e-15, 0.))
def testXGreaterThanOne(self, dtype, rtol, atol):
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):

View File

@ -292,13 +292,17 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[1, 2]], dtype=dtype),
expected=np.array([[0.540297, -0.41614]], dtype=dtype))
# Confirm that log1p will remain precise across a range of small values.
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]],
dtype=dtype)).astype(dtype),
rtol=1e-4,
atol=1e-6)
np.array([[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
dtype=dtype),
expected=np.log1p(
np.array(
[[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
dtype=dtype)).astype(dtype),
rtol=1e-15 if dtype == np.float64 else 1e-4,
atol=1e-15 if dtype == np.float64 else 1e-4)
self._assertOpOutputMatchesExpected(
math_ops.rint,

View File

@ -1336,9 +1336,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
const auto kAntilogarithmIsSmallThreshold = 1e-4;
// When x is small, (defined to be less than sqrt(2) / 2), use a rational
// approximation. The approximation below is based on one from the Cephes
// Mathematical Library.
//
// sqrt(2) - 1.
const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
static const std::array<double, 7> kDenominatorCoeffs{
1.,
1.5062909083469192043167E1,
8.3047565967967209469434E1,
2.2176239823732856465394E2,
3.0909872225312059774938E2,
2.1642788614495947685003E2,
6.0118660497603843919306E1,
};
static const std::array<double, 7> kNumeratorCoeffs{
4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
6.5787325942061044846969E0, 2.9911919328553073277375E1,
6.0949667980987787057556E1, 5.7112963590585538103336E1,
2.0039553499201281259648E1,
};
auto x_squared = FMul(x, x);
TF_ASSIGN_OR_RETURN(auto denominator,
EvaluatePolynomial(type, x, kDenominatorCoeffs));
TF_ASSIGN_OR_RETURN(auto numerator,
EvaluatePolynomial(type, x, kNumeratorCoeffs));
auto for_small_x = FDiv(numerator, denominator);
for_small_x = FMul(FMul(x, x_squared), for_small_x);
for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
for_small_x = FAdd(x, for_small_x);
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
auto x_is_small = FCmpOLT(
@ -2699,4 +2730,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
}
}
// Evaluate polynomial using Horner's method.
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
for (const double c : coefficients) {
poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
}
return poly;
}
} // namespace xla

View File

@ -258,6 +258,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
llvm::Value* a, llvm::Value* b,
llvm::Value* c, llvm::Value* d);
// Evaluates a polynomial using Horner's method.
StatusOr<llvm::Value*> EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
};
} // namespace xla