[XLA:CPU/GPU] Fix codegen for log(complex).
Previous, we calculate .5*log(a^2+b^2) which can overfloat when the magnitude of a and b are large. Replace this formula with log(abs(a+bi)) to fix the problem. Limit exhaustive_unary_test_complex for the CPU and GPU backends. Implement StringifyNum for double. Change the EvaluateOp for complex to use a const reference input. PiperOrigin-RevId: 260998149
This commit is contained in:
parent
fb7da355b0
commit
620fbe292f
@ -515,15 +515,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
|
|||||||
: input_type;
|
: input_type;
|
||||||
switch (op->opcode()) {
|
switch (op->opcode()) {
|
||||||
case HloOpcode::kLog: {
|
case HloOpcode::kLog: {
|
||||||
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
// log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
|
||||||
auto a = EmitExtractReal(operand_value);
|
auto a = EmitExtractReal(operand_value);
|
||||||
auto b = EmitExtractImag(operand_value);
|
auto b = EmitExtractImag(operand_value);
|
||||||
llvm::Type* llvm_ty = a->getType();
|
TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a));
|
||||||
auto sum_sq = FAdd(FMul(a, a), FMul(b, b));
|
TF_ASSIGN_OR_RETURN(llvm::Value * abs,
|
||||||
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
|
EmitComplexAbs(component_type, operand_value));
|
||||||
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
|
TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
|
||||||
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
|
return EmitComposeComplex(op, log_abs, angle);
|
||||||
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
|
|
||||||
}
|
}
|
||||||
case HloOpcode::kLog1p: {
|
case HloOpcode::kLog1p: {
|
||||||
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
|
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
|
||||||
|
@ -748,6 +748,10 @@ xla_test(
|
|||||||
xla_test(
|
xla_test(
|
||||||
name = "exhaustive_unary_test_complex",
|
name = "exhaustive_unary_test_complex",
|
||||||
srcs = ["exhaustive_unary_test.cc"],
|
srcs = ["exhaustive_unary_test.cc"],
|
||||||
|
backends = [
|
||||||
|
"gpu",
|
||||||
|
"cpu",
|
||||||
|
],
|
||||||
copts = ["-DUNARY_TEST_TARGET_COMPLEX"],
|
copts = ["-DUNARY_TEST_TARGET_COMPLEX"],
|
||||||
real_hardware_only = True, # Very slow on the interpreter.
|
real_hardware_only = True, # Very slow on the interpreter.
|
||||||
shard_count = 48,
|
shard_count = 48,
|
||||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
|
// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of
|
||||||
// guaranteed that we're printing the full number.
|
// precision to be guaranteed that we're printing the full number.
|
||||||
//
|
//
|
||||||
// (The general formula is, given a floating-point number with S significand
|
// (The general formula is, given a floating-point number with S significand
|
||||||
// bits, the number of decimal digits needed to print it to full precision is
|
// bits, the number of decimal digits needed to print it to full precision is
|
||||||
@ -26,6 +26,11 @@ namespace xla {
|
|||||||
// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
|
// ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
|
||||||
//
|
//
|
||||||
// See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
|
// See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
|
||||||
|
/*static*/
|
||||||
|
string ExhaustiveOpTestBase::StringifyNum(double x) {
|
||||||
|
return absl::StrFormat("%0.17g (0x%016x)", x, BitCast<uint64>(x));
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
string ExhaustiveOpTestBase::StringifyNum(float x) {
|
string ExhaustiveOpTestBase::StringifyNum(float x) {
|
||||||
return absl::StrFormat("%0.9g (0x%08x)", x, BitCast<uint32>(x));
|
return absl::StrFormat("%0.9g (0x%08x)", x, BitCast<uint32>(x));
|
||||||
|
@ -198,6 +198,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
|||||||
return ConvertValue<T>(bits);
|
return ConvertValue<T>(bits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static string StringifyNum(double x);
|
||||||
|
|
||||||
static string StringifyNum(float x);
|
static string StringifyNum(float x);
|
||||||
|
|
||||||
static string StringifyNum(half x);
|
static string StringifyNum(half x);
|
||||||
|
@ -835,7 +835,7 @@ class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase {
|
|||||||
// T is the component type of the complex number.
|
// T is the component type of the complex number.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void Run(std::function<XlaOp(XlaOp)> enqueue_op,
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op,
|
||||||
std::complex<T> (*evaluate_op)(std::complex<T>),
|
std::complex<T> (*evaluate_op)(const std::complex<T>&),
|
||||||
FpValues* values_real, FpValues* values_imag,
|
FpValues* values_real, FpValues* values_imag,
|
||||||
std::function<ErrorSpec(float)> error_spec_gen) {
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
Literal input_literal = CreateInputLiteral();
|
Literal input_literal = CreateInputLiteral();
|
||||||
@ -883,7 +883,7 @@ class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void ExpectNearComplex(const Literal& input_literal,
|
void ExpectNearComplex(const Literal& input_literal,
|
||||||
const Literal& result_literal,
|
const Literal& result_literal,
|
||||||
std::complex<T> (*evaluate_op)(std::complex<T>),
|
std::complex<T> (*evaluate_op)(const std::complex<T>&),
|
||||||
std::function<ErrorSpec(float)> error_spec_gen) {
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
absl::Span<const std::complex<T>> input_arr =
|
absl::Span<const std::complex<T>> input_arr =
|
||||||
input_literal.data<std::complex<T>>();
|
input_literal.data<std::complex<T>>();
|
||||||
@ -938,7 +938,7 @@ class ExhaustiveC64UnaryTest
|
|||||||
public ::testing::WithParamInterface<
|
public ::testing::WithParamInterface<
|
||||||
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
||||||
public:
|
public:
|
||||||
typedef complex64 (*C64EvaluateOp)(complex64);
|
typedef complex64 (*C64EvaluateOp)(const complex64&);
|
||||||
|
|
||||||
ExhaustiveC64UnaryTest()
|
ExhaustiveC64UnaryTest()
|
||||||
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
||||||
@ -962,6 +962,11 @@ class ExhaustiveC64UnaryTest
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug.
|
||||||
|
XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) {
|
||||||
|
Run(Log, std::log<float>);
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
F32SpecialValues, ExhaustiveC64UnaryTest,
|
F32SpecialValues, ExhaustiveC64UnaryTest,
|
||||||
@ -969,7 +974,6 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
::testing::Values(C64),
|
::testing::Values(C64),
|
||||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||||
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
|
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
@ -1013,7 +1017,7 @@ class ExhaustiveC128UnaryTest
|
|||||||
public ::testing::WithParamInterface<
|
public ::testing::WithParamInterface<
|
||||||
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
||||||
public:
|
public:
|
||||||
typedef complex128 (*C128EvaluateOp)(complex128);
|
typedef complex128 (*C128EvaluateOp)(const complex128&);
|
||||||
|
|
||||||
ExhaustiveC128UnaryTest()
|
ExhaustiveC128UnaryTest()
|
||||||
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
||||||
@ -1038,14 +1042,13 @@ class ExhaustiveC128UnaryTest
|
|||||||
};
|
};
|
||||||
|
|
||||||
XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
|
XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
|
||||||
// TODO(bixia): only test values that are not too big and not too small
|
// TODO(b/138578313): Enable the test for all values after fixing the bug.
|
||||||
// for now and will work on fixing the implementation of XLA
|
|
||||||
// operations to enable test for other values.
|
|
||||||
known_incorrect_fn_ = [&](int64 v) {
|
known_incorrect_fn_ = [&](int64 v) {
|
||||||
double f = ConvertValue<double>(v);
|
double f = ConvertValue<double>(v);
|
||||||
return std::fpclassify(f) == FP_NAN || std::abs(f) > 5 || std::abs(f) < 1;
|
return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 ||
|
||||||
|
std::abs(f) < 1.0e-300;
|
||||||
};
|
};
|
||||||
Run(Log, [](complex128 x) { return std::log(x); });
|
Run(Log, std::log<double>);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
#if defined(UNARY_TEST_TARGET_COMPLEX)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user