Support int16 quantization type
This patch is just changing a hard-coded 8 bits setting to be configured by the inference type. PiperOrigin-RevId: 311816528 Change-Id: I8da61fb0751122e29134d13e5f8200c89980e131
This commit is contained in:
parent
ec52e0fcd3
commit
da27ac6878
@ -70,6 +70,7 @@ class PrepareQuantizePass
|
|||||||
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
||||||
public:
|
public:
|
||||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||||
|
// This is only used by test.
|
||||||
explicit PrepareQuantizePass() {
|
explicit PrepareQuantizePass() {
|
||||||
if (quantize_signed)
|
if (quantize_signed)
|
||||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||||
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
// convert all of them to signed.
|
// convert all of them to signed.
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||||
|
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||||
if (is_signed) {
|
if (is_signed) {
|
||||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||||
// Convert quant stats to int8 quantization parameters.
|
// Convert quant stats to int8 quantization parameters.
|
||||||
// Currently, only activation stats are imported, so narrow_range = false.
|
// Currently, only activation stats are imported, so narrow_range = false.
|
||||||
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
|
patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
|
||||||
} else {
|
} else {
|
||||||
// Convert quant stats to uint8 quantization parameters.
|
// Convert quant stats to uint8 quantization parameters.
|
||||||
// Currently, only activation stats are imported, so narrow_range = false.
|
// Currently, only activation stats are imported, so narrow_range = false.
|
||||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||||
}
|
}
|
||||||
applyPatternsAndFoldGreedily(func, patterns);
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user