diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc index b58a1cb368e..efaf2e583cd 100644 --- a/tensorflow/lite/micro/kernels/quantize.cc +++ b/tensorflow/lite/micro/kernels/quantize.cc @@ -66,11 +66,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || input->type == kTfLiteInt16 || input->type == kTfLiteInt8); - TF_LITE_ENSURE(context, - output->type == kTfLiteUInt8 || output->type == kTfLiteInt8); + TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 || + output->type == kTfLiteInt8 || + output->type == kTfLiteInt16); if ((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) && - output->type == kTfLiteInt8) { + output->type == kTfLiteInt8 || + (input->type == kTfLiteInt16 && output->type == kTfLiteInt16)) { double effective_scale = static_cast(input->params.scale / output->params.scale); @@ -103,6 +105,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); break; + case kTfLiteInt16: + reference_ops::AffineQuantize( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); + return kTfLiteOk; default: TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), @@ -118,6 +125,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { data->output_shift, input->params.zero_point, output->params.zero_point, GetTensorData(output)); break; + case kTfLiteInt16: + reference_ops::Requantize( + GetTensorData(input), size, data->output_multiplier, + data->output_shift, input->params.zero_point, + output->params.zero_point, GetTensorData(output)); + return kTfLiteOk; default: TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc index b6f885d09e7..8e097429ca0 100644 --- a/tensorflow/lite/micro/kernels/quantize_test.cc +++ b/tensorflow/lite/micro/kernels/quantize_test.cc @@ -198,6 +198,32 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt8NoScale) { dims, values, dims, values, values_quantized, scale, zero_point, output); } +TF_LITE_MICRO_TEST(QuantizeOpTestInt16) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-63.5, -63, -62.5, -62, -61.5, + 62, 62.5, 63, 63.5, 64}; + const float scale = 0.5; + const int zero_point = -1; + int16_t output[length]; + int16_t values_quantized[length]; + tflite::testing::TestQuantizeFloat( + dims, values, dims, values, values_quantized, scale, zero_point, output); +} + +TF_LITE_MICRO_TEST(QuantizeOpTestInt16NoScale) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-128, -127, -126, -125, -124, + 123, 124, 125, 126, 127}; + const float scale = 1.0; + const int zero_point = 0; + int16_t output[length]; + int16_t values_quantized[length]; + tflite::testing::TestQuantizeFloat( + dims, values, dims, values, values_quantized, scale, zero_point, output); +} + TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) { const int length = 10; const int dims[] = {2, 2, 5}; @@ -215,6 +241,40 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) { output_zero_point, output_quantized); } +TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt16) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-64, -62, -60, -58, -56, 54, 56, 58, 60, 62}; + const float input_scale = 2.f; + const int input_zero_point = 0; + const float output_scale = 0.5; + const int output_zero_point = 32; + int16_t output_quantized[length]; + int16_t values_quantized[length]; + int16_t input_quantized[length]; + tflite::testing::TestRequantize(dims, values, input_quantized, input_scale, + input_zero_point, dims, values, + values_quantized, output_scale, + output_zero_point, output_quantized); +} + +TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt16NoZeroPoint) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31}; + const float input_scale = 1.f; + const int input_zero_point = 0; + const float output_scale = 0.5; + const int output_zero_point = 0; + int16_t output_quantized[length]; + int16_t values_quantized[length]; + int16_t input_quantized[length]; + tflite::testing::TestRequantize(dims, values, input_quantized, input_scale, + input_zero_point, dims, values, + values_quantized, output_scale, + output_zero_point, output_quantized); +} + TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8) { const int length = 10; const int dims[] = {2, 2, 5};