[XLA] Fix numerical stability of asinh implementation.
PiperOrigin-RevId: 246745284
This commit is contained in:
parent
46ccc4e6dd
commit
fd2d8bc50e
@ -570,7 +570,47 @@ XlaOp Acosh(XlaOp x) {
|
||||
}
|
||||
|
||||
// asinh(x) = log(x + sqrt(x^2 + 1))
|
||||
XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); }
|
||||
//
|
||||
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
|
||||
// as 2*x and return log(2) + log(x).
|
||||
//
|
||||
// If x is negative, the above would give us some trouble, because we'd need to
|
||||
// approximate x + sqrt(sqrt(x^2 + 1) - abs(x). But we're saved
|
||||
// by the fact that asinh(-x) = -asinh(x).
|
||||
XlaOp Asinh(XlaOp x) {
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
auto one = ScalarLike(x, 1);
|
||||
|
||||
// Let a = abs(x). Compute
|
||||
//
|
||||
// y = log(a + sqrt(a*a + 1)) if a < sqrt_max_value, or
|
||||
// y = log(a) + log(2) otherwise
|
||||
//
|
||||
// and then return
|
||||
//
|
||||
// y * sign(x).
|
||||
//
|
||||
// TODO(jlebar): For now, we ignore the question of overflow if x is a
|
||||
// complex type, because we don't yet have exhaustive tests for complex trig
|
||||
// functions.
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
return Log(x + Sqrt(x * x + one));
|
||||
}
|
||||
auto a = Abs(x);
|
||||
auto naive_result = Log(a + Sqrt(a * a + one));
|
||||
auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
|
||||
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
|
||||
return Sign(x) *
|
||||
Select(Ge(a, sqrt_max_value), overflow_result, naive_result);
|
||||
};
|
||||
// These upcasts are not strictly necessary on all platforms to get within our
|
||||
// error tolerances, so we could relax this if it ever mattered.
|
||||
return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
|
||||
return b->ReportErrorOrReturn(do_it(x));
|
||||
});
|
||||
}
|
||||
|
||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||
XlaOp Atanh(XlaOp x) {
|
||||
|
@ -215,7 +215,7 @@ class ExhaustiveOpTest
|
||||
RunImpl<half, uint16>(enqueue_op, evaluate_op);
|
||||
break;
|
||||
case BF16:
|
||||
SetDefaultErrSpec(0.001, 0.02);
|
||||
SetDefaultErrSpec(0.002, 0.02);
|
||||
RunImpl<bfloat16, uint16>(enqueue_op, evaluate_op);
|
||||
break;
|
||||
default:
|
||||
@ -563,10 +563,17 @@ XLA_TEST_P(ExhaustiveOpTest, Acosh) {
|
||||
}
|
||||
Run(Acosh, std::acosh);
|
||||
}
|
||||
XLA_TEST_P(ExhaustiveOpTest, Asinh) {
|
||||
// Error inherited from Log, which our implementation of Asinh uses.
|
||||
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
|
||||
abs_err_ = 0.001;
|
||||
rel_err_ = 0.001;
|
||||
}
|
||||
Run(Asinh, std::asinh);
|
||||
}
|
||||
|
||||
// TODO(jlebar): Enable these.
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Asinh) { Run(Asinh, std::asinh); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
||||
|
Loading…
Reference in New Issue
Block a user