From c00907e0804478812d9193feb2c3657c288295b7 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 24 Jan 2019 19:18:24 -0800 Subject: [PATCH] [XLA] Fix implementation of acos(-1). At least on fp16, it was returning 0, when it should have returned pi. PiperOrigin-RevId: 230834226 --- tensorflow/compiler/xla/client/lib/math.cc | 9 ++++++--- .../compiler/xla/client/lib/math_exhaustive_test.cc | 6 ++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 2426a433fcc..f7e29db690f 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -367,10 +367,13 @@ XlaOp RoundToEven(XlaOp x) { // Trigonometric functions. -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// pi if x == -1 XlaOp Acos(XlaOp x) { - return ScalarLike(x, 2.0) * - Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc index 123af5e51c8..0408ed1ba23 100644 --- a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc @@ -137,10 +137,7 @@ XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { ComputeAndCompareR1(&b, expected_result, {}, tc.error); } -// TODO(jlebar): The following tests are missing. -// -// - Fails on -1 (returns 0 instead of pi). -// Testcase{"acos", Acos, std::acos}.set_skip_infs(), +// TODO(b/123355973): The following tests are missing. // // - Many failures. // Testcase{"acosh", Acosh, std::acosh}.set_relaxed_nans(), @@ -170,6 +167,7 @@ INSTANTIATE_TEST_CASE_P( .set_tolerance(0.1, 0.15) .set_fewer_infs_ok(), Testcase{"asin", Asin, std::asin}.set_skip_infs(), + Testcase{"acos", Acos, std::acos}.set_skip_infs(), Testcase{"atan", Atan, std::atan}, Testcase{"tan", Tan, std::tan}.set_tolerance(0.05, 0.05), }));