Support float32->int16 and int16->int16 quantization in TFLu
This commit is contained in:
parent
e66a20bb44
commit
412da53d65
@ -66,11 +66,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
|
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
|
||||||
input->type == kTfLiteInt16 ||
|
input->type == kTfLiteInt16 ||
|
||||||
input->type == kTfLiteInt8);
|
input->type == kTfLiteInt8);
|
||||||
TF_LITE_ENSURE(context,
|
TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
|
||||||
output->type == kTfLiteUInt8 || output->type == kTfLiteInt8);
|
output->type == kTfLiteInt8 ||
|
||||||
|
output->type == kTfLiteInt16);
|
||||||
|
|
||||||
if ((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) &&
|
if ((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) &&
|
||||||
output->type == kTfLiteInt8) {
|
output->type == kTfLiteInt8 ||
|
||||||
|
(input->type == kTfLiteInt16 && output->type == kTfLiteInt16)) {
|
||||||
double effective_scale =
|
double effective_scale =
|
||||||
static_cast<double>(input->params.scale / output->params.scale);
|
static_cast<double>(input->params.scale / output->params.scale);
|
||||||
|
|
||||||
@ -103,6 +105,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
reference_ops::AffineQuantize(
|
||||||
|
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||||
|
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||||
TfLiteTypeGetName(input->type),
|
TfLiteTypeGetName(input->type),
|
||||||
@ -118,6 +125,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
data->output_shift, input->params.zero_point,
|
data->output_shift, input->params.zero_point,
|
||||||
output->params.zero_point, GetTensorData<int8_t>(output));
|
output->params.zero_point, GetTensorData<int8_t>(output));
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
reference_ops::Requantize(
|
||||||
|
GetTensorData<int16_t>(input), size, data->output_multiplier,
|
||||||
|
data->output_shift, input->params.zero_point,
|
||||||
|
output->params.zero_point, GetTensorData<int16_t>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||||
TfLiteTypeGetName(input->type),
|
TfLiteTypeGetName(input->type),
|
||||||
|
@ -198,6 +198,32 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt8NoScale) {
|
|||||||
dims, values, dims, values, values_quantized, scale, zero_point, output);
|
dims, values, dims, values, values_quantized, scale, zero_point, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt16) {
|
||||||
|
const int length = 10;
|
||||||
|
const int dims[] = {2, 2, 5};
|
||||||
|
const float values[] = {-63.5, -63, -62.5, -62, -61.5,
|
||||||
|
62, 62.5, 63, 63.5, 64};
|
||||||
|
const float scale = 0.5;
|
||||||
|
const int zero_point = -1;
|
||||||
|
int16_t output[length];
|
||||||
|
int16_t values_quantized[length];
|
||||||
|
tflite::testing::TestQuantizeFloat(
|
||||||
|
dims, values, dims, values, values_quantized, scale, zero_point, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt16NoScale) {
|
||||||
|
const int length = 10;
|
||||||
|
const int dims[] = {2, 2, 5};
|
||||||
|
const float values[] = {-128, -127, -126, -125, -124,
|
||||||
|
123, 124, 125, 126, 127};
|
||||||
|
const float scale = 1.0;
|
||||||
|
const int zero_point = 0;
|
||||||
|
int16_t output[length];
|
||||||
|
int16_t values_quantized[length];
|
||||||
|
tflite::testing::TestQuantizeFloat(
|
||||||
|
dims, values, dims, values, values_quantized, scale, zero_point, output);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
|
||||||
const int length = 10;
|
const int length = 10;
|
||||||
const int dims[] = {2, 2, 5};
|
const int dims[] = {2, 2, 5};
|
||||||
@ -215,6 +241,40 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
|
|||||||
output_zero_point, output_quantized);
|
output_zero_point, output_quantized);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt16) {
|
||||||
|
const int length = 10;
|
||||||
|
const int dims[] = {2, 2, 5};
|
||||||
|
const float values[] = {-64, -62, -60, -58, -56, 54, 56, 58, 60, 62};
|
||||||
|
const float input_scale = 2.f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
const float output_scale = 0.5;
|
||||||
|
const int output_zero_point = 32;
|
||||||
|
int16_t output_quantized[length];
|
||||||
|
int16_t values_quantized[length];
|
||||||
|
int16_t input_quantized[length];
|
||||||
|
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
|
||||||
|
input_zero_point, dims, values,
|
||||||
|
values_quantized, output_scale,
|
||||||
|
output_zero_point, output_quantized);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt16NoZeroPoint) {
|
||||||
|
const int length = 10;
|
||||||
|
const int dims[] = {2, 2, 5};
|
||||||
|
const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31};
|
||||||
|
const float input_scale = 1.f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
const float output_scale = 0.5;
|
||||||
|
const int output_zero_point = 0;
|
||||||
|
int16_t output_quantized[length];
|
||||||
|
int16_t values_quantized[length];
|
||||||
|
int16_t input_quantized[length];
|
||||||
|
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
|
||||||
|
input_zero_point, dims, values,
|
||||||
|
values_quantized, output_scale,
|
||||||
|
output_zero_point, output_quantized);
|
||||||
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8) {
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8) {
|
||||||
const int length = 10;
|
const int length = 10;
|
||||||
const int dims[] = {2, 2, 5};
|
const int dims[] = {2, 2, 5};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user