[XLA] Add support for exhaustive test of operations with more than 32 bit input.
For operations that require 64 bits or more input data, we can't actually exhaustively test all input bit patterns. Instead, we define a data structure, FpValues, for a test to specify a subset of bit patterns being test. Add exhaustive tests for transcendental operations of F64, C64 and C128. PiperOrigin-RevId: 259014020
This commit is contained in:
parent
be42a1eb12
commit
84d5ed5ba6
@ -45,7 +45,13 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
|||||||
|
|
||||||
// `ty` is the primitive type being tested.
|
// `ty` is the primitive type being tested.
|
||||||
explicit ExhaustiveOpTestBase(PrimitiveType ty)
|
explicit ExhaustiveOpTestBase(PrimitiveType ty)
|
||||||
: ty_(ty), platform_(client_->platform()->Name()) {}
|
: ty_(ty), platform_(client_->platform()->Name()) {
|
||||||
|
SetFastMathDisabled(true);
|
||||||
|
|
||||||
|
// Run all HLO passes. In particular, constant folding is disabled by
|
||||||
|
// default for tests, but we need to run it in order to tickle some bugs.
|
||||||
|
mutable_debug_options()->clear_xla_disable_hlo_passes();
|
||||||
|
}
|
||||||
|
|
||||||
// Builds and runs the computation using the LocalClient API, rather than the
|
// Builds and runs the computation using the LocalClient API, rather than the
|
||||||
// plain Client API, which is used by ClientLibraryTestBase. This is because
|
// plain Client API, which is used by ClientLibraryTestBase. This is because
|
||||||
@ -227,5 +233,410 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
|||||||
bool relaxed_denormal_signs_ = platform_ != "CUDA";
|
bool relaxed_denormal_signs_ = platform_ != "CUDA";
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Represents a set of 64 bit chunks by representing the starting bit chunk,
|
||||||
|
// the last bit chunk, and the spacing between two adjacent bit chunks, without
|
||||||
|
// actually storing all the bit chunks being generated. The bit chunk iterator
|
||||||
|
// is provided to retrieve all the bit chunks.
|
||||||
|
//
|
||||||
|
// This data structure is used to generate the bit representation to test
|
||||||
|
// operations that requires more than 64 bit input data. In this case,
|
||||||
|
// truly exhaustive testing is not possible and we want to test a value every
|
||||||
|
// n values, where n == spacing_.
|
||||||
|
//
|
||||||
|
// Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to
|
||||||
|
// compute the next bit chunk. We can change this to use values generated
|
||||||
|
// by a random number generator that can achieve the average spacing
|
||||||
|
// statistically, if we will find this is necessary.
|
||||||
|
class BitChunks {
|
||||||
|
public:
|
||||||
|
class iterator
|
||||||
|
: public std::iterator<std::input_iterator_tag, // iterator_category
|
||||||
|
uint64, // value_type
|
||||||
|
uint64, // difference_type
|
||||||
|
const uint64*, // pointer
|
||||||
|
uint64 // reference
|
||||||
|
> {
|
||||||
|
public:
|
||||||
|
iterator() {}
|
||||||
|
|
||||||
|
explicit iterator(const BitChunks* bit_chunks)
|
||||||
|
: bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {}
|
||||||
|
|
||||||
|
iterator& operator++() {
|
||||||
|
Next();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator operator++(int) {
|
||||||
|
iterator retval = *this;
|
||||||
|
Next();
|
||||||
|
return retval;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(iterator other) const {
|
||||||
|
return bit_chunks_ == other.bit_chunks_ &&
|
||||||
|
next_bit_chunk_ == other.next_bit_chunk_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(iterator other) const { return !(*this == other); }
|
||||||
|
|
||||||
|
iterator MoveToEnd() {
|
||||||
|
MoveNextBitChunkToOnePassEnd();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
reference operator*() const {
|
||||||
|
CHECK(*this != this->bit_chunks_->end());
|
||||||
|
return next_bit_chunk_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BitChunks* GetBitChunks() const { return bit_chunks_; }
|
||||||
|
|
||||||
|
void Reset() { next_bit_chunk_ = bit_chunks_->start_; }
|
||||||
|
|
||||||
|
void Next() {
|
||||||
|
CHECK(*this != this->bit_chunks_->end());
|
||||||
|
if (next_bit_chunk_ == bit_chunks_->end_) {
|
||||||
|
MoveNextBitChunkToOnePassEnd();
|
||||||
|
} else {
|
||||||
|
next_bit_chunk_ += bit_chunks_->spacing_;
|
||||||
|
if (next_bit_chunk_ > bit_chunks_->end_) {
|
||||||
|
next_bit_chunk_ = bit_chunks_->end_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
return absl::StrFormat("0x%08x", next_bit_chunk_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the
|
||||||
|
// iterator has reached the end. When spacing_ is not one, or if we will
|
||||||
|
// change to use a random value instead of spacing_ in function Next(),
|
||||||
|
// normalizing the representation of the iterator ending this way can
|
||||||
|
// can simplify the checking for iterator ending.
|
||||||
|
void MoveNextBitChunkToOnePassEnd() {
|
||||||
|
next_bit_chunk_ = bit_chunks_->end_ + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BitChunks* bit_chunks_;
|
||||||
|
uint64 next_bit_chunk_;
|
||||||
|
};
|
||||||
|
|
||||||
|
iterator begin() const { return iterator(this); }
|
||||||
|
iterator end() const {
|
||||||
|
iterator end(this);
|
||||||
|
return end.MoveToEnd();
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit BitChunks(uint64 start = 0, uint64 end = 0, uint64 spacing = 1)
|
||||||
|
: start_(start), end_(end), spacing_(spacing) {
|
||||||
|
CHECK_GE(end_, start_);
|
||||||
|
CHECK_NE(spacing, 0) << ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 GetTotalBitChunks() const {
|
||||||
|
if (start_ == end_) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1 + (end_ - start_ + spacing_ - 1) / spacing_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64 start_;
|
||||||
|
uint64 end_;
|
||||||
|
uint64 spacing_;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline string StringifyNum(BitChunks c) { return c.ToString(); }
|
||||||
|
|
||||||
|
inline string StringifyNum(BitChunks::iterator c) { return c.ToString(); }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AppendStringifyNum(std::string* s, T x) {
|
||||||
|
absl::StrAppend(s, StringifyNum(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represents a set of floating point values through the possible values for
|
||||||
|
// the three components: mantissa, exponent, and sign. Also implements an
|
||||||
|
// iterator for retrieving all the represented floating point values.
|
||||||
|
class FpValues {
|
||||||
|
public:
|
||||||
|
static constexpr uint kTotalBitChunks = 3;
|
||||||
|
|
||||||
|
class iterator
|
||||||
|
: public std::iterator<std::input_iterator_tag, // iterator_category
|
||||||
|
uint64, // value_type
|
||||||
|
uint64, // difference_type
|
||||||
|
const uint64*, // pointer
|
||||||
|
uint64 // reference
|
||||||
|
> {
|
||||||
|
public:
|
||||||
|
explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) {
|
||||||
|
for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
|
||||||
|
iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator& operator++() {
|
||||||
|
Next();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator operator++(int) {
|
||||||
|
iterator retval = *this;
|
||||||
|
Next();
|
||||||
|
return retval;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(iterator other) const {
|
||||||
|
for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
|
||||||
|
if (iters_[i] != other.GetBitChunksIter(i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(iterator other) const { return !(*this == other); }
|
||||||
|
|
||||||
|
iterator MoveToEnd() {
|
||||||
|
for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
|
||||||
|
iters_[i].MoveToEnd();
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64 operator*() const {
|
||||||
|
uint64 value = 0;
|
||||||
|
for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
|
||||||
|
value = value | (*iters_[i]) << fp_values_->offsets_[i];
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; }
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
return absl::StrJoin(iters_, ",",
|
||||||
|
AppendStringifyNum<BitChunks::iterator>);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Moves the iterator for the ith BitChunks to the next value, and
|
||||||
|
// returns true if the new state is not the end of the iterator.
|
||||||
|
bool Next(int i = 0) {
|
||||||
|
iters_[i].Next();
|
||||||
|
if (iters_[i] == iters_[i].GetBitChunks()->end()) {
|
||||||
|
if (i == FpValues::kTotalBitChunks - 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (Next(i + 1)) {
|
||||||
|
iters_[i].Reset();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<BitChunks::iterator, FpValues::kTotalBitChunks> iters_;
|
||||||
|
const FpValues* fp_values_;
|
||||||
|
};
|
||||||
|
|
||||||
|
FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) {
|
||||||
|
CHECK_EQ(chunks.size(), offsets.size() - 1);
|
||||||
|
CHECK_EQ(chunks.size(), kTotalBitChunks);
|
||||||
|
std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin());
|
||||||
|
std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin());
|
||||||
|
|
||||||
|
// The last value in `offsets` is the total number of bits.
|
||||||
|
offsets_[kTotalBitChunks] = offsets[kTotalBitChunks];
|
||||||
|
// Validate the input values.
|
||||||
|
for (int i = 0; i < kTotalBitChunks; ++i) {
|
||||||
|
int total_bits = offsets[i + 1] - offsets[i];
|
||||||
|
if (total_bits < 64) {
|
||||||
|
uint64 bound = 1ull << total_bits;
|
||||||
|
CHECK_LT(chunks[i].start_, bound);
|
||||||
|
CHECK_LT(chunks[i].end_, bound);
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(total_bits, 64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator begin() const { return iterator(this); }
|
||||||
|
|
||||||
|
iterator end() const {
|
||||||
|
iterator end(this);
|
||||||
|
return end.MoveToEnd();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 GetTotalNumValues() const {
|
||||||
|
int64 total = 1;
|
||||||
|
absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) {
|
||||||
|
total *= chunks.GetTotalBitChunks();
|
||||||
|
});
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; }
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
return absl::StrCat(
|
||||||
|
"[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum<BitChunks>),
|
||||||
|
"]");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<BitChunks, kTotalBitChunks> bit_chunks_;
|
||||||
|
std::array<int, kTotalBitChunks + 1> offsets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int GetMantissaTotalBits() {
|
||||||
|
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||||
|
"Only supports float and double.");
|
||||||
|
return std::numeric_limits<T>::digits - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int GetFpTotalBits() {
|
||||||
|
return sizeof(T) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int GetExponentTotalBits() {
|
||||||
|
return GetFpTotalBits<T>() - GetMantissaTotalBits<T>() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
uint64 GetAllOneMantissa() {
|
||||||
|
return (1ull << GetMantissaTotalBits<T>()) - 1ull;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
uint64 GetAllOneExponent() {
|
||||||
|
return (1ull << GetExponentTotalBits<T>()) - 1ull;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) {
|
||||||
|
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
|
||||||
|
"Only supports float and double.");
|
||||||
|
int total_bits = GetFpTotalBits<T>();
|
||||||
|
return FpValues({mantissa, exponent, sign},
|
||||||
|
{0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FpValues GetZeros() {
|
||||||
|
return GetFpValues<T>(BitChunks(0, 0, 1), BitChunks(0, 0, 1),
|
||||||
|
BitChunks(0, 1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FpValues GetSubnormals(int approx_num_values) {
|
||||||
|
int mantissa = GetMantissaTotalBits<T>();
|
||||||
|
uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2);
|
||||||
|
return GetFpValues<T>(
|
||||||
|
BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing),
|
||||||
|
BitChunks(0, 0, 1), BitChunks(0, 1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FpValues GetInfinites() {
|
||||||
|
uint64 all_one_exp = GetAllOneExponent<T>();
|
||||||
|
return GetFpValues<T>(BitChunks(0, 0, 1),
|
||||||
|
BitChunks(all_one_exp, all_one_exp, 1),
|
||||||
|
BitChunks(0, 1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FpValues GetNans(int approx_num_values) {
|
||||||
|
int mantissa = GetMantissaTotalBits<T>();
|
||||||
|
uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2);
|
||||||
|
uint64 all_one_exp = GetAllOneExponent<T>();
|
||||||
|
return GetFpValues<T>(
|
||||||
|
BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing),
|
||||||
|
BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FpValues GetNormals(int approx_num_values) {
|
||||||
|
float component_total = std::sqrtf(approx_num_values);
|
||||||
|
return GetFpValues<T>(
|
||||||
|
BitChunks(0x1, GetAllOneMantissa<T>(),
|
||||||
|
(1ull << (GetMantissaTotalBits<T>() + 1)) / component_total),
|
||||||
|
BitChunks(0x1, GetAllOneExponent<T>() - 1,
|
||||||
|
(1ull << (GetExponentTotalBits<T>() + 1)) / component_total),
|
||||||
|
BitChunks(0, 1, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a vector of FpValues, which together represent about
|
||||||
|
// `approx_num_values` floating point values of type `T`, with each FpValues
|
||||||
|
// represents about `num_values_per_group` floating point values.
|
||||||
|
template <typename T>
|
||||||
|
std::vector<FpValues> GetFpValuesWithExponents(uint64 first_exponent,
|
||||||
|
uint64 exponent_spacing,
|
||||||
|
uint64 num_exponents,
|
||||||
|
uint64 approx_num_values,
|
||||||
|
uint64 num_values_per_group) {
|
||||||
|
const uint64 num_signs = 2;
|
||||||
|
uint64 approx_num_mantissa = approx_num_values / (num_exponents * num_signs);
|
||||||
|
uint64 num_mantissa_per_group =
|
||||||
|
num_values_per_group / (num_exponents * num_signs);
|
||||||
|
CHECK_GT(approx_num_mantissa, 0);
|
||||||
|
CHECK_GT(num_mantissa_per_group, 0);
|
||||||
|
|
||||||
|
CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent<T>());
|
||||||
|
int mantissa = GetMantissaTotalBits<T>();
|
||||||
|
uint64 mantissa_spacing = (1ull << mantissa) / approx_num_mantissa;
|
||||||
|
|
||||||
|
std::vector<FpValues> result;
|
||||||
|
for (uint64 group_start = 0; group_start < GetAllOneMantissa<T>();
|
||||||
|
group_start += mantissa_spacing * num_mantissa_per_group) {
|
||||||
|
uint64 group_end =
|
||||||
|
group_start + (num_mantissa_per_group - 1) * mantissa_spacing;
|
||||||
|
if (group_end > GetAllOneMantissa<T>()) {
|
||||||
|
group_end = GetAllOneMantissa<T>();
|
||||||
|
}
|
||||||
|
result.push_back(GetFpValues<T>(
|
||||||
|
BitChunks(group_start, group_end, mantissa_spacing),
|
||||||
|
BitChunks(first_exponent, first_exponent + num_exponents - 1, 1),
|
||||||
|
BitChunks(0, 1, 1)));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a vector of FpValues together represent about `approx_num_values`
|
||||||
|
// "very large" floating point values and `approx_num_values` "very small"
|
||||||
|
// floating point values of type `T`, which each FpValues represent about
|
||||||
|
// `num_values_per_group` floating point values. Because we use FpValues as
|
||||||
|
// a parameter for parameterized testing, the number of floating values
|
||||||
|
// represented by each FpValues affects the input size for each sub-test and
|
||||||
|
// the hence the peak memory usage of the test.
|
||||||
|
template <typename T>
|
||||||
|
std::vector<FpValues> GetFpValuesForMagnitudeExtremeNormals(
|
||||||
|
uint64 approx_num_values = 40000, uint64 num_values_per_group = 4000) {
|
||||||
|
std::vector<FpValues> large =
|
||||||
|
GetFpValuesWithExponents<T>(GetAllOneExponent<T>() - 5, 1, 5,
|
||||||
|
approx_num_values / 2, num_values_per_group);
|
||||||
|
std::vector<FpValues> small = GetFpValuesWithExponents<T>(
|
||||||
|
1, 1, 5, approx_num_values / 2, num_values_per_group);
|
||||||
|
large.insert(large.end(), small.begin(), small.end());
|
||||||
|
return large;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<FpValues> CreateFpValuesForBoundaryTest() {
|
||||||
|
return {GetZeros<T>(), GetSubnormals<T>(1000), GetInfinites<T>(),
|
||||||
|
GetNans<T>(1000)};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
|
||||||
|
@ -326,11 +326,6 @@ class Exhaustive32BitOrLessUnaryTest
|
|||||||
|
|
||||||
void Run(std::function<XlaOp(XlaOp)> enqueue_op, F32EvaluateOp evaluate_op,
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, F32EvaluateOp evaluate_op,
|
||||||
std::function<ErrorSpec(float)> error_spec_gen) {
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
SetFastMathDisabled(true);
|
|
||||||
|
|
||||||
// Run all HLO passes. In particular, constant folding is disabled by
|
|
||||||
// default for tests, but we need to run it in order to tickle some bugs.
|
|
||||||
mutable_debug_options()->clear_xla_disable_hlo_passes();
|
|
||||||
Literal input_literal = CreateInputLiteral();
|
Literal input_literal = CreateInputLiteral();
|
||||||
switch (ty_) {
|
switch (ty_) {
|
||||||
case F32:
|
case F32:
|
||||||
@ -708,4 +703,340 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
::testing::Values(std::make_pair(0, 1 << 16))));
|
::testing::Values(std::make_pair(0, 1 << 16))));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Exhaustive test for unary operations for double.
|
||||||
|
//
|
||||||
|
// Test parameter is a tuple containing
|
||||||
|
// - primitive type under test,
|
||||||
|
// - FpValues representing a set of double values.
|
||||||
|
class ExhaustiveF64UnaryTest : public ExhaustiveRealUnaryTestBase,
|
||||||
|
public ::testing::WithParamInterface<
|
||||||
|
std::tuple<PrimitiveType, FpValues>> {
|
||||||
|
public:
|
||||||
|
typedef double (*F64EvaluateOp)(double);
|
||||||
|
|
||||||
|
ExhaustiveF64UnaryTest()
|
||||||
|
: ExhaustiveRealUnaryTestBase(std::get<0>(GetParam())) {}
|
||||||
|
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, F64EvaluateOp evaluate_op) {
|
||||||
|
return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, F64EvaluateOp evaluate_op,
|
||||||
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
|
CHECK_EQ(ty_, F64);
|
||||||
|
Literal input_literal = CreateInputLiteral();
|
||||||
|
FillInputF64(&input_literal);
|
||||||
|
RunImpl<double, double>(enqueue_op, evaluate_op, input_literal,
|
||||||
|
error_spec_gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 GetInputSize() override {
|
||||||
|
FpValues values = std::get<1>(GetParam());
|
||||||
|
return values.GetTotalNumValues();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FillInputF64(Literal* input_literal) {
|
||||||
|
FpValues fp_values = std::get<1>(GetParam());
|
||||||
|
int64 input_size = input_literal->element_count();
|
||||||
|
LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", "
|
||||||
|
<< input_size;
|
||||||
|
absl::Span<double> input_arr = input_literal->data<double>();
|
||||||
|
|
||||||
|
uint64 i = 0;
|
||||||
|
for (auto bits : fp_values) {
|
||||||
|
input_arr[i] = ConvertAndReplaceKnownIncorrectValueWith<double>(bits, 1);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
CHECK_EQ(i, input_size);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); }
|
||||||
|
|
||||||
|
// TODO(bixia): add other unary ops for double
|
||||||
|
|
||||||
|
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SpecialValues, ExhaustiveF64UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(F64),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
NormalValues, ExhaustiveF64UnaryTest,
|
||||||
|
::testing::Combine(::testing::Values(F64),
|
||||||
|
::testing::Values(GetNormals<double>(1000))));
|
||||||
|
|
||||||
|
// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to
|
||||||
|
// keep the peak memory usage low.
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(F64),
|
||||||
|
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<double>(
|
||||||
|
4000000000ull, 16000000))));
|
||||||
|
#endif
|
||||||
|
|
||||||
|
class ExhaustiveComplexUnaryTestBase : public ExhaustiveOpTestBase {
|
||||||
|
public:
|
||||||
|
explicit ExhaustiveComplexUnaryTestBase(PrimitiveType ty)
|
||||||
|
: ExhaustiveOpTestBase(ty) {}
|
||||||
|
|
||||||
|
// A helper for implementing the Run method for unary op test of complex
|
||||||
|
// numbers.
|
||||||
|
//
|
||||||
|
// T is the component type of the complex number.
|
||||||
|
template <typename T>
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op,
|
||||||
|
std::complex<T> (*evaluate_op)(std::complex<T>),
|
||||||
|
FpValues* values_real, FpValues* values_imag,
|
||||||
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
|
Literal input_literal = CreateInputLiteral();
|
||||||
|
|
||||||
|
FillInput<T>(&input_literal, values_real, values_imag);
|
||||||
|
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
auto input = Parameter(&builder, 0, input_literal.shape(), "input");
|
||||||
|
enqueue_op(input);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
|
||||||
|
RunComputation(comp, {&input_literal}));
|
||||||
|
ExpectNearComplex<T>(input_literal, result_literal, evaluate_op,
|
||||||
|
error_spec_gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates the input complex literal given the FpValues representation for
|
||||||
|
// the real and imaginary components.
|
||||||
|
//
|
||||||
|
// T is the component type of the complex number.
|
||||||
|
template <typename T>
|
||||||
|
void FillInput(Literal* input_literal, FpValues* real_values,
|
||||||
|
FpValues* imag_values) {
|
||||||
|
VLOG(2) << " testing input total "
|
||||||
|
<< real_values->GetTotalNumValues() *
|
||||||
|
imag_values->GetTotalNumValues()
|
||||||
|
<< ", range " << real_values->ToString() << " "
|
||||||
|
<< imag_values->ToString();
|
||||||
|
|
||||||
|
absl::Span<std::complex<T>> input_arr =
|
||||||
|
input_literal->data<std::complex<T>>();
|
||||||
|
|
||||||
|
uint64 i = 0;
|
||||||
|
for (auto real : *real_values) {
|
||||||
|
for (auto imag : *imag_values) {
|
||||||
|
input_arr[i] = std::complex<T>(
|
||||||
|
ConvertAndReplaceKnownIncorrectValueWith<T>(real, 1),
|
||||||
|
ConvertAndReplaceKnownIncorrectValueWith<T>(imag, 1));
|
||||||
|
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ExpectNearComplex(const Literal& input_literal,
|
||||||
|
const Literal& result_literal,
|
||||||
|
std::complex<T> (*evaluate_op)(std::complex<T>),
|
||||||
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
|
absl::Span<const std::complex<T>> input_arr =
|
||||||
|
input_literal.data<std::complex<T>>();
|
||||||
|
absl::Span<const std::complex<T>> result_arr =
|
||||||
|
result_literal.data<std::complex<T>>();
|
||||||
|
ASSERT_EQ(result_arr.size(), input_arr.size());
|
||||||
|
int64 mismatches = 0;
|
||||||
|
|
||||||
|
for (int64 i = 0; i < input_arr.size(); ++i) {
|
||||||
|
std::complex<T> input = input_arr[i];
|
||||||
|
std::complex<T> actual = result_arr[i];
|
||||||
|
std::complex<T> expected = evaluate_op(input);
|
||||||
|
|
||||||
|
// TODO(bixia): Need to fix error_spec_gen to consider both components.
|
||||||
|
// This only affects the value specific error_spec, and before we fix
|
||||||
|
// this, it means complex operation testing doesn't support value
|
||||||
|
// specific error_spec yet. We delay the fix to this partially because
|
||||||
|
// we don't know whether it is enough for the error_spec to only take
|
||||||
|
// the absolute value of the complex number.
|
||||||
|
ErrorSpec error_spec = error_spec_gen(input.real());
|
||||||
|
|
||||||
|
if (IsClose(expected.real(), actual.real(), error_spec) &&
|
||||||
|
IsClose(expected.imag(), actual.imag(), error_spec)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(bixia): Need to handle complex operands with subnormals in
|
||||||
|
// real and/or imaginary components.
|
||||||
|
VLOG(2) << "calculate " << StringifyNum(input) << " ;"
|
||||||
|
<< StringifyNum(actual) << "; " << StringifyNum(expected);
|
||||||
|
|
||||||
|
PrintMismatch(&mismatches, [&] {
|
||||||
|
return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
|
||||||
|
StringifyNum(input), StringifyNum(expected),
|
||||||
|
StringifyNum(actual));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(mismatches, 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Unary op test for complex<float>.
|
||||||
|
//
|
||||||
|
// Test parameter is a tuple containing
|
||||||
|
// - primitive type under test,
|
||||||
|
// - two FpValues representing the values for the real and imaginary
|
||||||
|
// components. The complex numbers for the test input is the cartesian
|
||||||
|
// product of the values represented by the two FpValues.
|
||||||
|
class ExhaustiveC64UnaryTest
|
||||||
|
: public ExhaustiveComplexUnaryTestBase,
|
||||||
|
public ::testing::WithParamInterface<
|
||||||
|
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
||||||
|
public:
|
||||||
|
typedef complex64 (*C64EvaluateOp)(complex64);
|
||||||
|
|
||||||
|
ExhaustiveC64UnaryTest()
|
||||||
|
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
||||||
|
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, C64EvaluateOp evaluate_op) {
|
||||||
|
return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, C64EvaluateOp evaluate_op,
|
||||||
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
|
FpValues values_real = std::get<1>(GetParam());
|
||||||
|
FpValues values_imag = std::get<2>(GetParam());
|
||||||
|
ExhaustiveComplexUnaryTestBase::Run<float>(
|
||||||
|
enqueue_op, evaluate_op, &values_real, &values_imag, error_spec_gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 GetInputSize() override {
|
||||||
|
FpValues values_real = std::get<1>(GetParam());
|
||||||
|
FpValues values_imag = std::get<2>(GetParam());
|
||||||
|
return values_real.GetTotalNumValues() * values_imag.GetTotalNumValues();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
F32SpecialValues, ExhaustiveC64UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C64),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
F32SpecialAndNormalValues, ExhaustiveC64UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C64),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>()),
|
||||||
|
::testing::Values(GetNormals<float>(10000))));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
F32NormalAndSpecialValues, ExhaustiveC64UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C64), ::testing::Values(GetNormals<float>(10000)),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<float>())));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
F32NormalAndNormalValues, ExhaustiveC64UnaryTest,
|
||||||
|
::testing::Combine(::testing::Values(C64),
|
||||||
|
::testing::Values(GetNormals<float>(10000)),
|
||||||
|
::testing::Values(GetNormals<float>(10000))));
|
||||||
|
|
||||||
|
// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to
|
||||||
|
// keep the peak memory usage low.
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
F32LargeAndSmallMagnituedNormalValues, ExhaustiveC64UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C64),
|
||||||
|
::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals<float>(40000,
|
||||||
|
4000)),
|
||||||
|
::testing::ValuesIn(
|
||||||
|
GetFpValuesForMagnitudeExtremeNormals<float>(40000, 4000))));
|
||||||
|
|
||||||
|
// Unary op test for complex<double>.
|
||||||
|
//
|
||||||
|
// Test parameter is a tuple containing
|
||||||
|
// - primitive type under test,
|
||||||
|
// - two FpValues representing the values for the real and imaginary
|
||||||
|
// components. The complex numbers for the test input is the cartesian
|
||||||
|
// product of the values represented by the two FpValues.
|
||||||
|
class ExhaustiveC128UnaryTest
|
||||||
|
: public ExhaustiveComplexUnaryTestBase,
|
||||||
|
public ::testing::WithParamInterface<
|
||||||
|
std::tuple<PrimitiveType, FpValues, FpValues>> {
|
||||||
|
public:
|
||||||
|
typedef complex128 (*C128EvaluateOp)(complex128);
|
||||||
|
|
||||||
|
ExhaustiveC128UnaryTest()
|
||||||
|
: ExhaustiveComplexUnaryTestBase(std::get<0>(GetParam())) {}
|
||||||
|
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, C128EvaluateOp evaluate_op) {
|
||||||
|
return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Run(std::function<XlaOp(XlaOp)> enqueue_op, C128EvaluateOp evaluate_op,
|
||||||
|
std::function<ErrorSpec(float)> error_spec_gen) {
|
||||||
|
FpValues values_real = std::get<1>(GetParam());
|
||||||
|
FpValues values_imag = std::get<2>(GetParam());
|
||||||
|
ExhaustiveComplexUnaryTestBase::Run<double>(
|
||||||
|
enqueue_op, evaluate_op, &values_real, &values_imag, error_spec_gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 GetInputSize() override {
|
||||||
|
FpValues values_real = std::get<1>(GetParam());
|
||||||
|
FpValues values_imag = std::get<2>(GetParam());
|
||||||
|
return values_real.GetTotalNumValues() * values_imag.GetTotalNumValues();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XLA_TEST_P(ExhaustiveC128UnaryTest, Log) {
|
||||||
|
// TODO(bixia): only test values that are not too big and not too small
|
||||||
|
// for now and will work on fixing the implementation of XLA
|
||||||
|
// operations to enable test for other values.
|
||||||
|
known_incorrect_fn_ = [&](int64 v) {
|
||||||
|
double f = ConvertValue<double>(v);
|
||||||
|
return std::fpclassify(f) == FP_NAN || std::abs(f) > 5 || std::abs(f) < 1;
|
||||||
|
};
|
||||||
|
Run(Log, [](complex128 x) { return std::log(x); });
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SpecialValues, ExhaustiveC128UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C128),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
SpecialAndNormalValues, ExhaustiveC128UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C128),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>()),
|
||||||
|
::testing::Values(GetNormals<double>(10000))));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
NormalAndSpecialValues, ExhaustiveC128UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C128), ::testing::Values(GetNormals<double>(10000)),
|
||||||
|
::testing::ValuesIn(CreateFpValuesForBoundaryTest<double>())));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
F32NormalAndNormalValues, ExhaustiveC128UnaryTest,
|
||||||
|
::testing::Combine(::testing::Values(C128),
|
||||||
|
::testing::Values(GetNormals<double>(10000)),
|
||||||
|
::testing::Values(GetNormals<double>(10000))));
|
||||||
|
|
||||||
|
// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to
|
||||||
|
// keep the peak memory usage low.
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
LargeAndSmallMagnituedNormalValues, ExhaustiveC128UnaryTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(C128),
|
||||||
|
::testing::ValuesIn(
|
||||||
|
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000)),
|
||||||
|
::testing::ValuesIn(
|
||||||
|
GetFpValuesForMagnitudeExtremeNormals<double>(40000, 2000))));
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user