diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index f0b4057e039..ee70adc6e7d 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -612,10 +612,28 @@ XlaOp Asinh(XlaOp x) { }); } -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) +// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 +// atanh(x) = nan otherwise XlaOp Atanh(XlaOp x) { - return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * - ScalarLike(x, 0.5); + XlaBuilder* b = x.builder(); + auto do_it = [&](XlaOp x) -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + auto naive_result = + Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * + ScalarLike(x, 0.5); + + // TODO(jlebar): For now, we ignore the nan edge case for complex inputs, + // because we don't yet have exhaustive tests for complex trig functions. + if (primitive_util::IsComplexType(shape.element_type())) { + return naive_result; + } + + auto nan = FullLike(x, std::numeric_limits::quiet_NaN()); + return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result); + }; + return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { // + return b->ReportErrorOrReturn(do_it(x)); + }); } XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc index 9f0d6053da4..7f35a61ba33 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test.cc @@ -571,11 +571,11 @@ XLA_TEST_P(ExhaustiveOpTest, Asinh) { } Run(Asinh, std::asinh); } +XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); } // TODO(jlebar): Enable these. // XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); } // XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); } -// XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); } // XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); } // XLA_TEST_P(ExhaustiveOpTest, Cosh) { Run(Cosh, std::cosh); } // XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }