[XLA] Fix edge case (abs(x) >= 1) for atanh implementation.

PiperOrigin-RevId: 246745571
This commit is contained in:
Justin Lebar 2019-05-05 15:36:57 -07:00 committed by TensorFlower Gardener
parent fd2d8bc50e
commit 274062297a
2 changed files with 22 additions and 4 deletions

View File

@ -612,10 +612,28 @@ XlaOp Asinh(XlaOp x) {
});
}
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
// atanh(x) = nan otherwise
XlaOp Atanh(XlaOp x) {
return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
ScalarLike(x, 0.5);
XlaBuilder* b = x.builder();
auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
auto naive_result =
Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
ScalarLike(x, 0.5);
// TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
// because we don't yet have exhaustive tests for complex trig functions.
if (primitive_util::IsComplexType(shape.element_type())) {
return naive_result;
}
auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
};
return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { //
return b->ReportErrorOrReturn(do_it(x));
});
}
XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }

View File

@ -571,11 +571,11 @@ XLA_TEST_P(ExhaustiveOpTest, Asinh) {
}
Run(Asinh, std::asinh);
}
XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
// TODO(jlebar): Enable these.
// XLA_TEST_P(ExhaustiveOpTest, Acos) { Run(Acos, std::acos); }
// XLA_TEST_P(ExhaustiveOpTest, Asin) { Run(Asin, std::asin); }
// XLA_TEST_P(ExhaustiveOpTest, Atanh) { Run(Atanh, std::atanh); }
// XLA_TEST_P(ExhaustiveOpTest, Atan) { Run(Atan, std::atan); }
// XLA_TEST_P(ExhaustiveOpTest, Cosh) { Run(Cosh, std::cosh); }
// XLA_TEST_P(ExhaustiveOpTest, Cos) { Run(Cos, std::cos); }