Add complex data type support for tf.math.acos in XLA
This PR tries to address the issur raised in 41370 where tf.math.acos throws out error with complex input data. The issue was that in XLA the `Acos` op does not capture the complex data types. This PR adds complex support for tf.math.acos in XLA This PR fixes 41370. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
585d841061
commit
3acb88ed27
@ -1112,10 +1112,26 @@ XlaOp RoundToEven(XlaOp x) {
|
||||
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
|
||||
// pi if x == -1
|
||||
XlaOp Acos(XlaOp x) {
|
||||
return Select(Ne(x, FullLike(x, -1)),
|
||||
ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
|
||||
ScalarLike(x, 1.0) + x),
|
||||
FullLike(x, M_PI));
|
||||
XlaBuilder* b = x.builder();
|
||||
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
|
||||
// complex: acos(x) = -i * log(x + sqrt(-(x+1)*(x-1)))
|
||||
if (primitive_util::IsComplexType(shape.element_type())) {
|
||||
auto one = ScalarLike(x, 1);
|
||||
auto imag_one = Complex(
|
||||
Zero(b, primitive_util::ComplexComponentType(shape.element_type())),
|
||||
One(b, primitive_util::ComplexComponentType(shape.element_type())));
|
||||
|
||||
auto result = Neg(
|
||||
imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x))));
|
||||
return result;
|
||||
}
|
||||
return Select(Ne(x, FullLike(x, -1)),
|
||||
ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
|
||||
ScalarLike(x, 1.0) + x),
|
||||
FullLike(x, M_PI));
|
||||
});
|
||||
}
|
||||
|
||||
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||
|
@ -660,5 +660,19 @@ XLA_TEST_F(MathTest, BesselI1eDouble) {
|
||||
ComputeAndCompareR1<double>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(MathTest, AcosComplexValues) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR1<std::complex<float>>(
|
||||
&builder, {{0, 0}, {0, 1}, {1, 1}, {0.8, 0.2}});
|
||||
|
||||
Acos(x);
|
||||
std::vector<std::complex<float>> expected = {
|
||||
{1.5707963267948966, 0},
|
||||
{1.5707963267948966, -0.881373587019543},
|
||||
{0.9045568943023814, -1.0612750619050357},
|
||||
{0.7011246914497526, -0.30527648462436596}};
|
||||
ComputeAndCompareR1<std::complex<float>>(&builder, expected, {}, error_spec_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user