From 1b41b5a30d744a4b6cfdb27246cf19d5984db3d5 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 9 Feb 2021 03:14:10 -0800 Subject: [PATCH] Make NaNs equal for testing The default behavior is that NaN != NaN, making it impossible to check values against known good values that contains NaN. Change the behavior so NaN is equal (and close) to just NaN. While this is incompatible with IEEE 754, it makes writing tests much easier. PiperOrigin-RevId: 356463179 Change-Id: Iad2bd77708eb21e86398413425f77423442de5ad --- tensorflow/core/framework/tensor_testutil.cc | 13 ++++++++-- .../core/framework/tensor_testutil_test.cc | 25 +++++++++++-------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc index 804d5df31ed..4d971da21e5 100644 --- a/tensorflow/core/framework/tensor_testutil.cc +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -48,16 +48,23 @@ static ::testing::AssertionResult EqualFailure(const T& x, const T& y) { << " not equal to " << y; } static ::testing::AssertionResult IsEqual(float x, float y) { - if (::testing::internal::CmpHelperFloatingPointEQ("", "", x, y)) + // We consider NaNs equal for testing. + if ((isnan(x) && isnan(y)) || + ::testing::internal::CmpHelperFloatingPointEQ("", "", x, y)) return ::testing::AssertionSuccess(); return EqualFailure(x, y); } static ::testing::AssertionResult IsEqual(double x, double y) { - if (::testing::internal::CmpHelperFloatingPointEQ("", "", x, y)) + // We consider NaNs equal for testing. + if ((isnan(x) && isnan(y)) || + ::testing::internal::CmpHelperFloatingPointEQ("", "", x, y)) return ::testing::AssertionSuccess(); return EqualFailure(x, y); } static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y) { + // We consider NaNs equal for testing. + if (isnan(x) && isnan(y)) return ::testing::AssertionSuccess(); + // Below is a reimplementation of CmpHelperFloatingPointEQ, which // we cannot use because Eigen::half is not default-constructible. @@ -107,6 +114,8 @@ static void ExpectEqual(const Tensor& x, const Tensor& y) { template static ::testing::AssertionResult IsClose(const T& x, const T& y, const T& atol, const T& rtol) { + // We consider NaNs equal for testing. + if (isnan(x) && isnan(y)) return ::testing::AssertionSuccess(); if (x == y) return ::testing::AssertionSuccess(); // Handle infinity. auto tolerance = atol + rtol * Eigen::numext::abs(x); if (Eigen::numext::abs(x - y) <= tolerance) diff --git a/tensorflow/core/framework/tensor_testutil_test.cc b/tensorflow/core/framework/tensor_testutil_test.cc index 8c02f18d77f..677b157294e 100644 --- a/tensorflow/core/framework/tensor_testutil_test.cc +++ b/tensorflow/core/framework/tensor_testutil_test.cc @@ -34,11 +34,14 @@ void TestEdgeCasesNear() { EXPECT_FALSE( IsClose(Eigen::NumTraits::lowest(), Eigen::NumTraits::highest(), static_cast(Eigen::NumTraits::highest()), 0.0)); - EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), - Eigen::NumTraits::quiet_NaN(), 0.0, 0.0)); - EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), - Eigen::NumTraits::quiet_NaN(), + EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), T(0.0), 0.0, 0.0)); + EXPECT_TRUE(IsClose(Eigen::NumTraits::quiet_NaN(), + Eigen::NumTraits::quiet_NaN(), 0.0, 0.0)); + EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), T(0.0), Eigen::NumTraits::infinity(), 0.0)); + EXPECT_TRUE(IsClose(Eigen::NumTraits::quiet_NaN(), + Eigen::NumTraits::quiet_NaN(), + Eigen::NumTraits::infinity(), 0.0)); } // For debug printing. Example usage: @@ -207,12 +210,14 @@ void TestEdgeCasesClose() { Eigen::NumTraits::highest(), static_cast(Eigen::NumTraits::highest()), static_cast(Eigen::NumTraits::highest()))); - EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), - Eigen::NumTraits::quiet_NaN(), 0.0, 0.0)); - EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), - Eigen::NumTraits::quiet_NaN(), - Eigen::NumTraits::infinity(), - Eigen::NumTraits::infinity())); + EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), T(0.0), 0.0, 0.0)); + EXPECT_TRUE(IsClose(Eigen::NumTraits::quiet_NaN(), + Eigen::NumTraits::quiet_NaN(), 0.0, 0.0)); + EXPECT_FALSE(IsClose(Eigen::NumTraits::quiet_NaN(), T(0.0), + Eigen::NumTraits::infinity(), 0.0)); + EXPECT_TRUE(IsClose(Eigen::NumTraits::quiet_NaN(), + Eigen::NumTraits::quiet_NaN(), + Eigen::NumTraits::infinity(), 0.0)); } TEST(TensorTestUtilTest, ExpectTensorCloseHalf) {