Check for overflow in # of bytes computation of tensor allocation.
We check both for product of shape dimensions (# of elements) and number of bytes (elements * sizeof(data_type)). @joyalbin provided an initial adaption of the TF overflowl.h impl in #26859. This is cleaned up and optimized for a refactor that occurred after. PiperOrigin-RevId: 288539371 Change-Id: I9298de40afd22ae72da151cc457c1f0937e97d7a
This commit is contained in:
parent
22ccf0267a
commit
2522ce7dd5
@ -559,16 +559,39 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Multiply two sizes and return true if overflow occurred;
|
||||
// This is based off tensorflow/overflow.h but is simpler as we already
|
||||
// have unsigned numbers. It is also generalized to work where sizeof(size_t)
|
||||
// is not 8.
|
||||
TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
|
||||
constexpr size_t overflow_threshold = (8 * sizeof(size_t)) >> 1;
|
||||
*product = a * b;
|
||||
// If neither integers have non-zero bits past 32 bits can't overflow.
|
||||
// Otherwise check using slow devision.
|
||||
if (__builtin_expect((a | b) >> overflow_threshold != 0, false)) {
|
||||
if (a != 0 && *product / a != b) return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims,
|
||||
size_t dims_size, size_t* bytes) {
|
||||
// TODO(aselle): Check for overflow here using overflow.h in TensorFlow
|
||||
// MultiplyWithoutOverflow.
|
||||
TF_LITE_ENSURE(&context_, bytes != nullptr);
|
||||
size_t count = 1;
|
||||
for (int k = 0; k < dims_size; k++) count *= dims[k];
|
||||
for (int k = 0; k < dims_size; k++) {
|
||||
size_t old_count = count;
|
||||
TF_LITE_ENSURE_MSG(
|
||||
&context_,
|
||||
MultiplyAndCheckOverflow(old_count, dims[k], &count) == kTfLiteOk,
|
||||
"BytesRequired number of elements overflowed.\n");
|
||||
}
|
||||
size_t type_size = 0;
|
||||
TF_LITE_ENSURE_OK(&context_, GetSizeOfType(&context_, type, &type_size));
|
||||
*bytes = type_size * count;
|
||||
TF_LITE_ENSURE_MSG(
|
||||
&context_, MultiplyAndCheckOverflow(type_size, count, bytes) == kTfLiteOk,
|
||||
"BytesRequired number of bytes overflowed.\n");
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -820,6 +820,70 @@ TEST(BasicInterpreter, TestCustomErrorReporter) {
|
||||
ASSERT_EQ(reporter.num_calls(), 1);
|
||||
}
|
||||
|
||||
TEST(BasicInterpreter, TestOverflow) {
|
||||
TestErrorReporter reporter;
|
||||
Interpreter interpreter(&reporter);
|
||||
TfLiteQuantizationParams quantized;
|
||||
|
||||
ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
|
||||
ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
|
||||
// Overflow testing is pointer word size dependent.
|
||||
if (sizeof(size_t) == 8) {
|
||||
// #bits for bytecount = 30 + 30 + 2 = 62 < 64
|
||||
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30}, quantized),
|
||||
kTfLiteOk);
|
||||
// #bits for element count = 30 + 30 + 2 = 62 < 64 (no overflow)
|
||||
// #bits for byte count = 30 + 30 + 2 + 2 = 64 == 64 (overflow)
|
||||
ASSERT_NE(
|
||||
interpreter.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30, 1 << 2}, quantized),
|
||||
kTfLiteOk);
|
||||
EXPECT_THAT(
|
||||
reporter.error_messages(),
|
||||
testing::EndsWith("BytesRequired number of bytes overflowed.\n"));
|
||||
// #bits for element count = 30 + 30 + 2 + 4 = 66 > 64 (overflow).
|
||||
// #bits for byte count = 30 + 30 + 2 + 4 + 2 = 68 > 64 (overflow).
|
||||
reporter.Reset();
|
||||
ASSERT_NE(interpreter.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30, 1 << 2, 1 << 4},
|
||||
quantized),
|
||||
kTfLiteOk);
|
||||
EXPECT_THAT(
|
||||
reporter.error_messages(),
|
||||
testing::EndsWith("BytesRequired number of elements overflowed.\n"));
|
||||
|
||||
} else if (sizeof(size_t) == 4) {
|
||||
// #bits for bytecount = 14 + 14 + 2 = 30 < 32
|
||||
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14}, quantized),
|
||||
kTfLiteOk);
|
||||
// #bits for element count = 14 + 14 + 3 = 31 < 32 (no overflow).
|
||||
// #bits for byte count = 14 + 14 + 3 + 2 = 33 > 32 (overflow).
|
||||
ASSERT_NE(
|
||||
interpreter.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14, 1 << 3}, quantized),
|
||||
kTfLiteOk);
|
||||
EXPECT_THAT(
|
||||
reporter.error_messages(),
|
||||
testing::EndsWith("BytesRequired number of bytes overflowed.\n"));
|
||||
// #bits for element count = 14 + 14 + 4 = 32 == 32 (overflow).
|
||||
// byte count also overflows, but we don't get to that check.
|
||||
reporter.Reset();
|
||||
ASSERT_NE(
|
||||
interpreter.SetTensorParametersReadWrite(
|
||||
0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14, 1 << 4}, quantized),
|
||||
kTfLiteOk);
|
||||
EXPECT_THAT(
|
||||
reporter.error_messages(),
|
||||
testing::EndsWith("BytesRequired number of elements overflowed.\n"));
|
||||
} else {
|
||||
// This test failing means that we are using a non 32/64 bit architecture.
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BasicInterpreter, TestUseNNAPI) {
|
||||
TestErrorReporter reporter;
|
||||
Interpreter interpreter(&reporter);
|
||||
|
Loading…
x
Reference in New Issue
Block a user