Improve numerics for Sinh, Asinh and Atanh in XLA.

- Rewrite Sinh,Asinh for smaller parameter regions so they return non-zero values for small x.
   - Use Log1p in Atanh to retrieve non-zero values for small x.

PiperOrigin-RevId: 295820343
Change-Id: Ia330e201c2fac8497f3b021290550715cf067a81
This commit is contained in:
Srinivas Vasudevan 2020-02-18 14:17:29 -08:00 committed by TensorFlower Gardener
parent 8264abb627
commit e47e4bfb9e
2 changed files with 66 additions and 7 deletions
tensorflow/compiler/xla/client/lib

View File

@ -1008,12 +1008,23 @@ XlaOp Asinh(XlaOp x) {
if (primitive_util::IsComplexType(shape.element_type())) {
return Log(x + Sqrt(x * x + one));
}
// For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
// arithmetic. However, we would like to retain the low order term of this,
// which is around 0.5 * x**2 using a binomial expansion.
// Let z = sqrt(a**2 + 1)
// log(a + sqrt(a**2 + 1)) =
// log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
// log((a + a**2 + 1 + a * z + z) / (1 + z)) =
// log(1 + a + a**2 / (1 + z)) =
// log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
// This rewrite retains the lower order term.
auto a = Abs(x);
auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
auto naive_result = Log(a + Sqrt(a * a + one));
auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
return Sign(x) *
Select(Ge(a, sqrt_max_value), overflow_result, naive_result);
return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
Select(Le(a, one), small_result, naive_result));
};
// These upcasts are not strictly necessary on all platforms to get within our
// error tolerances, so we could relax this if it ever mattered.
@ -1028,9 +1039,7 @@ XlaOp Atanh(XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto naive_result =
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
ScalarLike(x, 0.5);
auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
// because we don't yet have exhaustive tests for complex trig functions.
@ -1074,9 +1083,35 @@ XlaOp Cosh(XlaOp x) {
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
XlaOp Sinh(XlaOp x) {
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto one_half = ScalarLike(x, 0.5);
auto log_one_half = Log(ScalarLike(x, 0.5));
return Exp(x + log_one_half) - Exp(-x + log_one_half);
auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
if (primitive_util::IsComplexType(shape.element_type())) {
return large_sinh_result;
}
// Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
// values of x.
// For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
// 0.
// Rewrite this to avoid that. We use expm1(x) because that preserves the
// first order term of the taylor series of e^x.
// (e^(x) - e^(-x)) / 2. =
// (e^(x) - 1 + 1 - e^(-x)) / 2.
// (expm1(x) + (e^(x) - 1) / e^x) / 2.
// (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
auto expm1 = Expm1(x);
auto one = ScalarLike(x, 1.);
auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
};
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
return b->ReportErrorOrReturn(do_it(x));
});
}

View File

@ -298,6 +298,30 @@ XLA_TEST_F(MathTest, SqrtSixValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, SinhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
Sinh(x);
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, AsinhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
Asinh(x);
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, AtanhSmallValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1e-8, 1e-9, 1e-10, 1e-11});
Atanh(x);
std::vector<float> expected = {1e-8, 1e-9, 1e-10, 1e-11};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, Lgamma) {
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5,