diff --git a/tensorflow/lite/kernels/quantize.cc b/tensorflow/lite/kernels/quantize.cc index 8f396355777..8ddc18be2b1 100644 --- a/tensorflow/lite/kernels/quantize.cc +++ b/tensorflow/lite/kernels/quantize.cc @@ -120,8 +120,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { // Requantize use case. if (input->type == kTfLiteInt16) { - TF_LITE_ENSURE( - context, output->type == kTfLiteInt8 || output->type == kTfLiteInt16); + TF_LITE_ENSURE(context, output->type == kTfLiteInt8 || + output->type == kTfLiteInt16 || + output->type == kTfLiteInt32); } else { TF_LITE_ENSURE(context, input->type == kTfLiteInt8 || input->type == kTfLiteUInt8); @@ -198,6 +199,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output->params.zero_point, GetTensorData(output)); return kTfLiteOk; + case kTfLiteInt32: + // This case is not supported by the converter or other TFLite tools. + // The only use case is for applications that take quantized int32 + // inference outputs. + Requantize(GetTensorData(input), + MatchingFlatSize(input_shape, output_shape), + data->output_multiplier, data->output_shift, + input->params.zero_point, + output->params.zero_point, + GetTensorData(output)); + return kTfLiteOk; default: ReportError(context, input->type, output->type); return kTfLiteError; diff --git a/tensorflow/lite/kernels/quantize_test.cc b/tensorflow/lite/kernels/quantize_test.cc index d7392b3e3ea..a8d68f6875b 100644 --- a/tensorflow/lite/kernels/quantize_test.cc +++ b/tensorflow/lite/kernels/quantize_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include #include @@ -458,5 +459,59 @@ TEST(QuantizeOpTest, Int16Int8SmallerScaleNeonPath) { 19, 17, 15, 13, 11, 9, 7, 5, 3, 1})); } +// Input scale 1.0, output scale 1.0, input zeropoint 0, output zeropoint 0 +TEST(QuantizeOpTest, Int16Int32SameScale) { + QuantizeOpModel m({TensorType_INT16, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}, + {TensorType_INT32, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}); + + // Input will quantized to {1,3,5,7,9,11,13,15,17,19}. + m.SetInputAndQuantize({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +// Input scale 0.500000, output scale 1.000000, input zeropoint -1, output +// zeropoint 0 +TEST(QuantizeOpTest, Int16Int32LargerScale) { + QuantizeOpModel m({TensorType_INT16, + {1, 1, 2, 5}, + std::numeric_limits::min() / 2.0, + std::numeric_limits::max() / 2.0}, + {TensorType_INT32, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}); + + m.SetInputAndQuantize({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +// Input scale 1.000000, output scale 0.500000, input zeropoint -1, output +// zeropoint 0 +TEST(QuantizeOpTest, Int16Int32SmallerScale) { + QuantizeOpModel m({TensorType_INT16, + {1, 1, 2, 5}, + std::numeric_limits::min(), + std::numeric_limits::max()}, + {TensorType_INT32, + {1, 1, 2, 5}, + std::numeric_limits::min() / 2.0, + std::numeric_limits::max() / 2.0}); + + m.SetInputAndQuantize({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 4, 6, 8, 10, 12, 14, 16, 18, 20})); +} + } // namespace } // namespace tflite