Add test flag for post-training quantization in prepare quantize pass
Also pass quant_spec to ConvertLstmStatsToQDQs pass. PiperOrigin-RevId: 345008792 Change-Id: I888064ae680958cc751963c9803cd0c955dc5635
This commit is contained in:
parent
d019324c5a
commit
d00fe1e4b7
@ -59,6 +59,12 @@ static llvm::cl::opt<bool> quantize_signed(
|
|||||||
llvm::cl::desc("signed inference type. Only used in tests"),
|
llvm::cl::desc("signed inference type. Only used in tests"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static llvm::cl::opt<bool> post_training_quantize(
|
||||||
|
"tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
|
||||||
|
llvm::cl::desc("enable post training quantization. Only used in tests"),
|
||||||
|
llvm::cl::init(false));
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::opt<bool> disable_per_channel(
|
static llvm::cl::opt<bool> disable_per_channel(
|
||||||
"tfl-disable-per-channel", llvm::cl::value_desc("bool"),
|
"tfl-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||||
@ -89,10 +95,9 @@ class PrepareQuantizePass
|
|||||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||||
// This is only used by test.
|
// This is only used by test.
|
||||||
explicit PrepareQuantizePass() {
|
explicit PrepareQuantizePass() {
|
||||||
if (quantize_signed)
|
quant_specs_.inference_type =
|
||||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
|
||||||
else
|
quant_specs_.post_training_quantization = post_training_quantize;
|
||||||
quant_specs_.inference_type = tensorflow::DT_QUINT8;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Constructor used by manually creating the pass.
|
// Constructor used by manually creating the pass.
|
||||||
@ -362,7 +367,7 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
// Currently, only activation stats are imported, so narrow_range = false.
|
// Currently, only activation stats are imported, so narrow_range = false.
|
||||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||||
}
|
}
|
||||||
patterns.insert<PrepareLstmQuantStats>(ctx);
|
patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
|
||||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||||
|
|
||||||
SanityCheckAndAdjustment(func);
|
SanityCheckAndAdjustment(func);
|
||||||
|
@ -58,8 +58,11 @@ namespace operator_property = ::tflite::optimize::operator_property;
|
|||||||
template <typename SourceOp, typename Q, typename DQ>
|
template <typename SourceOp, typename Q, typename DQ>
|
||||||
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||||
public:
|
public:
|
||||||
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
|
ConvertLstmStatsToQDQs(MLIRContext* context,
|
||||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
const QuantizationSpecs& quant_specs)
|
||||||
|
|
||||||
|
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
|
||||||
|
quant_specs(quant_specs) {}
|
||||||
LogicalResult matchAndRewrite(SourceOp op,
|
LogicalResult matchAndRewrite(SourceOp op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
operator_property::OpVariant lstm_variant;
|
operator_property::OpVariant lstm_variant;
|
||||||
@ -137,7 +140,7 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
|||||||
op.getLoc(), tensor_property.number_of_bits,
|
op.getLoc(), tensor_property.number_of_bits,
|
||||||
calibrated_type.getMin(), calibrated_type.getMax(),
|
calibrated_type.getMin(), calibrated_type.getMax(),
|
||||||
/*narrowRange=*/false, calibrated_type.getExpressedType(),
|
/*narrowRange=*/false, calibrated_type.getExpressedType(),
|
||||||
/*isSigned=*/false);
|
/*isSigned=*/quant_specs.IsSignedInferenceType());
|
||||||
} else if (tensor_property.number_of_bits == 16) {
|
} else if (tensor_property.number_of_bits == 16) {
|
||||||
double max = std::max(std::abs(calibrated_type.getMin()),
|
double max = std::max(std::abs(calibrated_type.getMin()),
|
||||||
std::abs(calibrated_type.getMax()));
|
std::abs(calibrated_type.getMax()));
|
||||||
@ -185,6 +188,9 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
|||||||
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
QuantizationSpecs quant_specs;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
Loading…
x
Reference in New Issue
Block a user