diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index a9e10a485bf..87cae3dd957 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -70,6 +70,7 @@ class PrepareQuantizePass : public PassWrapper { public: // Constructor used by the PassRegistration and enforce uint8 quantization. + // This is only used by test. explicit PrepareQuantizePass() { if (quantize_signed) quant_specs_.inference_type = tensorflow::DT_QINT8; @@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() { // convert all of them to signed. OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); + int bit_width = quant_specs_.GetQuantizationTypeWidth(); if (is_signed) { patterns.insert>(ctx); // Convert quant stats to int8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert(8, false, true, ctx); + patterns.insert(bit_width, false, true, ctx); } else { // Convert quant stats to uint8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert(8, false, false, ctx); + patterns.insert(bit_width, false, false, ctx); } applyPatternsAndFoldGreedily(func, patterns);