From fd2d8bc50e9b3143544819bf505326e4ed6db2a5 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 5 May 2019 15:30:55 -0700 Subject: [PATCH] [XLA] Fix numerical stability of asinh implementation. PiperOrigin-RevId: 246745284 --- tensorflow/compiler/xla/client/lib/math.cc | 42 ++++++++++++++++++- .../compiler/xla/tests/exhaustive_op_test.cc | 11 ++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index e151477b5c3..f0b4057e039 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -570,7 +570,47 @@ XlaOp Acosh(XlaOp x) { } // asinh(x) = log(x + sqrt(x^2 + 1)) -XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); } +// +// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) +// as 2*x and return log(2) + log(x). +// +// If x is negative, the above would give us some trouble, because we'd need to +// approximate x + sqrt(sqrt(x^2 + 1) - abs(x). But we're saved +// by the fact that asinh(-x) = -asinh(x). +XlaOp Asinh(XlaOp x) { + XlaBuilder* b = x.builder(); + auto do_it = [&](XlaOp x) -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + auto one = ScalarLike(x, 1); + + // Let a = abs(x). Compute + // + // y = log(a + sqrt(a*a + 1)) if a < sqrt_max_value, or + // y = log(a) + log(2) otherwise + // + // and then return + // + // y * sign(x). + // + // TODO(jlebar): For now, we ignore the question of overflow if x is a + // complex type, because we don't yet have exhaustive tests for complex trig + // functions. + if (primitive_util::IsComplexType(shape.element_type())) { + return Log(x + Sqrt(x * x + one)); + } + auto a = Abs(x); + auto naive_result = Log(a + Sqrt(a * a + one)); + auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2)); + auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type())); + return Sign(x) * + Select(Ge(a, sqrt_max_value), overflow_result, naive_result); + }; + // These upcasts are not strictly necessary on all platforms to get within our + // error tolerances, so we could relax this if it ever mattered. + return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) { + return b->ReportErrorOrReturn(do_it(x)); + }); +} // atanh(x) = 0.5 * log((1 + x) / (1 - x)) XlaOp Atanh(XlaOp x) { diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc index bfc3b0743ce..9f0d6053da4 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -215,7 +215,7 @@ class ExhaustiveOpTest RunImpl(enqueue_op, evaluate_op); break; case BF16: - SetDefaultErrSpec(0.001, 0.02); + SetDefaultErrSpec(0.002, 0.02); RunImpl(enqueue_op, evaluate_op); break; default: @@ -563,10 +563,17 @@ XLA_TEST_P(ExhaustiveOpTest, Acosh) { } Run(Acosh, std::acosh); } +XLA_TEST_P(ExhaustiveOpTest, Asinh) { + // Error inherited from Log, which our implementation of Asinh uses. + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + abs_err_ = 0.001; + rel_err_ = 0.001; + } + Run(Asinh, std::asinh); +} // TODO(jlebar): Enable these. // XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); } -// XLA_TEST_P(ExhaustiveOpTest, Asinh) { Run(Asinh, std::asinh); } // XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); } // XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); } // XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }