Support int16 quantization type

This patch is just changing a hard-coded 8 bits setting to be configured by the inference type.

PiperOrigin-RevId: 311816528
Change-Id: I8da61fb0751122e29134d13e5f8200c89980e131
This commit is contained in:
Feng Liu 2020-05-15 15:54:21 -07:00 committed by TensorFlower Gardener
parent ec52e0fcd3
commit da27ac6878

View File

@ -70,6 +70,7 @@ class PrepareQuantizePass
: public PassWrapper<PrepareQuantizePass, FunctionPass> { : public PassWrapper<PrepareQuantizePass, FunctionPass> {
public: public:
// Constructor used by the PassRegistration and enforce uint8 quantization. // Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test.
explicit PrepareQuantizePass() { explicit PrepareQuantizePass() {
if (quantize_signed) if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8; quant_specs_.inference_type = tensorflow::DT_QINT8;
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
// convert all of them to signed. // convert all of them to signed.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType(); bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
if (is_signed) { if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx); patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters. // Convert quant stats to int8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false. // Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, true, ctx); patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
} else { } else {
// Convert quant stats to uint8 quantization parameters. // Convert quant stats to uint8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false. // Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, false, ctx); patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
} }
applyPatternsAndFoldGreedily(func, patterns); applyPatternsAndFoldGreedily(func, patterns);