Add test flag for post-training quantization in prepare quantize pass

Also pass quant_spec to ConvertLstmStatsToQDQs pass.

PiperOrigin-RevId: 345008792
Change-Id: I888064ae680958cc751963c9803cd0c955dc5635
This commit is contained in:
Taehee Jeong 2020-12-01 06:37:29 -08:00 committed by TensorFlower Gardener
parent d019324c5a
commit d00fe1e4b7
2 changed files with 19 additions and 8 deletions

View File

@ -59,6 +59,12 @@ static llvm::cl::opt<bool> quantize_signed(
llvm::cl::desc("signed inference type. Only used in tests"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> post_training_quantize(
"tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
llvm::cl::desc("enable post training quantization. Only used in tests"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
"tfl-disable-per-channel", llvm::cl::value_desc("bool"),
@ -89,10 +95,9 @@ class PrepareQuantizePass
// Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test.
explicit PrepareQuantizePass() {
if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8;
else
quant_specs_.inference_type = tensorflow::DT_QUINT8;
quant_specs_.inference_type =
quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
quant_specs_.post_training_quantization = post_training_quantize;
}
// Constructor used by manually creating the pass.
@ -362,7 +367,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
}
patterns.insert<PrepareLstmQuantStats>(ctx);
patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
applyPatternsAndFoldGreedily(func, std::move(patterns));
SanityCheckAndAdjustment(func);

View File

@ -58,8 +58,11 @@ namespace operator_property = ::tflite::optimize::operator_property;
template <typename SourceOp, typename Q, typename DQ>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
ConvertLstmStatsToQDQs(MLIRContext* context,
const QuantizationSpecs& quant_specs)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
quant_specs(quant_specs) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
operator_property::OpVariant lstm_variant;
@ -137,7 +140,7 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
op.getLoc(), tensor_property.number_of_bits,
calibrated_type.getMin(), calibrated_type.getMax(),
/*narrowRange=*/false, calibrated_type.getExpressedType(),
/*isSigned=*/false);
/*isSigned=*/quant_specs.IsSignedInferenceType());
} else if (tensor_property.number_of_bits == 16) {
double max = std::max(std::abs(calibrated_type.getMin()),
std::abs(calibrated_type.getMax()));
@ -185,6 +188,9 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
private:
QuantizationSpecs quant_specs;
};
} // namespace TFL