Support int16 quantization type
This patch is just changing a hard-coded 8 bits setting to be configured by the inference type. PiperOrigin-RevId: 311816528 Change-Id: I8da61fb0751122e29134d13e5f8200c89980e131
This commit is contained in:
parent
ec52e0fcd3
commit
da27ac6878
@ -70,6 +70,7 @@ class PrepareQuantizePass
|
||||
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
// This is only used by test.
|
||||
explicit PrepareQuantizePass() {
|
||||
if (quantize_signed)
|
||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// convert all of them to signed.
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
|
||||
} else {
|
||||
// Convert quant stats to uint8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user