From 2522ce7dd5d28c9733824a66133fc918290e3ed0 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Tue, 7 Jan 2020 11:38:05 -0800 Subject: [PATCH] Check for overflow in # of bytes computation of tensor allocation. We check both for product of shape dimensions (# of elements) and number of bytes (elements * sizeof(data_type)). @joyalbin provided an initial adaption of the TF overflowl.h impl in #26859. This is cleaned up and optimized for a refactor that occurred after. PiperOrigin-RevId: 288539371 Change-Id: I9298de40afd22ae72da151cc457c1f0937e97d7a --- tensorflow/lite/core/subgraph.cc | 31 ++++++++++++-- tensorflow/lite/interpreter_test.cc | 64 +++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 188bb6f70e8..5fcf754d244 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -559,16 +559,39 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices, return kTfLiteOk; } +namespace { +// Multiply two sizes and return true if overflow occurred; +// This is based off tensorflow/overflow.h but is simpler as we already +// have unsigned numbers. It is also generalized to work where sizeof(size_t) +// is not 8. +TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) { + constexpr size_t overflow_threshold = (8 * sizeof(size_t)) >> 1; + *product = a * b; + // If neither integers have non-zero bits past 32 bits can't overflow. + // Otherwise check using slow devision. + if (__builtin_expect((a | b) >> overflow_threshold != 0, false)) { + if (a != 0 && *product / a != b) return kTfLiteError; + } + return kTfLiteOk; +} +} // namespace + TfLiteStatus Subgraph::BytesRequired(TfLiteType type, const int* dims, size_t dims_size, size_t* bytes) { - // TODO(aselle): Check for overflow here using overflow.h in TensorFlow - // MultiplyWithoutOverflow. TF_LITE_ENSURE(&context_, bytes != nullptr); size_t count = 1; - for (int k = 0; k < dims_size; k++) count *= dims[k]; + for (int k = 0; k < dims_size; k++) { + size_t old_count = count; + TF_LITE_ENSURE_MSG( + &context_, + MultiplyAndCheckOverflow(old_count, dims[k], &count) == kTfLiteOk, + "BytesRequired number of elements overflowed.\n"); + } size_t type_size = 0; TF_LITE_ENSURE_OK(&context_, GetSizeOfType(&context_, type, &type_size)); - *bytes = type_size * count; + TF_LITE_ENSURE_MSG( + &context_, MultiplyAndCheckOverflow(type_size, count, bytes) == kTfLiteOk, + "BytesRequired number of bytes overflowed.\n"); return kTfLiteOk; } diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index df0ab67c410..7d5babc43d2 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -820,6 +820,70 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { ASSERT_EQ(reporter.num_calls(), 1); } +TEST(BasicInterpreter, TestOverflow) { + TestErrorReporter reporter; + Interpreter interpreter(&reporter); + TfLiteQuantizationParams quantized; + + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + // Overflow testing is pointer word size dependent. + if (sizeof(size_t) == 8) { + // #bits for bytecount = 30 + 30 + 2 = 62 < 64 + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30}, quantized), + kTfLiteOk); + // #bits for element count = 30 + 30 + 2 = 62 < 64 (no overflow) + // #bits for byte count = 30 + 30 + 2 + 2 = 64 == 64 (overflow) + ASSERT_NE( + interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30, 1 << 2}, quantized), + kTfLiteOk); + EXPECT_THAT( + reporter.error_messages(), + testing::EndsWith("BytesRequired number of bytes overflowed.\n")); + // #bits for element count = 30 + 30 + 2 + 4 = 66 > 64 (overflow). + // #bits for byte count = 30 + 30 + 2 + 4 + 2 = 68 > 64 (overflow). + reporter.Reset(); + ASSERT_NE(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "in1", {1 << 30, 1 << 30, 1 << 2, 1 << 4}, + quantized), + kTfLiteOk); + EXPECT_THAT( + reporter.error_messages(), + testing::EndsWith("BytesRequired number of elements overflowed.\n")); + + } else if (sizeof(size_t) == 4) { + // #bits for bytecount = 14 + 14 + 2 = 30 < 32 + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14}, quantized), + kTfLiteOk); + // #bits for element count = 14 + 14 + 3 = 31 < 32 (no overflow). + // #bits for byte count = 14 + 14 + 3 + 2 = 33 > 32 (overflow). + ASSERT_NE( + interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14, 1 << 3}, quantized), + kTfLiteOk); + EXPECT_THAT( + reporter.error_messages(), + testing::EndsWith("BytesRequired number of bytes overflowed.\n")); + // #bits for element count = 14 + 14 + 4 = 32 == 32 (overflow). + // byte count also overflows, but we don't get to that check. + reporter.Reset(); + ASSERT_NE( + interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "in1", {1 << 14, 1 << 14, 1 << 4}, quantized), + kTfLiteOk); + EXPECT_THAT( + reporter.error_messages(), + testing::EndsWith("BytesRequired number of elements overflowed.\n")); + } else { + // This test failing means that we are using a non 32/64 bit architecture. + ASSERT_TRUE(false); + } +} + TEST(BasicInterpreter, TestUseNNAPI) { TestErrorReporter reporter; Interpreter interpreter(&reporter);