[XLA] Fix numeric stability of acosh implementation.

PiperOrigin-RevId: 246745091
This commit is contained in:
Justin Lebar 2019-05-05 15:24:55 -07:00 committed by TensorFlower Gardener
parent f98c967062
commit 46ccc4e6dd
2 changed files with 57 additions and 5 deletions

View File

@ -532,10 +532,41 @@ XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
// Hyperbolic trigonometric functions.
// acosh(x) = log(x + sqrt(x^2 - 1))
// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
// = log(x + sqrt((x+1)*(x-1)))
// acosh(x) = nan if x < -1
//
// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
// log(2*x) = log(2) + log(x). (Note this works because negative x never
// overflows; x < -1 simply yields nan. This is quite different than asinh!)
XlaOp Acosh(XlaOp x) {
return Log(x + Sqrt((x + ScalarLike(x, 1.0)) * (x - ScalarLike(x, 1.0))));
XlaBuilder* b = x.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto one = ScalarLike(x, 1);
auto neg_one = ScalarLike(x, -1);
auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
// return
//
// nan if x < -1
// log(x) + log(2) if x >= sqrt_max_value
// log(x + sqrt((x+1)*(x-1))) otherwise
//
// TODO(jlebar): For now, we ignore the question of overflow if x is a
// complex type, because we don't yet have exhaustive tests for complex trig
// functions.
auto naive_result = Log(x + Sqrt((x + one) * (x - one)));
if (primitive_util::IsComplexType(shape.element_type())) {
return naive_result;
}
auto overflow_result = Log(x) + Log(ScalarLike(x, 2));
auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
return Select(Lt(x, neg_one), nan,
Select(Ge(x, sqrt_max_value), overflow_result, naive_result));
});
}
// asinh(x) = log(x + sqrt(x^2 + 1))

View File

@ -215,7 +215,7 @@ class ExhaustiveOpTest
RunImpl<half, uint16>(enqueue_op, evaluate_op);
break;
case BF16:
SetDefaultErrSpec(0.001, 0.01);
SetDefaultErrSpec(0.001, 0.02);
RunImpl<bfloat16, uint16>(enqueue_op, evaluate_op);
break;
default:
@ -553,9 +553,30 @@ XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
Run(Sqrt, std::sqrt);
}
// TODO(jlebar): Add remaining trig functions. Don't forget Atan2!
// TODO(jlebar): Test trig functions over complex inputs.
XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
XLA_TEST_P(ExhaustiveOpTest, Acosh) {
// Error inherited from Log, which our implementation of Acosh uses.
if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
abs_err_ = 0.001;
rel_err_ = 0.001;
}
Run(Acosh, std::acosh);
}
// TODO(jlebar): Enable these.
// XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); }
// XLA_TEST_P(ExhaustiveOpTest, Asinh) { Run(Asinh, std::asinh); }
// XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); }
// XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
// XLA_TEST_P(ExhaustiveOpTest, Cosh) { Run(Cosh, std::cosh); }
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
// XLA_TEST_P(ExhaustiveOpTest, Sinh) { Run(Sinh, std::sinh); }
// XLA_TEST_P(ExhaustiveOpTest, Sin) { Run(Sin, std::sin); }
// XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
// XLA_TEST_P(ExhaustiveOpTest, Tan) { Run(Tan, std::tan); }
// XLA_TEST_P(ExhaustiveOpTest, Atan2) { Run(Atan2, std::atan2); }
XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }