diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 517d15f2c34..f78750d34ca 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -515,15 +515,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( : input_type; switch (op->opcode()) { case HloOpcode::kLog: { - // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a) auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); - llvm::Type* llvm_ty = a->getType(); - auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); - TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); - TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); + TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a)); + TF_ASSIGN_OR_RETURN(llvm::Value * abs, + EmitComplexAbs(component_type, operand_value)); + TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs)); + return EmitComposeComplex(op, log_abs, angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index a2fe86fc360..c93b8b366ce 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -748,6 +748,10 @@ xla_test( xla_test( name = "exhaustive_unary_test_complex", srcs = ["exhaustive_unary_test.cc"], + backends = [ + "gpu", + "cpu", + ], copts = ["-DUNARY_TEST_TARGET_COMPLEX"], real_hardware_only = True, # Very slow on the interpreter. shard_count = 48, diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc index 02273d7debd..8792d27440f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc @@ -17,8 +17,8 @@ limitations under the License. namespace xla { -// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be -// guaranteed that we're printing the full number. +// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of +// precision to be guaranteed that we're printing the full number. // // (The general formula is, given a floating-point number with S significand // bits, the number of decimal digits needed to print it to full precision is @@ -26,6 +26,11 @@ namespace xla { // ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103). // // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.) +/*static*/ +string ExhaustiveOpTestBase::StringifyNum(double x) { + return absl::StrFormat("%0.17g (0x%016x)", x, BitCast(x)); +} + /*static*/ string ExhaustiveOpTestBase::StringifyNum(float x) { return absl::StrFormat("%0.9g (0x%08x)", x, BitCast(x)); diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h index be16fddc756..2696231c00b 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h @@ -198,6 +198,8 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { return ConvertValue(bits); } + static string StringifyNum(double x); + static string StringifyNum(float x); static string StringifyNum(half x); diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc index b80c16ca2a6..b3f363618cd 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc @@ -835,7 +835,7 @@ class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase { // T is the component type of the complex number. template void Run(std::function enqueue_op, - std::complex (*evaluate_op)(std::complex), + std::complex (*evaluate_op)(const std::complex&), FpValues* values_real, FpValues* values_imag, std::function error_spec_gen) { Literal input_literal = CreateInputLiteral(); @@ -883,7 +883,7 @@ class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase { template void ExpectNearComplex(const Literal& input_literal, const Literal& result_literal, - std::complex (*evaluate_op)(std::complex), + std::complex (*evaluate_op)(const std::complex&), std::function error_spec_gen) { absl::Span> input_arr = input_literal.data>(); @@ -938,7 +938,7 @@ class ExhaustiveC64UnaryTest public ::testing::WithParamInterface< std::tuple> { public: - typedef complex64 (*C64EvaluateOp)(complex64); + typedef complex64 (*C64EvaluateOp)(const complex64&); ExhaustiveC64UnaryTest() : ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {} @@ -962,6 +962,11 @@ class ExhaustiveC64UnaryTest } }; +// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. +XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) { + Run(Log, std::log); +} + #if defined(UNARY_TEST_TARGET_COMPLEX) INSTANTIATE_TEST_SUITE_P( F32SpecialValues, ExhaustiveC64UnaryTest, @@ -969,7 +974,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(C64), ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - INSTANTIATE_TEST_SUITE_P( F32SpecialAndNormalValues, ExhaustiveC64UnaryTest, ::testing::Combine( @@ -1013,7 +1017,7 @@ class ExhaustiveC128UnaryTest public ::testing::WithParamInterface< std::tuple> { public: - typedef complex128 (*C128EvaluateOp)(complex128); + typedef complex128 (*C128EvaluateOp)(const complex128&); ExhaustiveC128UnaryTest() : ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {} @@ -1038,14 +1042,13 @@ class ExhaustiveC128UnaryTest }; XLA_TEST_P(ExhaustiveC128UnaryTest, Log) { - // TODO(bixia): only test values that are not too big and not too small - // for now and will work on fixing the implementation of XLA - // operations to enable test for other values. + // TODO(b/138578313): Enable the test for all values after fixing the bug. known_incorrect_fn_ = [&](int64 v) { double f = ConvertValue(v); - return std::fpclassify(f) == FP_NAN || std::abs(f) > 5 || std::abs(f) < 1; + return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 || + std::abs(f) < 1.0e-300; }; - Run(Log, [](complex128 x) { return std::log(x); }); + Run(Log, std::log); } #if defined(UNARY_TEST_TARGET_COMPLEX)