Add complex data type support for tf.math.acos in XLA

This PR tries to address the issur raised in 41370
where tf.math.acos throws out error with complex input data.
The issue was that in XLA the `Acos` op does not capture
the complex data types.

This PR adds complex support for tf.math.acos in XLA

This PR fixes 41370.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-07-14 18:47:03 +00:00
parent 585d841061
commit 3acb88ed27
2 changed files with 34 additions and 4 deletions

View File

@ -1112,10 +1112,26 @@ XlaOp RoundToEven(XlaOp x) {
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
// pi if x == -1
XlaOp Acos(XlaOp x) {
return Select(Ne(x, FullLike(x, -1)),
ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
ScalarLike(x, 1.0) + x),
FullLike(x, M_PI));
XlaBuilder* b = x.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
// complex: acos(x) = -i * log(x + sqrt(-(x+1)*(x-1)))
if (primitive_util::IsComplexType(shape.element_type())) {
auto one = ScalarLike(x, 1);
auto imag_one = Complex(
Zero(b, primitive_util::ComplexComponentType(shape.element_type())),
One(b, primitive_util::ComplexComponentType(shape.element_type())));
auto result = Neg(
imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x))));
return result;
}
return Select(Ne(x, FullLike(x, -1)),
ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
ScalarLike(x, 1.0) + x),
FullLike(x, M_PI));
});
}
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))

View File

@ -660,5 +660,19 @@ XLA_TEST_F(MathTest, BesselI1eDouble) {
ComputeAndCompareR1<double>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(MathTest, AcosComplexValues) {
XlaBuilder builder(TestName());
auto x = ConstantR1<std::complex<float>>(
&builder, {{0, 0}, {0, 1}, {1, 1}, {0.8, 0.2}});
Acos(x);
std::vector<std::complex<float>> expected = {
{1.5707963267948966, 0},
{1.5707963267948966, -0.881373587019543},
{0.9045568943023814, -1.0612750619050357},
{0.7011246914497526, -0.30527648462436596}};
ComputeAndCompareR1<std::complex<float>>(&builder, expected, {}, error_spec_);
}
} // namespace
} // namespace xla