[XLA] Fix edge case (abs(x) >= 1) for atanh implementation.
PiperOrigin-RevId: 246745571
This commit is contained in:
parent
fd2d8bc50e
commit
274062297a
@ -612,10 +612,28 @@ XlaOp Asinh(XlaOp x) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
|
||||||
|
// atanh(x) = nan otherwise
|
||||||
XlaOp Atanh(XlaOp x) {
|
XlaOp Atanh(XlaOp x) {
|
||||||
return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
|
XlaBuilder* b = x.builder();
|
||||||
ScalarLike(x, 0.5);
|
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||||
|
auto naive_result =
|
||||||
|
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
|
||||||
|
ScalarLike(x, 0.5);
|
||||||
|
|
||||||
|
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
|
||||||
|
// because we don't yet have exhaustive tests for complex trig functions.
|
||||||
|
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||||
|
return naive_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
|
||||||
|
return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
|
||||||
|
};
|
||||||
|
return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { //
|
||||||
|
return b->ReportErrorOrReturn(do_it(x));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
|
XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
|
||||||
|
@ -571,11 +571,11 @@ XLA_TEST_P(ExhaustiveOpTest, Asinh) {
|
|||||||
}
|
}
|
||||||
Run(Asinh, std::asinh);
|
Run(Asinh, std::asinh);
|
||||||
}
|
}
|
||||||
|
XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
|
||||||
|
|
||||||
// TODO(jlebar): Enable these.
|
// TODO(jlebar): Enable these.
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); }
|
// XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); }
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); }
|
// XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); }
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
|
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Cosh) { Run(Cosh, std::cosh); }
|
// XLA_TEST_P(ExhaustiveOpTest, Cosh) { Run(Cosh, std::cosh); }
|
||||||
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
|
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }
|
||||||
|
Loading…
Reference in New Issue
Block a user