Handle edge cases gracefully in GetInvSqrtQuantizedMultiplierExp.

PiperOrigin-RevId: 283799303
Change-Id: Iecffd57dbce7fa231cc20f3db5efa3f2bb9d474a
This commit is contained in:
Benoit Jacob 2019-12-04 11:37:37 -08:00 committed by TensorFlower Gardener
parent 4abe9e27bc
commit fb23e44515
3 changed files with 38 additions and 1 deletions

View File

@ -361,6 +361,7 @@ cc_test(
name = "quantization_util_test",
srcs = ["quantization_util_test.cc"],
deps = [
":common",
":quantization_util",
"@com_google_googletest//:gtest_main",
],

View File

@ -432,12 +432,23 @@ inline int32 GetReciprocal(int32 x, int x_integer_digits,
inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift,
int32* output_inv_sqrt,
int* output_shift) {
TFLITE_DCHECK_GE(input, 0);
if (input <= 1) {
// Handle the input value 1 separately to avoid overflow in that case
// in the general computation below (b/143972021). Also handle 0 as if it
// were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
// but rare/unrealistic input value. We can expect both to occur in some
// incompletely trained models, but probably not in fully trained models.
*output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
*output_shift = 0;
return;
}
TFLITE_DCHECK_GT(input, 1);
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
const unsigned max_left_shift_bits =
CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;

View File

@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include <limits>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace {
@ -397,6 +400,28 @@ TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
}
#endif
TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
auto inv_sqrt = [](std::int32_t input) {
int32_t output;
int output_shift;
GetInvSqrtQuantizedMultiplierExp(input, 1, &output, &output_shift);
return std::pair<int32_t, int>{output, output_shift};
};
const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
EXPECT_THAT(inv_sqrt(0), Pair(kInt32Max, 0));
EXPECT_THAT(inv_sqrt(1), Pair(kInt32Max, 0));
EXPECT_THAT(inv_sqrt(2), Pair(1518498372, 0));
EXPECT_THAT(inv_sqrt(3), Pair(1239850284, 0));
EXPECT_THAT(inv_sqrt(4), Pair(1073741828, 0));
EXPECT_THAT(inv_sqrt(100), Pair(214748363, 0));
EXPECT_THAT(inv_sqrt(10000), Pair(343597361, 4));
EXPECT_THAT(inv_sqrt(1000000), Pair(274877901, 7));
EXPECT_THAT(inv_sqrt(100000000), Pair(219902323, 10));
EXPECT_THAT(inv_sqrt((1 << 30)), Pair(268435457, 12));
EXPECT_THAT(inv_sqrt(kInt32Max), Pair(189812531, 12));
}
TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
auto quantize = [](double beta, double scale, int integer_bits) {
int32_t q;