Add Int8 -> Int8 requantize operation.
PiperOrigin-RevId: 313800318 Change-Id: I04650f59f8551b482648a7468fb1b2773a64b415
This commit is contained in:
parent
a475c198ec
commit
57749ce64f
|
@ -63,12 +63,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE(context, affine_quantization->scale);
|
TF_LITE_ENSURE(context, affine_quantization->scale);
|
||||||
TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
|
TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
|
||||||
|
|
||||||
TF_LITE_ENSURE(context,
|
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 ||
|
||||||
input->type == kTfLiteFloat32 || input->type == kTfLiteInt16);
|
input->type == kTfLiteInt16 ||
|
||||||
|
input->type == kTfLiteInt8);
|
||||||
TF_LITE_ENSURE(context,
|
TF_LITE_ENSURE(context,
|
||||||
output->type == kTfLiteUInt8 || output->type == kTfLiteInt8);
|
output->type == kTfLiteUInt8 || output->type == kTfLiteInt8);
|
||||||
|
|
||||||
if (input->type == kTfLiteInt16 && output->type == kTfLiteInt8) {
|
if ((input->type == kTfLiteInt16 || input->type == kTfLiteInt8) &&
|
||||||
|
output->type == kTfLiteInt8) {
|
||||||
double effective_scale =
|
double effective_scale =
|
||||||
static_cast<double>(input->params.scale / output->params.scale);
|
static_cast<double>(input->params.scale / output->params.scale);
|
||||||
|
|
||||||
|
@ -122,6 +124,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TfLiteTypeGetName(output->type));
|
TfLiteTypeGetName(output->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
} else if (input->type == kTfLiteInt8) {
|
||||||
|
// Int8 to Int8 requantization, required if the input and output tensors
|
||||||
|
// have different scales and/or zero points.
|
||||||
|
size_t size = ElementCount(*input->dims);
|
||||||
|
switch (output->type) {
|
||||||
|
case kTfLiteInt8:
|
||||||
|
reference_ops::Requantize(
|
||||||
|
GetTensorData<int8_t>(input), size, data->output_multiplier,
|
||||||
|
data->output_shift, input->params.zero_point,
|
||||||
|
output->params.zero_point, GetTensorData<int8_t>(output));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||||
|
TfLiteTypeGetName(input->type),
|
||||||
|
TfLiteTypeGetName(output->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
TF_LITE_KERNEL_LOG(context, "Input %s, output %s not supported.",
|
||||||
TfLiteTypeGetName(input->type),
|
TfLiteTypeGetName(input->type),
|
||||||
|
|
|
@ -110,13 +110,13 @@ void TestQuantizeFloat(const int* input_dims_data, const float* input_data,
|
||||||
scale, zero_point, output_dims_count, output_data);
|
scale, zero_point, output_dims_count, output_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename InputType, typename OutputType>
|
||||||
void TestQuantizeInt16(const int* input_dims_data, const float* input_data,
|
void TestRequantize(const int* input_dims_data, const float* input_data,
|
||||||
int16_t* input_quantized, const float input_scale,
|
InputType* input_quantized, const float input_scale,
|
||||||
const int input_zero_point, const int* output_dims_data,
|
const int input_zero_point, const int* output_dims_data,
|
||||||
const float* golden, T* golden_quantized,
|
const float* golden, OutputType* golden_quantized,
|
||||||
const float output_scale, const int output_zero_point,
|
const float output_scale, const int output_zero_point,
|
||||||
T* output_data) {
|
OutputType* output_data) {
|
||||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||||
const int output_dims_count = ElementCount(*output_dims);
|
const int output_dims_count = ElementCount(*output_dims);
|
||||||
|
@ -212,10 +212,45 @@ TF_LITE_MICRO_TEST(QuantizeOpTestInt16toInt8) {
|
||||||
const float output_scale = 0.5;
|
const float output_scale = 0.5;
|
||||||
const int output_zero_point = 0;
|
const int output_zero_point = 0;
|
||||||
int8_t output_quantized[length];
|
int8_t output_quantized[length];
|
||||||
|
int8_t values_quantized[length];
|
||||||
int16_t input_quantized[length];
|
int16_t input_quantized[length];
|
||||||
tflite::testing::TestQuantizeInt16(dims, values, input_quantized, input_scale,
|
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
|
||||||
input_zero_point, dims, values,
|
input_zero_point, dims, values,
|
||||||
output_quantized, output_scale,
|
values_quantized, output_scale,
|
||||||
|
output_zero_point, output_quantized);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8) {
|
||||||
|
const int length = 10;
|
||||||
|
const int dims[] = {2, 2, 5};
|
||||||
|
const float values[] = {-64, -62, -60, -58, -56, 54, 56, 58, 60, 62};
|
||||||
|
const float input_scale = 2.f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
const float output_scale = 0.5;
|
||||||
|
const int output_zero_point = 32;
|
||||||
|
int8_t output_quantized[length];
|
||||||
|
int8_t values_quantized[length];
|
||||||
|
int8_t input_quantized[length];
|
||||||
|
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
|
||||||
|
input_zero_point, dims, values,
|
||||||
|
values_quantized, output_scale,
|
||||||
|
output_zero_point, output_quantized);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_MICRO_TEST(QuantizeOpTestInt8toInt8NoZeroPoint) {
|
||||||
|
const int length = 10;
|
||||||
|
const int dims[] = {2, 2, 5};
|
||||||
|
const float values[] = {-32, -31, -30, -29, -28, 27, 28, 29, 30, 31};
|
||||||
|
const float input_scale = 1.f;
|
||||||
|
const int input_zero_point = 0;
|
||||||
|
const float output_scale = 0.5;
|
||||||
|
const int output_zero_point = 0;
|
||||||
|
int8_t output_quantized[length];
|
||||||
|
int8_t values_quantized[length];
|
||||||
|
int8_t input_quantized[length];
|
||||||
|
tflite::testing::TestRequantize(dims, values, input_quantized, input_scale,
|
||||||
|
input_zero_point, dims, values,
|
||||||
|
values_quantized, output_scale,
|
||||||
output_zero_point, output_quantized);
|
output_zero_point, output_quantized);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue