Add int16 support to Quant.
PiperOrigin-RevId: 258563058
This commit is contained in:
parent
4fd6623585
commit
f0d6424da5
@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpContext op_context(context, node);
|
||||
|
||||
TF_LITE_ENSURE(context, op_context.output->type == kTfLiteUInt8 ||
|
||||
op_context.output->type == kTfLiteInt8);
|
||||
op_context.output->type == kTfLiteInt8 ||
|
||||
op_context.output->type == kTfLiteInt16);
|
||||
|
||||
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
|
||||
// Currently this only support affine per-layer quantization.
|
||||
@ -69,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// For requantize use case.
|
||||
const bool is_requantize = (op_context.input->type == kTfLiteUInt8 ||
|
||||
op_context.input->type == kTfLiteInt8) &&
|
||||
op_context.input->type == kTfLiteInt8 ||
|
||||
op_context.input->type == kTfLiteInt16) &&
|
||||
(op_context.output->type == kTfLiteUInt8 ||
|
||||
op_context.output->type == kTfLiteInt8);
|
||||
op_context.output->type == kTfLiteInt8 ||
|
||||
op_context.output->type == kTfLiteInt16);
|
||||
if (is_requantize) {
|
||||
const double effective_output_scale =
|
||||
static_cast<double>(op_context.input->params.scale) /
|
||||
@ -104,6 +107,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
optimized_ops::AffineQuantize(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
} else if (output->type == kTfLiteInt16) {
|
||||
optimized_ops::AffineQuantize(
|
||||
op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
context->ReportError(
|
||||
context,
|
||||
|
@ -79,6 +79,17 @@ TEST(QuantizeOpTest, INT8) {
|
||||
{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}));
|
||||
}
|
||||
|
||||
TEST(QuantizeOpTest, INT16) {
|
||||
QuantizeOpModel m({TensorType_FLOAT32, {2, 5}},
|
||||
{TensorType_INT16, {2, 5}, 0, 0, 0.005, 0});
|
||||
|
||||
m.SetInput({-63.5, -63, -3, -2, -1, 1, 2, 3, 63.5, 64});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(),
|
||||
ElementsAreArray({-12700, -12600, -600, -400, -200, 200, 400, 600,
|
||||
12700, 12800}));
|
||||
}
|
||||
|
||||
// Input scale 0.500000, output scale 0.500000, input zeropoint -1, output
|
||||
// zeropoint -1
|
||||
TEST(QuantizeOpTest, Int8Int8SameScale) {
|
||||
|
@ -376,7 +376,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
|
||||
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
|
@ -169,7 +169,7 @@ OperatorProperty GetOperatorProperty(const BuiltinOperator& op) {
|
||||
case BuiltinOperator_QUANTIZE:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.version = 1;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_RESHAPE:
|
||||
property.inputs = {{0, {}}};
|
||||
|
@ -370,7 +370,7 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
|
||||
BuiltinOperator_CONCATENATION);
|
||||
EXPECT_EQ(model_.operator_codes[0]->version, 2);
|
||||
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
|
||||
EXPECT_EQ(model_.operator_codes[1]->version, 1);
|
||||
EXPECT_EQ(model_.operator_codes[1]->version, 2);
|
||||
}
|
||||
|
||||
class QuantizeConvModel1Test : public QuantizeModelTest {
|
||||
|
Loading…
x
Reference in New Issue
Block a user