Handle edge cases gracefully in GetInvSqrtQuantizedMultiplierExp.
PiperOrigin-RevId: 283799303 Change-Id: Iecffd57dbce7fa231cc20f3db5efa3f2bb9d474a
This commit is contained in:
parent
4abe9e27bc
commit
fb23e44515
@ -361,6 +361,7 @@ cc_test(
|
||||
name = "quantization_util_test",
|
||||
srcs = ["quantization_util_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
":quantization_util",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
|
@ -432,12 +432,23 @@ inline int32 GetReciprocal(int32 x, int x_integer_digits,
|
||||
inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int reverse_shift,
|
||||
int32* output_inv_sqrt,
|
||||
int* output_shift) {
|
||||
TFLITE_DCHECK_GE(input, 0);
|
||||
if (input <= 1) {
|
||||
// Handle the input value 1 separately to avoid overflow in that case
|
||||
// in the general computation below (b/143972021). Also handle 0 as if it
|
||||
// were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
|
||||
// but rare/unrealistic input value. We can expect both to occur in some
|
||||
// incompletely trained models, but probably not in fully trained models.
|
||||
*output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
|
||||
*output_shift = 0;
|
||||
return;
|
||||
}
|
||||
TFLITE_DCHECK_GT(input, 1);
|
||||
*output_shift = 11;
|
||||
while (input >= (1 << 29)) {
|
||||
input /= 4;
|
||||
++*output_shift;
|
||||
}
|
||||
TFLITE_DCHECK_GT(input, 0);
|
||||
const unsigned max_left_shift_bits =
|
||||
CountLeadingZeros(static_cast<uint32>(input)) - 1;
|
||||
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
|
||||
|
@ -14,8 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
@ -397,6 +400,28 @@ TEST(QuantizationUtilTest, QuantizeMultiplierUnderflow) {
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(QuantizationUtilTest, GetInvSqrtQuantizedMultiplierExp) {
|
||||
auto inv_sqrt = [](std::int32_t input) {
|
||||
int32_t output;
|
||||
int output_shift;
|
||||
GetInvSqrtQuantizedMultiplierExp(input, 1, &output, &output_shift);
|
||||
return std::pair<int32_t, int>{output, output_shift};
|
||||
};
|
||||
|
||||
const auto kInt32Max = std::numeric_limits<std::int32_t>::max();
|
||||
EXPECT_THAT(inv_sqrt(0), Pair(kInt32Max, 0));
|
||||
EXPECT_THAT(inv_sqrt(1), Pair(kInt32Max, 0));
|
||||
EXPECT_THAT(inv_sqrt(2), Pair(1518498372, 0));
|
||||
EXPECT_THAT(inv_sqrt(3), Pair(1239850284, 0));
|
||||
EXPECT_THAT(inv_sqrt(4), Pair(1073741828, 0));
|
||||
EXPECT_THAT(inv_sqrt(100), Pair(214748363, 0));
|
||||
EXPECT_THAT(inv_sqrt(10000), Pair(343597361, 4));
|
||||
EXPECT_THAT(inv_sqrt(1000000), Pair(274877901, 7));
|
||||
EXPECT_THAT(inv_sqrt(100000000), Pair(219902323, 10));
|
||||
EXPECT_THAT(inv_sqrt((1 << 30)), Pair(268435457, 12));
|
||||
EXPECT_THAT(inv_sqrt(kInt32Max), Pair(189812531, 12));
|
||||
}
|
||||
|
||||
TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) {
|
||||
auto quantize = [](double beta, double scale, int integer_bits) {
|
||||
int32_t q;
|
||||
|
Loading…
Reference in New Issue
Block a user