Improve numerics for Sinh, Asinh and Atanh in XLA.
- Rewrite Sinh,Asinh for smaller parameter regions so they return non-zero values for small x. - Use Log1p in Atanh to retrieve non-zero values for small x. PiperOrigin-RevId: 295820343 Change-Id: Ia330e201c2fac8497f3b021290550715cf067a81
This commit is contained in:
parent
8264abb627
commit
e47e4bfb9e
tensorflow/compiler/xla/client/lib
@ -1008,12 +1008,23 @@ XlaOp Asinh(XlaOp x) {
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
return Log(x + Sqrt(x * x + one));
|
||||
}
|
||||
// For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
|
||||
// arithmetic. However, we would like to retain the low order term of this,
|
||||
// which is around 0.5 * x**2 using a binomial expansion.
|
||||
// Let z = sqrt(a**2 + 1)
|
||||
// log(a + sqrt(a**2 + 1)) =
|
||||
// log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
|
||||
// log((a + a**2 + 1 + a * z + z) / (1 + z)) =
|
||||
// log(1 + a + a**2 / (1 + z)) =
|
||||
// log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
|
||||
// This rewrite retains the lower order term.
|
||||
auto a = Abs(x);
|
||||
auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
|
||||
auto naive_result = Log(a + Sqrt(a * a + one));
|
||||
auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
|
||||
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
|
||||
return Sign(x) *
|
||||
Select(Ge(a, sqrt_max_value), overflow_result, naive_result);
|
||||
return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
|
||||
Select(Le(a, one), small_result, naive_result));
|
||||
};
|
||||
// These upcasts are not strictly necessary on all platforms to get within our
|
||||
// error tolerances, so we could relax this if it ever mattered.
|
||||
@ -1028,9 +1039,7 @@ XlaOp Atanh(XlaOp x) {
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
auto naive_result =
|
||||
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
|
||||
ScalarLike(x, 0.5);
|
||||
auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
|
||||
|
||||
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
|
||||
// because we don't yet have exhaustive tests for complex trig functions.
|
||||
@ -1074,9 +1083,35 @@ XlaOp Cosh(XlaOp x) {
|
||||
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
||||
// we deem this acceptable.
|
||||
XlaOp Sinh(XlaOp x) {
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
auto one_half = ScalarLike(x, 0.5);
|
||||
auto log_one_half = Log(ScalarLike(x, 0.5));
|
||||
return Exp(x + log_one_half) - Exp(-x + log_one_half);
|
||||
auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
|
||||
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
return large_sinh_result;
|
||||
}
|
||||
|
||||
// Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
|
||||
// values of x.
|
||||
|
||||
// For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
|
||||
// 0.
|
||||
// Rewrite this to avoid that. We use expm1(x) because that preserves the
|
||||
// first order term of the taylor series of e^x.
|
||||
// (e^(x) - e^(-x)) / 2. =
|
||||
// (e^(x) - 1 + 1 - e^(-x)) / 2.
|
||||
// (expm1(x) + (e^(x) - 1) / e^x) / 2.
|
||||
// (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
|
||||
auto expm1 = Expm1(x);
|
||||
auto one = ScalarLike(x, 1.);
|
||||
auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
|
||||
return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
|
||||
};
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
|
||||
return b->ReportErrorOrReturn(do_it(x));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -298,6 +298,30 @@ XLA_TEST_F(MathTest, SqrtSixValues) {
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, SinhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
|
||||
Sinh(x);
|
||||
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, AsinhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11});
|
||||
Asinh(x);
|
||||
std::vector<float> expected = {1e-3, 1e-5, 1e-7, 1e-9, 1e-11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, AtanhSmallValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1e-8, 1e-9, 1e-10, 1e-11});
|
||||
Atanh(x);
|
||||
std::vector<float> expected = {1e-8, 1e-9, 1e-10, 1e-11};
|
||||
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, Lgamma) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5,
|
||||
|
Loading…
Reference in New Issue
Block a user