Improve numerics of Log1p for XLA:CPU.
PiperOrigin-RevId: 316714497 Change-Id: I20bf0148a850451077e435190034418e7eb9b38c
This commit is contained in:
parent
08e445e37f
commit
e6e8d48f8b
@ -61,6 +61,81 @@ def implicit_reparameterization_grad(a, x):
|
||||
return -gen_math_ops.igamma_grad_a(a, x) / prob
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def _log1p(x):
|
||||
return math_ops.log1p(x)
|
||||
|
||||
|
||||
class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if flags.FLAGS.vary_seed:
|
||||
entropy = os.urandom(64)
|
||||
if six.PY2:
|
||||
answer = int(entropy.encode('hex'), 16)
|
||||
else:
|
||||
answer = int.from_bytes(entropy, 'big')
|
||||
np.random.seed(answer % (2**32 - 1))
|
||||
super(Log1pTest, self).setUp()
|
||||
|
||||
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
|
||||
if self.device not in ['TPU']:
|
||||
return rtol, atol
|
||||
|
||||
if dtype == np.float32:
|
||||
return 4e-4, 0.
|
||||
return 1e-10, 0.
|
||||
|
||||
def _test_range(self, low, high, dtype, rtol, atol, is_negative=False):
|
||||
# Test values near zero.
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
x = np.exp(np.random.uniform(
|
||||
low=low, high=high, size=[NUM_SAMPLES])).astype(dtype)
|
||||
if is_negative:
|
||||
x = -x
|
||||
expected_values = np.log1p(x)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = _log1p(x)
|
||||
actual = sess.run(actual)
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testSmallX(self, dtype, rtol, atol):
|
||||
self._test_range(-40., -20., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-40., -20., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol):
|
||||
self._test_range(-20., -10., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-20., -10., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol):
|
||||
self._test_range(-10., -5., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-10., -5., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol):
|
||||
self._test_range(-5., -1., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-5., -1., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 4e-7, 0.),
|
||||
(np.float64, 3e-14, 0.))
|
||||
def testXGreaterThanOneTenth(self, dtype, rtol, atol):
|
||||
self._test_range(-1., 0., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-1., 0., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-7, 0.),
|
||||
(np.float64, 2e-15, 0.))
|
||||
def testXGreaterThanOne(self, dtype, rtol, atol):
|
||||
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
|
||||
|
||||
|
||||
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -292,13 +292,17 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
expected=np.array([[0.540297, -0.41614]], dtype=dtype))
|
||||
|
||||
# Confirm that log1p will remain precise across a range of small values.
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.log1p,
|
||||
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
|
||||
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]],
|
||||
dtype=dtype)).astype(dtype),
|
||||
rtol=1e-4,
|
||||
atol=1e-6)
|
||||
np.array([[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
|
||||
dtype=dtype),
|
||||
expected=np.log1p(
|
||||
np.array(
|
||||
[[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
|
||||
dtype=dtype)).astype(dtype),
|
||||
rtol=1e-15 if dtype == np.float64 else 1e-4,
|
||||
atol=1e-15 if dtype == np.float64 else 1e-4)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.rint,
|
||||
|
@ -1336,9 +1336,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
|
||||
// When x is large, the naive evaluation of ln(x + 1) is more
|
||||
// accurate than the Taylor series.
|
||||
TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
|
||||
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
|
||||
auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
|
||||
const auto kAntilogarithmIsSmallThreshold = 1e-4;
|
||||
// When x is small, (defined to be less than sqrt(2) / 2), use a rational
|
||||
// approximation. The approximation below is based on one from the Cephes
|
||||
// Mathematical Library.
|
||||
//
|
||||
// sqrt(2) - 1.
|
||||
const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
|
||||
|
||||
static const std::array<double, 7> kDenominatorCoeffs{
|
||||
1.,
|
||||
1.5062909083469192043167E1,
|
||||
8.3047565967967209469434E1,
|
||||
2.2176239823732856465394E2,
|
||||
3.0909872225312059774938E2,
|
||||
2.1642788614495947685003E2,
|
||||
6.0118660497603843919306E1,
|
||||
};
|
||||
|
||||
static const std::array<double, 7> kNumeratorCoeffs{
|
||||
4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
|
||||
6.5787325942061044846969E0, 2.9911919328553073277375E1,
|
||||
6.0949667980987787057556E1, 5.7112963590585538103336E1,
|
||||
2.0039553499201281259648E1,
|
||||
};
|
||||
|
||||
auto x_squared = FMul(x, x);
|
||||
TF_ASSIGN_OR_RETURN(auto denominator,
|
||||
EvaluatePolynomial(type, x, kDenominatorCoeffs));
|
||||
TF_ASSIGN_OR_RETURN(auto numerator,
|
||||
EvaluatePolynomial(type, x, kNumeratorCoeffs));
|
||||
auto for_small_x = FDiv(numerator, denominator);
|
||||
for_small_x = FMul(FMul(x, x_squared), for_small_x);
|
||||
for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
|
||||
for_small_x = FAdd(x, for_small_x);
|
||||
|
||||
auto abs_x =
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
|
||||
auto x_is_small = FCmpOLT(
|
||||
@ -2699,4 +2730,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate polynomial using Horner's method.
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
|
||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
|
||||
llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
|
||||
for (const double c : coefficients) {
|
||||
poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
|
||||
}
|
||||
return poly;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -258,6 +258,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
|
||||
llvm::Value* a, llvm::Value* b,
|
||||
llvm::Value* c, llvm::Value* d);
|
||||
|
||||
// Evaluates a polynomial using Horner's method.
|
||||
StatusOr<llvm::Value*> EvaluatePolynomial(
|
||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user