diff --git a/tensorflow/core/util/overflow_test.cc b/tensorflow/core/util/overflow_test.cc index f93ba885e6d..0f9b3571611 100644 --- a/tensorflow/core/util/overflow_test.cc +++ b/tensorflow/core/util/overflow_test.cc @@ -14,22 +14,44 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/overflow.h" + #include <cmath> + +#ifdef PLATFORM_WINDOWS +#include <Windows.h> +#endif + #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { +bool HasOverflow(int64 x, int64 y) { +#ifdef PLATFORM_WINDOWS + // `long double` on MSVC is 64 bits not 80 bits - use a windows specific API + // for this test. + return ::MultiplyHigh(x, y) != 0; +#else + long double dxy = static_cast<long double>(x) * static_cast<long double>(y); + return dxy > std::numeric_limits<int64>::max(); +#endif +} + TEST(OverflowTest, Nonnegative) { // Various interesting values - std::vector<int64> interesting = {0, std::numeric_limits<int64>::max()}; + std::vector<int64> interesting = { + 0, + std::numeric_limits<int64>::max(), + }; + for (int i = 0; i < 63; i++) { int64 bit = static_cast<int64>(1) << i; interesting.push_back(bit); interesting.push_back(bit + 1); interesting.push_back(bit - 1); } + for (const int64 mid : {static_cast<int64>(1) << 32, static_cast<int64>(std::pow(2, 63.0 / 2))}) { for (int i = -5; i < 5; i++) { @@ -38,14 +60,13 @@ TEST(OverflowTest, Nonnegative) { } // Check all pairs - for (auto x : interesting) { - for (auto y : interesting) { + for (int64 x : interesting) { + for (int64 y : interesting) { int64 xy = MultiplyWithoutOverflow(x, y); - long double dxy = static_cast<long double>(x) * y; - if (dxy > std::numeric_limits<int64>::max()) { - EXPECT_LT(xy, 0); + if (HasOverflow(x, y)) { + EXPECT_LT(xy, 0) << x << " " << y; } else { - EXPECT_EQ(dxy, xy); + EXPECT_EQ(x * y, xy) << x << " " << y; } } } @@ -54,9 +75,9 @@ TEST(OverflowTest, Nonnegative) { TEST(OverflowTest, Negative) { const int64 negatives[] = {-1, std::numeric_limits<int64>::min()}; for (const int64 n : negatives) { - EXPECT_DEATH(MultiplyWithoutOverflow(n, 0), ""); - EXPECT_DEATH(MultiplyWithoutOverflow(0, n), ""); - EXPECT_DEATH(MultiplyWithoutOverflow(n, n), ""); + EXPECT_DEATH(MultiplyWithoutOverflow(n, 0), "") << n; + EXPECT_DEATH(MultiplyWithoutOverflow(0, n), "") << n; + EXPECT_DEATH(MultiplyWithoutOverflow(n, n), "") << n; } }