Check for overflow in # of bytes computation of tensor allocation.

We check both for product of shape dimensions (# of elements)
and number of bytes (elements * sizeof(data_type)).

@joyalbin provided an initial adaption of the TF overflowl.h impl in
#26859. This is cleaned up and optimized for a refactor that
occurred after.

PiperOrigin-RevId: 288539371
Change-Id: I9298de40afd22ae72da151cc457c1f0937e97d7a
This commit is contained in:
Andrew Selle 2020-01-07 11:38:05 -08:00 committed by TensorFlower Gardener
parent 22ccf0267a
commit 2522ce7dd5
2 changed files with 91 additions and 4 deletions

View File

@ -559,16 +559,39 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
return kTfLiteOk;
}
namespace {
// Multiply two sizes and return true if overflow occurred;
// This is based off tensorflow/overflow.h but is simpler as we already
// have unsigned numbers. It is also generalized to work where sizeof(size_t)
// is not 8.
TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
constexpr size_t overflow_threshold = (8 * sizeof(size_t)) >> 1;
*product = a * b;
// If neither integers have non-zero bits past 32 bits can't overflow.
// Otherwise check using slow devision.
if (__builtin_expect((a | b) >> overflow_threshold != 0, false)) {
if (a != 0 && *product / a != b) return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
size_t dims_size, size_t* bytes) {
// TODO(aselle): Check for overflow here using overflow.h in TensorFlow
// MultiplyWithoutOverflow.
TF_LITE_ENSURE(&context_, bytes != nullptr);
size_t count = 1;
for (int k = 0; k < dims_size; k++) count *= dims[k];
for (int k = 0; k < dims_size; k++) {
size_t old_count = count;
TF_LITE_ENSURE_MSG(
&context_,
MultiplyAndCheckOverflow(old_count, dims[k], &count) == kTfLiteOk,
"BytesRequired number of elements overflowed.\n");
}
size_t type_size = 0;
TF_LITE_ENSURE_OK(&context_, GetSizeOfType(&context_, type, &type_size));
*bytes = type_size * count;
TF_LITE_ENSURE_MSG(
&context_, MultiplyAndCheckOverflow(type_size, count, bytes) == kTfLiteOk,
"BytesRequired number of bytes overflowed.\n");
return kTfLiteOk;
}

View File

@ -820,6 +820,70 @@ TEST(BasicInterpreter, TestCustomErrorReporter) {
ASSERT_EQ(reporter.num_calls(), 1);
}
TEST(BasicInterpreter, TestOverflow) {
TestErrorReporter reporter;
Interpreter interpreter(&reporter);
TfLiteQuantizationParams quantized;
ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
// Overflow testing is pointer word size dependent.
if (sizeof(size_t) == 8) {
// #bits for bytecount = 30 + 30 + 2 = 62 < 64
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30}, quantized),
kTfLiteOk);
// #bits for element count = 30 + 30 + 2 = 62 < 64 (no overflow)
// #bits for byte count = 30 + 30 + 2 + 2 = 64 == 64 (overflow)
ASSERT_NE(
interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30, 1 << 2}, quantized),
kTfLiteOk);
EXPECT_THAT(
reporter.error_messages(),
testing::EndsWith("BytesRequired number of bytes overflowed.\n"));
// #bits for element count = 30 + 30 + 2 + 4 = 66 > 64 (overflow).
// #bits for byte count = 30 + 30 + 2 + 4 + 2 = 68 > 64 (overflow).
reporter.Reset();
ASSERT_NE(interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30, 1 << 2, 1 << 4},
quantized),
kTfLiteOk);
EXPECT_THAT(
reporter.error_messages(),
testing::EndsWith("BytesRequired number of elements overflowed.\n"));
} else if (sizeof(size_t) == 4) {
// #bits for bytecount = 14 + 14 + 2 = 30 < 32
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14}, quantized),
kTfLiteOk);
// #bits for element count = 14 + 14 + 3 = 31 < 32 (no overflow).
// #bits for byte count = 14 + 14 + 3 + 2 = 33 > 32 (overflow).
ASSERT_NE(
interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14, 1 << 3}, quantized),
kTfLiteOk);
EXPECT_THAT(
reporter.error_messages(),
testing::EndsWith("BytesRequired number of bytes overflowed.\n"));
// #bits for element count = 14 + 14 + 4 = 32 == 32 (overflow).
// byte count also overflows, but we don't get to that check.
reporter.Reset();
ASSERT_NE(
interpreter.SetTensorParametersReadWrite(
0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14, 1 << 4}, quantized),
kTfLiteOk);
EXPECT_THAT(
reporter.error_messages(),
testing::EndsWith("BytesRequired number of elements overflowed.\n"));
} else {
// This test failing means that we are using a non 32/64 bit architecture.
ASSERT_TRUE(false);
}
}
TEST(BasicInterpreter, TestUseNNAPI) {
TestErrorReporter reporter;
Interpreter interpreter(&reporter);