Apply default quantization parameter before quantization pass

By this order, the default quantization parameter is only applied on the
activations and the weight quantization parameter will use the parameters from
the weight content.

PiperOrigin-RevId: 317672365
Change-Id: Ib7b02ae19105db124721242ea51a5ccc1d5aa68e
This commit is contained in:
Feng Liu 2020-06-22 09:43:32 -07:00 committed by TensorFlower Gardener
parent e25fcc8393
commit 22f6939be9
1 changed files with 5 additions and 9 deletions

View File

@ -39,22 +39,18 @@ namespace tensorflow {
void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::OpPassManager* pass_manager) { mlir::OpPassManager* pass_manager) {
pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs)); pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
if (quant_specs.default_ranges.first.hasValue() || if (quant_specs.default_ranges.first.hasValue() ||
quant_specs.default_ranges.second.hasValue()) { quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0), quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0), quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType())); quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
} }
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
bool emit_quant_adaptor_ops =
quant_specs.inference_type != quant_specs.inference_input_type;
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
} }
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,