Add int16 support to Quant.

PiperOrigin-RevId: 258563058
This commit is contained in:
Jian Li 2019-07-17 07:18:23 -07:00 committed by TensorFlower Gardener
parent 4fd6623585
commit f0d6424da5
5 changed files with 26 additions and 6 deletions

View File

@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.output->type == kTfLiteUInt8 ||
op_context.output->type == kTfLiteInt8);
op_context.output->type == kTfLiteInt8 ||
op_context.output->type == kTfLiteInt16);
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization.
@ -69,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// For requantize use case.
const bool is_requantize = (op_context.input->type == kTfLiteUInt8 ||
op_context.input->type == kTfLiteInt8) &&
op_context.input->type == kTfLiteInt8 ||
op_context.input->type == kTfLiteInt16) &&
(op_context.output->type == kTfLiteUInt8 ||
op_context.output->type == kTfLiteInt8);
op_context.output->type == kTfLiteInt8 ||
op_context.output->type == kTfLiteInt16);
if (is_requantize) {
const double effective_output_scale =
static_cast<double>(op_context.input->params.scale) /
@ -104,6 +107,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
optimized_ops::AffineQuantize(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
} else if (output->type == kTfLiteInt16) {
optimized_ops::AffineQuantize(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
context->ReportError(
context,

View File

@ -79,6 +79,17 @@ TEST(QuantizeOpTest, INT8) {
{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}));
}
TEST(QuantizeOpTest, INT16) {
QuantizeOpModel m({TensorType_FLOAT32, {2, 5}},
{TensorType_INT16, {2, 5}, 0, 0, 0.005, 0});
m.SetInput({-63.5, -63, -3, -2, -1, 1, 2, 3, 63.5, 64});
m.Invoke();
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({-12700, -12600, -600, -400, -200, 200, 400, 600,
12700, 12800}));
}
// Input scale 0.500000, output scale 0.500000, input zeropoint -1, output
// zeropoint -1
TEST(QuantizeOpTest, Int8Int8SameScale) {

View File

@ -376,7 +376,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that

View File

@ -169,7 +169,7 @@ OperatorProperty GetOperatorProperty(const BuiltinOperator& op) {
case BuiltinOperator_QUANTIZE:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.version = 1;
property.version = 2;
break;
case BuiltinOperator_RESHAPE:
property.inputs = {{0, {}}};

View File

@ -370,7 +370,7 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
BuiltinOperator_CONCATENATION);
EXPECT_EQ(model_.operator_codes[0]->version, 2);
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
EXPECT_EQ(model_.operator_codes[1]->version, 1);
EXPECT_EQ(model_.operator_codes[1]->version, 2);
}
class QuantizeConvModel1Test : public QuantizeModelTest {