[XLA:CPU/GPU] Fix codegen for log(complex).
Previous, we calculate .5*log(a^2+b^2) which can overfloat when the magnitude of a and b are large. Replace this formula with log(abs(a+bi)) to fix the problem. Limit exhaustive_unary_test_complex for the CPU and GPU backends. Implement StringifyNum for double. Change the EvaluateOp for complex to use a const reference input. PiperOrigin-RevId: 260998149
This commit is contained in:
parent
fb7da355b0
commit
620fbe292f
@ -515,15 +515,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
|
||||
: input_type;
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kLog: {
|
||||
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
||||
// log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
|
||||
auto a = EmitExtractReal(operand_value);
|
||||
auto b = EmitExtractImag(operand_value);
|
||||
llvm::Type* llvm_ty = a->getType();
|
||||
auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
|
||||
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
|
||||
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
|
||||
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
|
||||
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * abs,
|
||||
EmitComplexAbs(component_type, operand_value));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
|
||||
return EmitComposeComplex(op, log_abs, angle);
|
||||
}
|
||||
case HloOpcode::kLog1p: {
|
||||
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
|
||||
|
@ -748,6 +748,10 @@ xla_test(
|
||||
xla_test(
|
||||
name = "exhaustive_unary_test_complex",
|
||||
srcs = ["exhaustive_unary_test.cc"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
copts = ["-DUNARY_TEST_TARGET_COMPLEX"],
|
||||
real_hardware_only = True, # Very slow on the interpreter.
|
||||
shard_count = 48,
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
|
||||
// guaranteed that we're printing the full number.
|
||||
// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
|
||||
// precision to be guaranteed that we're printing the full number.
|
||||
//
|
||||
// (The general formula is, given a floating-point number with S significand
|
||||
// bits, the number of decimal digits needed to print it to full precision is
|
||||
@ -26,6 +26,11 @@ namespace xla {
|
||||
// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
|
||||
//
|
||||
// See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
|
||||
/*static*/
|
||||
string ExhaustiveOpTestBase::StringifyNum(double x) {
|
||||
return absl::StrFormat("%0.17g (0x%016x)", x, BitCast<uint64>(x));
|
||||
}
|
||||
|
||||
/*static*/
|
||||
string ExhaustiveOpTestBase::StringifyNum(float x) {
|
||||
return absl::StrFormat("%0.9g (0x%08x)", x, BitCast<uint32>(x));
|
||||
|
@ -198,6 +198,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
return ConvertValue<T>(bits);
|
||||
}
|
||||
|
||||
static string StringifyNum(double x);
|
||||
|
||||
static string StringifyNum(float x);
|
||||
|
||||
static string StringifyNum(half x);
|
||||
|
@ -835,7 +835,7 @@ class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase {
|
||||
// T is the component type of the complex number.
|
||||
template <typename T>
|
||||
void Run(std::function<XlaOp(XlaOp)> enqueue_op,
|
||||
std::complex<T> (*evaluate_op)(std::complex<T>),
|
||||
std::complex<T> (*evaluate_op)(const std::complex<T>&),
|
||||
FpValues* values_real, FpValues* values_imag,
|
||||
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||
Literal input_literal = CreateInputLiteral();
|
||||
@ -883,7 +883,7 @@ class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase {
|
||||
template <typename T>
|
||||
void ExpectNearComplex(const Literal& input_literal,
|
||||
const Literal& result_literal,
|
||||
std::complex<T> (*evaluate_op)(std::complex<T>),
|
||||
std::complex<T> (*evaluate_op)(const std::complex<T>&),
|
||||
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||
absl::Span<const std::complex<T>> input_arr =
|
||||
input_literal.data<std::complex<T>>();
|
||||
@ -938,7 +938,7 @@ class ExhaustiveC64UnaryTest
|
||||
public ::testing::WithParamInterface<
|
||||
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
||||
public:
|
||||
typedef complex64 (*C64EvaluateOp)(complex64);
|
||||
typedef complex64 (*C64EvaluateOp)(const complex64&);
|
||||
|
||||
ExhaustiveC64UnaryTest()
|
||||
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
||||
@ -962,6 +962,11 @@ class ExhaustiveC64UnaryTest
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
|
||||
XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) {
|
||||
Run(Log, std::log<float>);
|
||||
}
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32SpecialValues, ExhaustiveC64UnaryTest,
|
||||
@ -969,7 +974,6 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::Values(C64),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
|
||||
::testing::Combine(
|
||||
@ -1013,7 +1017,7 @@ class ExhaustiveC128UnaryTest
|
||||
public ::testing::WithParamInterface<
|
||||
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
||||
public:
|
||||
typedef complex128 (*C128EvaluateOp)(complex128);
|
||||
typedef complex128 (*C128EvaluateOp)(const complex128&);
|
||||
|
||||
ExhaustiveC128UnaryTest()
|
||||
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
||||
@ -1038,14 +1042,13 @@ class ExhaustiveC128UnaryTest
|
||||
};
|
||||
|
||||
XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
|
||||
// TODO(bixia): only test values that are not too big and not too small
|
||||
// for now and will work on fixing the implementation of XLA
|
||||
// operations to enable test for other values.
|
||||
// TODO(b/138578313): Enable the test for all values after fixing the bug.
|
||||
known_incorrect_fn_ = [&](int64 v) {
|
||||
double f = ConvertValue<double>(v);
|
||||
return std::fpclassify(f) == FP_NAN || std::abs(f) > 5 || std::abs(f) < 1;
|
||||
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
|
||||
std::abs(f) < 1.0e-300;
|
||||
};
|
||||
Run(Log, [](complex128 x) { return std::log(x); });
|
||||
Run(Log, std::log<double>);
|
||||
}
|
||||
|
||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
||||
|
Loading…
x
Reference in New Issue
Block a user