Switch legacy quantize mode on by default for MLIR quantizer

Also fix code to make legacy pass applied together with other patterns. It ensures when numeric_verify is set, float ops can be duplicated before quantization to happen.

PiperOrigin-RevId: 352948582
Change-Id: I6550697ee6a4508bc6518ec08cccc854b34d2321
This commit is contained in:
Taehee Jeong 2021-01-20 22:44:09 -08:00 committed by TensorFlower Gardener
parent 9c530d1204
commit 310b42a801
2 changed files with 10 additions and 5 deletions

View File

@ -29,6 +29,11 @@ namespace lite {
// The `input_type`, `output_type` and `inference_type` can be
// float32/qint8/int8/int16.
// Return partially quantized model if `fully_quantize` is false.
// When `verify_numeric` is true, the model will have it's original float ops
// and NumericVerify ops to compare output values from the quantized and float
// ops. When `legacy_float_scale` is true, the quantizer will use float scale
// instead of double, and call TOCO's quantization routines to maintain
// bit-exactness of the values with the TOCO quantizer.
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
@ -37,7 +42,7 @@ TfLiteStatus QuantizeModel(
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter, bool verify_numeric = false,
bool legacy_float_scale = false);
bool legacy_float_scale = true);
} // namespace lite
} // namespace mlir

View File

@ -88,8 +88,10 @@ struct TFLFullQuantization
};
struct LegacyQuantizePass : public OpRewritePattern<QuantizeOp> {
// This pattern should be applied before existing quantize pattern in
// `quantize_patterns.td`, so the benefit is set to some value larger than 1.
explicit LegacyQuantizePass(MLIRContext* context)
: OpRewritePattern<QuantizeOp>(context) {}
: OpRewritePattern<QuantizeOp>(context, /*benefit=*/10) {}
LogicalResult matchAndRewrite(QuantizeOp op,
PatternRewriter& rewriter) const override {
DenseFPElementsAttr attr;
@ -127,9 +129,7 @@ void QuantizePass::runOnFunction() {
auto func = getFunction();
auto* ctx = func.getContext();
if (legacy_float_scale) {
OwningRewritePatternList legacy_patterns;
legacy_patterns.insert<LegacyQuantizePass>(ctx);
applyPatternsAndFoldGreedily(func, std::move(legacy_patterns));
patterns.insert<LegacyQuantizePass>(ctx);
}
TFL::populateWithGenerated(ctx, patterns);
patterns.insert<TFLFullQuantization>(