Add int16->int32 quantization to TFLite. Necessary to run keyword spotting model.

PiperOrigin-RevId: 337569973
Change-Id: I1670a6f7d13835014529be8c309e22f4cd46d7df
This commit is contained in:
Nat Jeffries 2020-10-16 13:43:37 -07:00 committed by TensorFlower Gardener
parent 8c3653d203
commit ba58a5c410
2 changed files with 69 additions and 2 deletions

View File

@ -120,8 +120,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} else {
// Requantize use case.
if (input->type == kTfLiteInt16) {
TF_LITE_ENSURE(
context, output->type == kTfLiteInt8 || output->type == kTfLiteInt16);
TF_LITE_ENSURE(context, output->type == kTfLiteInt8 ||
output->type == kTfLiteInt16 ||
output->type == kTfLiteInt32);
} else {
TF_LITE_ENSURE(context,
input->type == kTfLiteInt8 || input->type == kTfLiteUInt8);
@ -198,6 +199,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output->params.zero_point,
GetTensorData<int16_t>(output));
return kTfLiteOk;
case kTfLiteInt32:
// This case is not supported by the converter or other TFLite tools.
// The only use case is for applications that take quantized int32
// inference outputs.
Requantize<kernel_type>(GetTensorData<int16_t>(input),
MatchingFlatSize(input_shape, output_shape),
data->output_multiplier, data->output_shift,
input->params.zero_point,
output->params.zero_point,
GetTensorData<int32_t>(output));
return kTfLiteOk;
default:
ReportError(context, input->type, output->type);
return kTfLiteError;

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cstdint>
#include <initializer_list>
#include <limits>
#include <vector>
#include <gtest/gtest.h>
@ -458,5 +459,59 @@ TEST(QuantizeOpTest, Int16Int8SmallerScaleNeonPath) {
19, 17, 15, 13, 11, 9, 7, 5, 3, 1}));
}
// Input scale 1.0, output scale 1.0, input zeropoint 0, output zeropoint 0
TEST(QuantizeOpTest, Int16Int32SameScale) {
QuantizeOpModel m({TensorType_INT16,
{1, 1, 2, 5},
std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max()},
{TensorType_INT32,
{1, 1, 2, 5},
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()});
// Input will quantized to {1,3,5,7,9,11,13,15,17,19}.
m.SetInputAndQuantize<int16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.Invoke();
EXPECT_THAT(m.GetOutput<int32_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}));
}
// Input scale 0.500000, output scale 1.000000, input zeropoint -1, output
// zeropoint 0
TEST(QuantizeOpTest, Int16Int32LargerScale) {
QuantizeOpModel m({TensorType_INT16,
{1, 1, 2, 5},
std::numeric_limits<int16_t>::min() / 2.0,
std::numeric_limits<int16_t>::max() / 2.0},
{TensorType_INT32,
{1, 1, 2, 5},
std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max()});
m.SetInputAndQuantize<int16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.Invoke();
EXPECT_THAT(m.GetOutput<int32_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}));
}
// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output
// zeropoint 0
TEST(QuantizeOpTest, Int16Int32SmallerScale) {
QuantizeOpModel m({TensorType_INT16,
{1, 1, 2, 5},
std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max()},
{TensorType_INT32,
{1, 1, 2, 5},
std::numeric_limits<int32_t>::min() / 2.0,
std::numeric_limits<int32_t>::max() / 2.0});
m.SetInputAndQuantize<int16_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
m.Invoke();
EXPECT_THAT(m.GetOutput<int32_t>(),
ElementsAreArray({2, 4, 6, 8, 10, 12, 14, 16, 18, 20}));
}
} // namespace
} // namespace tflite