Avoid signed integer overflow in test setup. Compilers are producing different
code and resulting in bad assumptions. PiperOrigin-RevId: 293859547 Change-Id: I7bc1ac344ee3db1456bc5538cd7dc25c5f52e21b
This commit is contained in:
parent
7bc52f2ff5
commit
85f10eb420
@ -14,22 +14,44 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/overflow.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifdef PLATFORM_WINDOWS
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool HasOverflow(int64 x, int64 y) {
|
||||
#ifdef PLATFORM_WINDOWS
|
||||
// `long double` on MSVC is 64 bits not 80 bits - use a windows specific API
|
||||
// for this test.
|
||||
return ::MultiplyHigh(x, y) != 0;
|
||||
#else
|
||||
long double dxy = static_cast<long double>(x) * static_cast<long double>(y);
|
||||
return dxy > std::numeric_limits<int64>::max();
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(OverflowTest, Nonnegative) {
|
||||
// Various interesting values
|
||||
std::vector<int64> interesting = {0, std::numeric_limits<int64>::max()};
|
||||
std::vector<int64> interesting = {
|
||||
0,
|
||||
std::numeric_limits<int64>::max(),
|
||||
};
|
||||
|
||||
for (int i = 0; i < 63; i++) {
|
||||
int64 bit = static_cast<int64>(1) << i;
|
||||
interesting.push_back(bit);
|
||||
interesting.push_back(bit + 1);
|
||||
interesting.push_back(bit - 1);
|
||||
}
|
||||
|
||||
for (const int64 mid : {static_cast<int64>(1) << 32,
|
||||
static_cast<int64>(std::pow(2, 63.0 / 2))}) {
|
||||
for (int i = -5; i < 5; i++) {
|
||||
@ -38,14 +60,13 @@ TEST(OverflowTest, Nonnegative) {
|
||||
}
|
||||
|
||||
// Check all pairs
|
||||
for (auto x : interesting) {
|
||||
for (auto y : interesting) {
|
||||
for (int64 x : interesting) {
|
||||
for (int64 y : interesting) {
|
||||
int64 xy = MultiplyWithoutOverflow(x, y);
|
||||
long double dxy = static_cast<long double>(x) * y;
|
||||
if (dxy > std::numeric_limits<int64>::max()) {
|
||||
EXPECT_LT(xy, 0);
|
||||
if (HasOverflow(x, y)) {
|
||||
EXPECT_LT(xy, 0) << x << " " << y;
|
||||
} else {
|
||||
EXPECT_EQ(dxy, xy);
|
||||
EXPECT_EQ(x * y, xy) << x << " " << y;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -54,9 +75,9 @@ TEST(OverflowTest, Nonnegative) {
|
||||
TEST(OverflowTest, Negative) {
|
||||
const int64 negatives[] = {-1, std::numeric_limits<int64>::min()};
|
||||
for (const int64 n : negatives) {
|
||||
EXPECT_DEATH(MultiplyWithoutOverflow(n, 0), "");
|
||||
EXPECT_DEATH(MultiplyWithoutOverflow(0, n), "");
|
||||
EXPECT_DEATH(MultiplyWithoutOverflow(n, n), "");
|
||||
EXPECT_DEATH(MultiplyWithoutOverflow(n, 0), "") << n;
|
||||
EXPECT_DEATH(MultiplyWithoutOverflow(0, n), "") << n;
|
||||
EXPECT_DEATH(MultiplyWithoutOverflow(n, n), "") << n;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user