From 57749ce64fa2e7626b5b4ed9650a4b5c48956afd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 29 May 2020 10:01:29 -0700 Subject: [PATCH] Add Int8 -> Int8 requantize operation. PiperOrigin-RevId: 313800318 Change-Id: I04650f59f8551b482648a7468fb1b2773a64b415 --- tensorflow/lite/micro/kernels/quantize.cc | 25 +++++++- .../lite/micro/kernels/quantize_test.cc | 57 +++++++++++++++---- 2 files changed, 68 insertions(+), 14 deletions(-) diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc index b5bba83beb8..b58a1cb368e 100644 --- a/tensorflow/lite/micro/kernels/quantize.cc +++ b/tensorflow/lite/micro/kernels/quantize.cc @@ -63,12 +63,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, affine_quantization->scale); TF_LITE_ENSURE(context, affine_quantization->scale->size == 1); - TF_LITE_ENSURE(context, - input->type == kTfLiteFloat32 || input->type == kTfLiteInt16); + TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || + input->type == kTfLiteInt16 || + input->type == kTfLiteInt8); TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 || output->type == kTfLiteInt8); - if (input->type == kTfLiteInt16 && output->type == kTfLiteInt8) { + if ((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) && + output->type == kTfLiteInt8) { double effective_scale = static_cast(input->params.scale / output->params.scale); @@ -122,6 +124,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTypeGetName(output->type)); return kTfLiteError; } + } else if (input->type == kTfLiteInt8) { + // Int8 to Int8 requantization, required if the input and output tensors + // have different scales and/or zero points. + size_t size = ElementCount(*input->dims); + switch (output->type) { + case kTfLiteInt8: + reference_ops::Requantize( + GetTensorData(input), size, data->output_multiplier, + data->output_shift, input->params.zero_point, + output->params.zero_point, GetTensorData(output)); + break; + default: + TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", + TfLiteTypeGetName(input->type), + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } } else { TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.", TfLiteTypeGetName(input->type), diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc index 359abbd73db..0364fbc57ec 100644 --- a/tensorflow/lite/micro/kernels/quantize_test.cc +++ b/tensorflow/lite/micro/kernels/quantize_test.cc @@ -110,13 +110,13 @@ void TestQuantizeFloat(const int* input_dims_data, const float* input_data, scale, zero_point, output_dims_count, output_data); } -template -void TestQuantizeInt16(const int* input_dims_data, const float* input_data, - int16_t* input_quantized, const float input_scale, - const int input_zero_point, const int* output_dims_data, - const float* golden, T* golden_quantized, - const float output_scale, const int output_zero_point, - T* output_data) { +template +void TestRequantize(const int* input_dims_data, const float* input_data, + InputType* input_quantized, const float input_scale, + const int input_zero_point, const int* output_dims_data, + const float* golden, OutputType* golden_quantized, + const float output_scale, const int output_zero_point, + OutputType* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); const int output_dims_count = ElementCount(*output_dims); @@ -212,11 +212,46 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) { const float output_scale = 0.5; const int output_zero_point = 0; int8_t output_quantized[length]; + int8_t values_quantized[length]; int16_t input_quantized[length]; - tflite::testing::TestQuantizeInt16(dims, values, input_quantized, input_scale, - input_zero_point, dims, values, - output_quantized, output_scale, - output_zero_point, output_quantized); + tflite::testing::TestRequantize(dims, values, input_quantized, input_scale, + input_zero_point, dims, values, + values_quantized, output_scale, + output_zero_point, output_quantized); +} + +TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-64, -62, -60, -58, -56, 54, 56, 58, 60, 62}; + const float input_scale = 2.f; + const int input_zero_point = 0; + const float output_scale = 0.5; + const int output_zero_point = 32; + int8_t output_quantized[length]; + int8_t values_quantized[length]; + int8_t input_quantized[length]; + tflite::testing::TestRequantize(dims, values, input_quantized, input_scale, + input_zero_point, dims, values, + values_quantized, output_scale, + output_zero_point, output_quantized); +} + +TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8NoZeroPoint) { + const int length = 10; + const int dims[] = {2, 2, 5}; + const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31}; + const float input_scale = 1.f; + const int input_zero_point = 0; + const float output_scale = 0.5; + const int output_zero_point = 0; + int8_t output_quantized[length]; + int8_t values_quantized[length]; + int8_t input_quantized[length]; + tflite::testing::TestRequantize(dims, values, input_quantized, input_scale, + input_zero_point, dims, values, + values_quantized, output_scale, + output_zero_point, output_quantized); } TF_LITE_MICRO_TESTS_END