From fb23e44515781d6a22d84b10768eb583a57cd566 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Wed, 4 Dec 2019 11:37:37 -0800 Subject: [PATCH] Handle edge cases gracefully in GetInvSqrtQuantizedMultiplierExp. PiperOrigin-RevId: 283799303 Change-Id: Iecffd57dbce7fa231cc20f3db5efa3f2bb9d474a --- tensorflow/lite/kernels/internal/BUILD | 1 + tensorflow/lite/kernels/internal/common.h | 13 +++++++++- .../internal/quantization_util_test.cc | 25 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 646f14680ac..d71b36547f2 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -361,6 +361,7 @@ cc_test( name = "quantization_util_test", srcs = ["quantization_util_test.cc"], deps = [ + ":common", ":quantization_util", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/kernels/internal/common.h b/tensorflow/lite/kernels/internal/common.h index 0c4fbc1e84e..5e4ba25b711 100644 --- a/tensorflow/lite/kernels/internal/common.h +++ b/tensorflow/lite/kernels/internal/common.h @@ -432,12 +432,23 @@ inline int32 GetReciprocal(int32 x, int x_integer_digits, inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift, int32* output_inv_sqrt, int* output_shift) { + TFLITE_DCHECK_GE(input, 0); + if (input <= 1) { + // Handle the input value 1 separately to avoid overflow in that case + // in the general computation below (b/143972021). Also handle 0 as if it + // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid + // but rare/unrealistic input value. We can expect both to occur in some + // incompletely trained models, but probably not in fully trained models. + *output_inv_sqrt = std::numeric_limits::max(); + *output_shift = 0; + return; + } + TFLITE_DCHECK_GT(input, 1); *output_shift = 11; while (input >= (1 << 29)) { input /= 4; ++*output_shift; } - TFLITE_DCHECK_GT(input, 0); const unsigned max_left_shift_bits = CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; diff --git a/tensorflow/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc index 132befbb020..053b3116a15 100644 --- a/tensorflow/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include + #include #include +#include "tensorflow/lite/kernels/internal/common.h" namespace tflite { namespace { @@ -397,6 +400,28 @@ TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) { } #endif +TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) { + auto inv_sqrt = [](std::int32_t input) { + int32_t output; + int output_shift; + GetInvSqrtQuantizedMultiplierExp(input, 1, &output, &output_shift); + return std::pair{output, output_shift}; + }; + + const auto kInt32Max = std::numeric_limits::max(); + EXPECT_THAT(inv_sqrt(0), Pair(kInt32Max, 0)); + EXPECT_THAT(inv_sqrt(1), Pair(kInt32Max, 0)); + EXPECT_THAT(inv_sqrt(2), Pair(1518498372, 0)); + EXPECT_THAT(inv_sqrt(3), Pair(1239850284, 0)); + EXPECT_THAT(inv_sqrt(4), Pair(1073741828, 0)); + EXPECT_THAT(inv_sqrt(100), Pair(214748363, 0)); + EXPECT_THAT(inv_sqrt(10000), Pair(343597361, 4)); + EXPECT_THAT(inv_sqrt(1000000), Pair(274877901, 7)); + EXPECT_THAT(inv_sqrt(100000000), Pair(219902323, 10)); + EXPECT_THAT(inv_sqrt((1 << 30)), Pair(268435457, 12)); + EXPECT_THAT(inv_sqrt(kInt32Max), Pair(189812531, 12)); +} + TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) { auto quantize = [](double beta, double scale, int integer_bits) { int32_t q;