[XLA] Fix edge case (abs(x) >= 1) for atanh implementation.
PiperOrigin-RevId: 246745571
This commit is contained in:
parent
fd2d8bc50e
commit
274062297a
@ -612,10 +612,28 @@ XlaOp Asinh(XlaOp x) {
|
||||
});
|
||||
}
|
||||
|
||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
|
||||
// atanh(x) = nan otherwise
|
||||
XlaOp Atanh(XlaOp x) {
|
||||
return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
|
||||
ScalarLike(x, 0.5);
|
||||
XlaBuilder* b = x.builder();
|
||||
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
auto naive_result =
|
||||
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
|
||||
ScalarLike(x, 0.5);
|
||||
|
||||
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
|
||||
// because we don't yet have exhaustive tests for complex trig functions.
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
return naive_result;
|
||||
}
|
||||
|
||||
auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
|
||||
return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
|
||||
};
|
||||
return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { //
|
||||
return b->ReportErrorOrReturn(do_it(x));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
|
||||
|
@ -571,11 +571,11 @@ XLA_TEST_P(ExhaustiveOpTest, Asinh) {
|
||||
}
|
||||
Run(Asinh, std::asinh);
|
||||
}
|
||||
XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
|
||||
|
||||
// TODO(jlebar): Enable these.
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Cosh) { Run(Cosh, std::cosh); }
|
||||
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
|
||||
|
Loading…
Reference in New Issue
Block a user