Make NaNs equal for testing

The default behavior is that NaN != NaN, making it impossible to check
values against known good values that contains NaN. Change the behavior
so NaN is equal (and close) to just NaN. While this is incompatible with
IEEE 754, it makes writing tests much easier.

PiperOrigin-RevId: 356463179
Change-Id: Iad2bd77708eb21e86398413425f77423442de5ad
This commit is contained in:
Benjamin Kramer 2021-02-09 03:14:10 -08:00 committed by TensorFlower Gardener
parent 5eac5b75b4
commit 1b41b5a30d
2 changed files with 26 additions and 12 deletions

View File

@ -48,16 +48,23 @@ static ::testing::AssertionResult EqualFailure(const T& x, const T& y) {
<< " not equal to " << y;
}
static ::testing::AssertionResult IsEqual(float x, float y) {
if (::testing::internal::CmpHelperFloatingPointEQ<float>("", "", x, y))
// We consider NaNs equal for testing.
if ((isnan(x) && isnan(y)) ||
::testing::internal::CmpHelperFloatingPointEQ<float>("", "", x, y))
return ::testing::AssertionSuccess();
return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(double x, double y) {
if (::testing::internal::CmpHelperFloatingPointEQ<double>("", "", x, y))
// We consider NaNs equal for testing.
if ((isnan(x) && isnan(y)) ||
::testing::internal::CmpHelperFloatingPointEQ<double>("", "", x, y))
return ::testing::AssertionSuccess();
return EqualFailure(x, y);
}
static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y) {
// We consider NaNs equal for testing.
if (isnan(x) && isnan(y)) return ::testing::AssertionSuccess();
// Below is a reimplementation of CmpHelperFloatingPointEQ<Eigen::half>, which
// we cannot use because Eigen::half is not default-constructible.
@ -107,6 +114,8 @@ static void ExpectEqual(const Tensor& x, const Tensor& y) {
template <typename T>
static ::testing::AssertionResult IsClose(const T& x, const T& y, const T& atol,
const T& rtol) {
// We consider NaNs equal for testing.
if (isnan(x) && isnan(y)) return ::testing::AssertionSuccess();
if (x == y) return ::testing::AssertionSuccess(); // Handle infinity.
auto tolerance = atol + rtol * Eigen::numext::abs(x);
if (Eigen::numext::abs(x - y) <= tolerance)

View File

@ -34,11 +34,14 @@ void TestEdgeCasesNear() {
EXPECT_FALSE(
IsClose(Eigen::NumTraits<T>::lowest(), Eigen::NumTraits<T>::highest(),
static_cast<double>(Eigen::NumTraits<T>::highest()), 0.0));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(), 0.0, 0.0));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(),
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(), T(0.0), 0.0, 0.0));
EXPECT_TRUE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(), 0.0, 0.0));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(), T(0.0),
Eigen::NumTraits<double>::infinity(), 0.0));
EXPECT_TRUE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<double>::infinity(), 0.0));
}
// For debug printing. Example usage:
@ -207,12 +210,14 @@ void TestEdgeCasesClose() {
Eigen::NumTraits<T>::highest(),
static_cast<double>(Eigen::NumTraits<T>::highest()),
static_cast<double>(Eigen::NumTraits<T>::highest())));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(), 0.0, 0.0));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<double>::infinity(),
Eigen::NumTraits<double>::infinity()));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(), T(0.0), 0.0, 0.0));
EXPECT_TRUE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(), 0.0, 0.0));
EXPECT_FALSE(IsClose(Eigen::NumTraits<T>::quiet_NaN(), T(0.0),
Eigen::NumTraits<double>::infinity(), 0.0));
EXPECT_TRUE(IsClose(Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<T>::quiet_NaN(),
Eigen::NumTraits<double>::infinity(), 0.0));
}
TEST(TensorTestUtilTest, ExpectTensorCloseHalf) {