Merge pull request #40913 from patriklaurell:tflu-int16-quantize

PiperOrigin-RevId: 326041350
Change-Id: Iaadea4c1ee878c44710598a91885d18c6266995a
This commit is contained in:
TensorFlower Gardener 2020-08-11 10:04:10 -07:00
commit 1cca3190a7
2 changed files with 80 additions and 4 deletions

View File

@ -65,11 +65,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
input->type == kTfLiteInt16 ||
input->type == kTfLiteInt8);
TF_LITE_ENSURE(context,
output->type == kTfLiteUInt8 || output->type == kTfLiteInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
output->type == kTfLiteInt8 ||
output->type == kTfLiteInt16);
if ((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) &&
output->type == kTfLiteInt8) {
if (((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) &&
output->type == kTfLiteInt8) ||
(input->type == kTfLiteInt16 && output->type == kTfLiteInt16)) {
double effective_scale =
static_cast<double>(input->params.scale / output->params.scale);
@ -107,6 +109,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
break;
case kTfLiteInt16:
reference_ops::AffineQuantize(
data->quantization_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
@ -123,6 +132,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
data->quantization_params.zero_point,
tflite::micro::GetTensorData<int8_t>(output));
break;
case kTfLiteInt16:
reference_ops::Requantize(
tflite::micro::GetTensorData<int16_t>(input), size,
data->output_multiplier, data->output_shift, data->input_zero_point,
data->quantization_params.zero_point,
tflite::micro::GetTensorData<int16_t>(output));
return kTfLiteOk;
default:
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),

View File

@ -173,6 +173,32 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt8NoScale) {
dims, values, dims, values, values_quantized, scale, zero_point, output);
}
TF_LITE_MICRO_TEST(QuantizeOpTestInt16) {
const int length = 10;
const int dims[] = {2, 2, 5};
const float values[] = {-63.5, -63, -62.5, -62, -61.5,
62, 62.5, 63, 63.5, 64};
const float scale = 0.5;
const int zero_point = -1;
int16_t output[length];
int16_t values_quantized[length];
tflite::testing::TestQuantizeFloat(
dims, values, dims, values, values_quantized, scale, zero_point, output);
}
TF_LITE_MICRO_TEST(QuantizeOpTestInt16NoScale) {
const int length = 10;
const int dims[] = {2, 2, 5};
const float values[] = {-128, -127, -126, -125, -124,
123, 124, 125, 126, 127};
const float scale = 1.0;
const int zero_point = 0;
int16_t output[length];
int16_t values_quantized[length];
tflite::testing::TestQuantizeFloat(
dims, values, dims, values, values_quantized, scale, zero_point, output);
}
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
const int length = 10;
const int dims[] = {2, 2, 5};
@ -190,6 +216,40 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
output_zero_point, output_quantized);
}
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt16) {
const int length = 10;
const int dims[] = {2, 2, 5};
const float values[] = {-64, -62, -60, -58, -56, 54, 56, 58, 60, 62};
const float input_scale = 2.f;
const int input_zero_point = 0;
const float output_scale = 0.5;
const int output_zero_point = 32;
int16_t output_quantized[length];
int16_t values_quantized[length];
int16_t input_quantized[length];
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
input_zero_point, dims, values,
values_quantized, output_scale,
output_zero_point, output_quantized);
}
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt16NoZeroPoint) {
const int length = 10;
const int dims[] = {2, 2, 5};
const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31};
const float input_scale = 1.f;
const int input_zero_point = 0;
const float output_scale = 0.5;
const int output_zero_point = 0;
int16_t output_quantized[length];
int16_t values_quantized[length];
int16_t input_quantized[length];
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
input_zero_point, dims, values,
values_quantized, output_scale,
output_zero_point, output_quantized);
}
TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8) {
const int length = 10;
const int dims[] = {2, 2, 5};