Avoid signed integer overflow in test setup. Compilers are producing different

code and resulting in bad assumptions.

PiperOrigin-RevId: 293859547
Change-Id: I7bc1ac344ee3db1456bc5538cd7dc25c5f52e21b
This commit is contained in:
Brian Atkinson 2020-02-07 11:38:39 -08:00 committed by TensorFlower Gardener
parent 7bc52f2ff5
commit 85f10eb420

View File

@ -14,22 +14,44 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/overflow.h"
#include <cmath>
#ifdef PLATFORM_WINDOWS
#include <Windows.h>
#endif
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
bool HasOverflow(int64 x, int64 y) {
#ifdef PLATFORM_WINDOWS
// `long double` on MSVC is 64 bits not 80 bits - use a windows specific API
// for this test.
return ::MultiplyHigh(x, y) != 0;
#else
long double dxy = static_cast<long double>(x) * static_cast<long double>(y);
return dxy > std::numeric_limits<int64>::max();
#endif
}
TEST(OverflowTest, Nonnegative) {
// Various interesting values
std::vector<int64> interesting = {0, std::numeric_limits<int64>::max()};
std::vector<int64> interesting = {
0,
std::numeric_limits<int64>::max(),
};
for (int i = 0; i < 63; i++) {
int64 bit = static_cast<int64>(1) << i;
interesting.push_back(bit);
interesting.push_back(bit + 1);
interesting.push_back(bit - 1);
}
for (const int64 mid : {static_cast<int64>(1) << 32,
static_cast<int64>(std::pow(2, 63.0 / 2))}) {
for (int i = -5; i < 5; i++) {
@ -38,14 +60,13 @@ TEST(OverflowTest, Nonnegative) {
}
// Check all pairs
for (auto x : interesting) {
for (auto y : interesting) {
for (int64 x : interesting) {
for (int64 y : interesting) {
int64 xy = MultiplyWithoutOverflow(x, y);
long double dxy = static_cast<long double>(x) * y;
if (dxy > std::numeric_limits<int64>::max()) {
EXPECT_LT(xy, 0);
if (HasOverflow(x, y)) {
EXPECT_LT(xy, 0) << x << " " << y;
} else {
EXPECT_EQ(dxy, xy);
EXPECT_EQ(x * y, xy) << x << " " << y;
}
}
}
@ -54,9 +75,9 @@ TEST(OverflowTest, Nonnegative) {
TEST(OverflowTest, Negative) {
const int64 negatives[] = {-1, std::numeric_limits<int64>::min()};
for (const int64 n : negatives) {
EXPECT_DEATH(MultiplyWithoutOverflow(n, 0), "");
EXPECT_DEATH(MultiplyWithoutOverflow(0, n), "");
EXPECT_DEATH(MultiplyWithoutOverflow(n, n), "");
EXPECT_DEATH(MultiplyWithoutOverflow(n, 0), "") << n;
EXPECT_DEATH(MultiplyWithoutOverflow(0, n), "") << n;
EXPECT_DEATH(MultiplyWithoutOverflow(n, n), "") << n;
}
}