diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h index 3df4de295e3..956e1694fb7 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h @@ -45,7 +45,13 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // `ty` is the primitive type being tested. explicit ExhaustiveOpTestBase(PrimitiveType ty) - : ty_(ty), platform_(client_->platform()->Name()) {} + : ty_(ty), platform_(client_->platform()->Name()) { + SetFastMathDisabled(true); + + // Run all HLO passes. In particular, constant folding is disabled by + // default for tests, but we need to run it in order to tickle some bugs. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + } // Builds and runs the computation using the LocalClient API, rather than the // plain Client API, which is used by ClientLibraryTestBase. This is because @@ -227,5 +233,410 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { bool relaxed_denormal_signs_ = platform_ != "CUDA"; }; +// Represents a set of 64 bit chunks by representing the starting bit chunk, +// the last bit chunk, and the spacing between two adjacent bit chunks, without +// actually storing all the bit chunks being generated. The bit chunk iterator +// is provided to retrieve all the bit chunks. +// +// This data structure is used to generate the bit representation to test +// operations that requires more than 64 bit input data. In this case, +// truly exhaustive testing is not possible and we want to test a value every +// n values, where n == spacing_. +// +// Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to +// compute the next bit chunk. We can change this to use values generated +// by a random number generator that can achieve the average spacing +// statistically, if we will find this is necessary. +class BitChunks { + public: + class iterator + : public std::iterator { + public: + iterator() {} + + explicit iterator(const BitChunks* bit_chunks) + : bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {} + + iterator& operator++() { + Next(); + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + Next(); + return retval; + } + + bool operator==(iterator other) const { + return bit_chunks_ == other.bit_chunks_ && + next_bit_chunk_ == other.next_bit_chunk_; + } + + bool operator!=(iterator other) const { return !(*this == other); } + + iterator MoveToEnd() { + MoveNextBitChunkToOnePassEnd(); + return *this; + } + + reference operator*() const { + CHECK(*this != this->bit_chunks_->end()); + return next_bit_chunk_; + } + + const BitChunks* GetBitChunks() const { return bit_chunks_; } + + void Reset() { next_bit_chunk_ = bit_chunks_->start_; } + + void Next() { + CHECK(*this != this->bit_chunks_->end()); + if (next_bit_chunk_ == bit_chunks_->end_) { + MoveNextBitChunkToOnePassEnd(); + } else { + next_bit_chunk_ += bit_chunks_->spacing_; + if (next_bit_chunk_ > bit_chunks_->end_) { + next_bit_chunk_ = bit_chunks_->end_; + } + } + } + + std::string ToString() const { + return absl::StrFormat("0x%08x", next_bit_chunk_); + } + + private: + // Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the + // iterator has reached the end. When spacing_ is not one, or if we will + // change to use a random value instead of spacing_ in function Next(), + // normalizing the representation of the iterator ending this way can + // can simplify the checking for iterator ending. + void MoveNextBitChunkToOnePassEnd() { + next_bit_chunk_ = bit_chunks_->end_ + 1; + } + + const BitChunks* bit_chunks_; + uint64 next_bit_chunk_; + }; + + iterator begin() const { return iterator(this); } + iterator end() const { + iterator end(this); + return end.MoveToEnd(); + } + + explicit BitChunks(uint64 start = 0, uint64 end = 0, uint64 spacing = 1) + : start_(start), end_(end), spacing_(spacing) { + CHECK_GE(end_, start_); + CHECK_NE(spacing, 0) << ToString(); + } + + int64 GetTotalBitChunks() const { + if (start_ == end_) { + return 1; + } + + return 1 + (end_ - start_ + spacing_ - 1) / spacing_; + } + + std::string ToString() const { + return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_); + } + + uint64 start_; + uint64 end_; + uint64 spacing_; +}; + +inline string StringifyNum(BitChunks c) { return c.ToString(); } + +inline string StringifyNum(BitChunks::iterator c) { return c.ToString(); } + +template +void AppendStringifyNum(std::string* s, T x) { + absl::StrAppend(s, StringifyNum(x)); +} + +// Represents a set of floating point values through the possible values for +// the three components: mantissa, exponent, and sign. Also implements an +// iterator for retrieving all the represented floating point values. +class FpValues { + public: + static constexpr uint kTotalBitChunks = 3; + + class iterator + : public std::iterator { + public: + explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) { + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i)); + } + } + + iterator& operator++() { + Next(); + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + Next(); + return retval; + } + + bool operator==(iterator other) const { + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + if (iters_[i] != other.GetBitChunksIter(i)) { + return false; + } + } + return true; + } + + bool operator!=(iterator other) const { return !(*this == other); } + + iterator MoveToEnd() { + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + iters_[i].MoveToEnd(); + } + return *this; + } + + uint64 operator*() const { + uint64 value = 0; + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + value = value | (*iters_[i]) << fp_values_->offsets_[i]; + } + return value; + } + + const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; } + + std::string ToString() const { + return absl::StrJoin(iters_, ",", + AppendStringifyNum); + } + + private: + // Moves the iterator for the ith BitChunks to the next value, and + // returns true if the new state is not the end of the iterator. + bool Next(int i = 0) { + iters_[i].Next(); + if (iters_[i] == iters_[i].GetBitChunks()->end()) { + if (i == FpValues::kTotalBitChunks - 1) { + return false; + } + if (Next(i + 1)) { + iters_[i].Reset(); + return true; + } + return false; + } + return true; + } + + std::array iters_; + const FpValues* fp_values_; + }; + + FpValues(absl::Span chunks, absl::Span offsets) { + CHECK_EQ(chunks.size(), offsets.size() - 1); + CHECK_EQ(chunks.size(), kTotalBitChunks); + std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin()); + std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin()); + + // The last value in `offsets` is the total number of bits. + offsets_[kTotalBitChunks] = offsets[kTotalBitChunks]; + // Validate the input values. + for (int i = 0; i < kTotalBitChunks; ++i) { + int total_bits = offsets[i + 1] - offsets[i]; + if (total_bits < 64) { + uint64 bound = 1ull << total_bits; + CHECK_LT(chunks[i].start_, bound); + CHECK_LT(chunks[i].end_, bound); + } else { + CHECK_EQ(total_bits, 64); + } + } + } + + iterator begin() const { return iterator(this); } + + iterator end() const { + iterator end(this); + return end.MoveToEnd(); + } + + int64 GetTotalNumValues() const { + int64 total = 1; + absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) { + total *= chunks.GetTotalBitChunks(); + }); + return total; + } + + const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; } + + std::string ToString() const { + return absl::StrCat( + "[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum), + "]"); + } + + std::array bit_chunks_; + std::array offsets_; +}; + +template +int GetMantissaTotalBits() { + static_assert(std::is_same::value || std::is_same::value, + "Only supports float and double."); + return std::numeric_limits::digits - 1; +} + +template +int GetFpTotalBits() { + return sizeof(T) * 8; +} + +template +int GetExponentTotalBits() { + return GetFpTotalBits() - GetMantissaTotalBits() - 1; +} + +template +uint64 GetAllOneMantissa() { + return (1ull << GetMantissaTotalBits()) - 1ull; +} + +template +uint64 GetAllOneExponent() { + return (1ull << GetExponentTotalBits()) - 1ull; +} + +template +FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) { + static_assert(std::is_same::value || std::is_same::value, + "Only supports float and double."); + int total_bits = GetFpTotalBits(); + return FpValues({mantissa, exponent, sign}, + {0, GetMantissaTotalBits(), total_bits - 1, total_bits}); +} + +template +FpValues GetZeros() { + return GetFpValues(BitChunks(0, 0, 1), BitChunks(0, 0, 1), + BitChunks(0, 1, 1)); +} + +template +FpValues GetSubnormals(int approx_num_values) { + int mantissa = GetMantissaTotalBits(); + uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); + return GetFpValues( + BitChunks(0x1, GetAllOneMantissa(), mantissa_spacing), + BitChunks(0, 0, 1), BitChunks(0, 1, 1)); +} + +template +FpValues GetInfinites() { + uint64 all_one_exp = GetAllOneExponent(); + return GetFpValues(BitChunks(0, 0, 1), + BitChunks(all_one_exp, all_one_exp, 1), + BitChunks(0, 1, 1)); +} + +template +FpValues GetNans(int approx_num_values) { + int mantissa = GetMantissaTotalBits(); + uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); + uint64 all_one_exp = GetAllOneExponent(); + return GetFpValues( + BitChunks(0x1, GetAllOneMantissa(), mantissa_spacing), + BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1)); +} + +template +FpValues GetNormals(int approx_num_values) { + float component_total = std::sqrtf(approx_num_values); + return GetFpValues( + BitChunks(0x1, GetAllOneMantissa(), + (1ull << (GetMantissaTotalBits() + 1)) / component_total), + BitChunks(0x1, GetAllOneExponent() - 1, + (1ull << (GetExponentTotalBits() + 1)) / component_total), + BitChunks(0, 1, 1)); +} + +// Returns a vector of FpValues, which together represent about +// `approx_num_values` floating point values of type `T`, with each FpValues +// represents about `num_values_per_group` floating point values. +template +std::vector GetFpValuesWithExponents(uint64 first_exponent, + uint64 exponent_spacing, + uint64 num_exponents, + uint64 approx_num_values, + uint64 num_values_per_group) { + const uint64 num_signs = 2; + uint64 approx_num_mantissa = approx_num_values / (num_exponents * num_signs); + uint64 num_mantissa_per_group = + num_values_per_group / (num_exponents * num_signs); + CHECK_GT(approx_num_mantissa, 0); + CHECK_GT(num_mantissa_per_group, 0); + + CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent()); + int mantissa = GetMantissaTotalBits(); + uint64 mantissa_spacing = (1ull << mantissa) / approx_num_mantissa; + + std::vector result; + for (uint64 group_start = 0; group_start < GetAllOneMantissa(); + group_start += mantissa_spacing * num_mantissa_per_group) { + uint64 group_end = + group_start + (num_mantissa_per_group - 1) * mantissa_spacing; + if (group_end > GetAllOneMantissa()) { + group_end = GetAllOneMantissa(); + } + result.push_back(GetFpValues( + BitChunks(group_start, group_end, mantissa_spacing), + BitChunks(first_exponent, first_exponent + num_exponents - 1, 1), + BitChunks(0, 1, 1))); + } + return result; +} + +// Returns a vector of FpValues together represent about `approx_num_values` +// "very large" floating point values and `approx_num_values` "very small" +// floating point values of type `T`, which each FpValues represent about +// `num_values_per_group` floating point values. Because we use FpValues as +// a parameter for parameterized testing, the number of floating values +// represented by each FpValues affects the input size for each sub-test and +// the hence the peak memory usage of the test. +template +std::vector GetFpValuesForMagnitudeExtremeNormals( + uint64 approx_num_values = 40000, uint64 num_values_per_group = 4000) { + std::vector large = + GetFpValuesWithExponents(GetAllOneExponent() - 5, 1, 5, + approx_num_values / 2, num_values_per_group); + std::vector small = GetFpValuesWithExponents( + 1, 1, 5, approx_num_values / 2, num_values_per_group); + large.insert(large.end(), small.begin(), small.end()); + return large; +} + +template +std::vector CreateFpValuesForBoundaryTest() { + return {GetZeros(), GetSubnormals(1000), GetInfinites(), + GetNans(1000)}; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc index 0186d7d668d..5f82af95245 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc @@ -326,11 +326,6 @@ class Exhaustive32BitOrLessUnaryTest void Run(std::function enqueue_op, F32EvaluateOp evaluate_op, std::function error_spec_gen) { - SetFastMathDisabled(true); - - // Run all HLO passes. In particular, constant folding is disabled by - // default for tests, but we need to run it in order to tickle some bugs. - mutable_debug_options()->clear_xla_disable_hlo_passes(); Literal input_literal = CreateInputLiteral(); switch (ty_) { case F32: @@ -708,4 +703,340 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(std::make_pair(0, 1 << 16)))); #endif +// Exhaustive test for unary operations for double. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - FpValues representing a set of double values. +class ExhaustiveF64UnaryTest : public ExhaustiveRealUnaryTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + typedef double (*F64EvaluateOp)(double); + + ExhaustiveF64UnaryTest() + : ExhaustiveRealUnaryTestBase(std::get<0>(GetParam())) {} + + void Run(std::function enqueue_op, F64EvaluateOp evaluate_op) { + return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_)); + } + + void Run(std::function enqueue_op, F64EvaluateOp evaluate_op, + std::function error_spec_gen) { + CHECK_EQ(ty_, F64); + Literal input_literal = CreateInputLiteral(); + FillInputF64(&input_literal); + RunImpl(enqueue_op, evaluate_op, input_literal, + error_spec_gen); + } + + private: + int64 GetInputSize() override { + FpValues values = std::get<1>(GetParam()); + return values.GetTotalNumValues(); + } + + void FillInputF64(Literal* input_literal) { + FpValues fp_values = std::get<1>(GetParam()); + int64 input_size = input_literal->element_count(); + LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", " + << input_size; + absl::Span input_arr = input_literal->data(); + + uint64 i = 0; + for (auto bits : fp_values) { + input_arr[i] = ConvertAndReplaceKnownIncorrectValueWith(bits, 1); + ++i; + } + CHECK_EQ(i, input_size); + } +}; + +XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); } + +// TODO(bixia): add other unary ops for double + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF64UnaryTest, + ::testing::Combine( + ::testing::Values(F64), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalValues, ExhaustiveF64UnaryTest, + ::testing::Combine(::testing::Values(F64), + ::testing::Values(GetNormals(1000)))); + +// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to +// keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest, + ::testing::Combine( + ::testing::Values(F64), + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( + 4000000000ull, 16000000)))); +#endif + +class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase { + public: + explicit ExhaustiveComplexUnaryTestBase(PrimitiveType ty) + : ExhaustiveOpTestBase(ty) {} + + // A helper for implementing the Run method for unary op test of complex + // numbers. + // + // T is the component type of the complex number. + template + void Run(std::function enqueue_op, + std::complex (*evaluate_op)(std::complex), + FpValues* values_real, FpValues* values_imag, + std::function error_spec_gen) { + Literal input_literal = CreateInputLiteral(); + + FillInput(&input_literal, values_real, values_imag); + + XlaBuilder builder(TestName()); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); + enqueue_op(input); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + RunComputation(comp, {&input_literal})); + ExpectNearComplex(input_literal, result_literal, evaluate_op, + error_spec_gen); + } + + // Generates the input complex literal given the FpValues representation for + // the real and imaginary components. + // + // T is the component type of the complex number. + template + void FillInput(Literal* input_literal, FpValues* real_values, + FpValues* imag_values) { + VLOG(2) << " testing input total " + << real_values->GetTotalNumValues() * + imag_values->GetTotalNumValues() + << ", range " << real_values->ToString() << " " + << imag_values->ToString(); + + absl::Span> input_arr = + input_literal->data>(); + + uint64 i = 0; + for (auto real : *real_values) { + for (auto imag : *imag_values) { + input_arr[i] = std::complex( + ConvertAndReplaceKnownIncorrectValueWith(real, 1), + ConvertAndReplaceKnownIncorrectValueWith(imag, 1)); + + ++i; + } + } + } + + template + void ExpectNearComplex(const Literal& input_literal, + const Literal& result_literal, + std::complex (*evaluate_op)(std::complex), + std::function error_spec_gen) { + absl::Span> input_arr = + input_literal.data>(); + absl::Span> result_arr = + result_literal.data>(); + ASSERT_EQ(result_arr.size(), input_arr.size()); + int64 mismatches = 0; + + for (int64 i = 0; i < input_arr.size(); ++i) { + std::complex input = input_arr[i]; + std::complex actual = result_arr[i]; + std::complex expected = evaluate_op(input); + + // TODO(bixia): Need to fix error_spec_gen to consider both components. + // This only affects the value specific error_spec, and before we fix + // this, it means complex operation testing doesn't support value + // specific error_spec yet. We delay the fix to this partially because + // we don't know whether it is enough for the error_spec to only take + // the absolute value of the complex number. + ErrorSpec error_spec = error_spec_gen(input.real()); + + if (IsClose(expected.real(), actual.real(), error_spec) && + IsClose(expected.imag(), actual.imag(), error_spec)) { + continue; + } + + // TODO(bixia): Need to handle complex operands with subnormals in + // real and/or imaginary components. + VLOG(2) << "calculate " << StringifyNum(input) << " ;" + << StringifyNum(actual) << "; " << StringifyNum(expected); + + PrintMismatch(&mismatches, [&] { + return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", + StringifyNum(input), StringifyNum(expected), + StringifyNum(actual)); + }); + } + + EXPECT_EQ(mismatches, 0); + } +}; + +// Unary op test for complex. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - two FpValues representing the values for the real and imaginary +// components. The complex numbers for the test input is the cartesian +// product of the values represented by the two FpValues. +class ExhaustiveC64UnaryTest + : public ExhaustiveComplexUnaryTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + typedef complex64 (*C64EvaluateOp)(complex64); + + ExhaustiveC64UnaryTest() + : ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {} + + void Run(std::function enqueue_op, C64EvaluateOp evaluate_op) { + return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_)); + } + + void Run(std::function enqueue_op, C64EvaluateOp evaluate_op, + std::function error_spec_gen) { + FpValues values_real = std::get<1>(GetParam()); + FpValues values_imag = std::get<2>(GetParam()); + ExhaustiveComplexUnaryTestBase::Run( + enqueue_op, evaluate_op, &values_real, &values_imag, error_spec_gen); + } + + int64 GetInputSize() override { + FpValues values_real = std::get<1>(GetParam()); + FpValues values_imag = std::get<2>(GetParam()); + return values_real.GetTotalNumValues() * values_imag.GetTotalNumValues(); + } +}; + +INSTANTIATE_TEST_SUITE_P( + F32SpecialValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::Values(C64), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + F32SpecialAndNormalValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::Values(C64), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(10000)))); + +INSTANTIATE_TEST_SUITE_P( + F32NormalAndSpecialValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::Values(C64), ::testing::Values(GetNormals(10000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + F32NormalAndNormalValues, ExhaustiveC64UnaryTest, + ::testing::Combine(::testing::Values(C64), + ::testing::Values(GetNormals(10000)), + ::testing::Values(GetNormals(10000)))); + +// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to +// keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + F32LargeAndSmallMagnituedNormalValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::Values(C64), + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, + 4000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 4000)))); + +// Unary op test for complex. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - two FpValues representing the values for the real and imaginary +// components. The complex numbers for the test input is the cartesian +// product of the values represented by the two FpValues. +class ExhaustiveC128UnaryTest + : public ExhaustiveComplexUnaryTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + typedef complex128 (*C128EvaluateOp)(complex128); + + ExhaustiveC128UnaryTest() + : ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {} + + void Run(std::function enqueue_op, C128EvaluateOp evaluate_op) { + return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_)); + } + + void Run(std::function enqueue_op, C128EvaluateOp evaluate_op, + std::function error_spec_gen) { + FpValues values_real = std::get<1>(GetParam()); + FpValues values_imag = std::get<2>(GetParam()); + ExhaustiveComplexUnaryTestBase::Run( + enqueue_op, evaluate_op, &values_real, &values_imag, error_spec_gen); + } + + int64 GetInputSize() override { + FpValues values_real = std::get<1>(GetParam()); + FpValues values_imag = std::get<2>(GetParam()); + return values_real.GetTotalNumValues() * values_imag.GetTotalNumValues(); + } +}; + +XLA_TEST_P(ExhaustiveC128UnaryTest, Log) { + // TODO(bixia): only test values that are not too big and not too small + // for now and will work on fixing the implementation of XLA + // operations to enable test for other values. + known_incorrect_fn_ = [&](int64 v) { + double f = ConvertValue(v); + return std::fpclassify(f) == FP_NAN || std::abs(f) > 5 || std::abs(f) < 1; + }; + Run(Log, [](complex128 x) { return std::log(x); }); +} + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::Values(C128), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::Values(C128), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(10000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::Values(C128), ::testing::Values(GetNormals(10000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + F32NormalAndNormalValues, ExhaustiveC128UnaryTest, + ::testing::Combine(::testing::Values(C128), + ::testing::Values(GetNormals(10000)), + ::testing::Values(GetNormals(10000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to +// keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnituedNormalValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::Values(C128), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +#endif + } // namespace xla