Add test flag for post-training quantization in prepare quantize pass
Also pass quant_spec to ConvertLstmStatsToQDQs pass. PiperOrigin-RevId: 345008792 Change-Id: I888064ae680958cc751963c9803cd0c955dc5635
This commit is contained in:
parent
d019324c5a
commit
d00fe1e4b7
@ -59,6 +59,12 @@ static llvm::cl::opt<bool> quantize_signed(
|
||||
llvm::cl::desc("signed inference type. Only used in tests"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> post_training_quantize(
|
||||
"tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("enable post training quantization. Only used in tests"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
"tfl-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||
@ -89,10 +95,9 @@ class PrepareQuantizePass
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
// This is only used by test.
|
||||
explicit PrepareQuantizePass() {
|
||||
if (quantize_signed)
|
||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||
else
|
||||
quant_specs_.inference_type = tensorflow::DT_QUINT8;
|
||||
quant_specs_.inference_type =
|
||||
quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
|
||||
quant_specs_.post_training_quantization = post_training_quantize;
|
||||
}
|
||||
|
||||
// Constructor used by manually creating the pass.
|
||||
@ -362,7 +367,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
patterns.insert<PrepareLstmQuantStats>(ctx);
|
||||
patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
|
||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
SanityCheckAndAdjustment(func);
|
||||
|
@ -58,8 +58,11 @@ namespace operator_property = ::tflite::optimize::operator_property;
|
||||
template <typename SourceOp, typename Q, typename DQ>
|
||||
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
|
||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
||||
ConvertLstmStatsToQDQs(MLIRContext* context,
|
||||
const QuantizationSpecs& quant_specs)
|
||||
|
||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
|
||||
quant_specs(quant_specs) {}
|
||||
LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
operator_property::OpVariant lstm_variant;
|
||||
@ -137,7 +140,7 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
op.getLoc(), tensor_property.number_of_bits,
|
||||
calibrated_type.getMin(), calibrated_type.getMax(),
|
||||
/*narrowRange=*/false, calibrated_type.getExpressedType(),
|
||||
/*isSigned=*/false);
|
||||
/*isSigned=*/quant_specs.IsSignedInferenceType());
|
||||
} else if (tensor_property.number_of_bits == 16) {
|
||||
double max = std::max(std::abs(calibrated_type.getMin()),
|
||||
std::abs(calibrated_type.getMax()));
|
||||
@ -185,6 +188,9 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
QuantizationSpecs quant_specs;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
|
Loading…
x
Reference in New Issue
Block a user