Add test flag for post-training quantization in prepare quantize pass

Also pass quant_spec to ConvertLstmStatsToQDQs pass.

PiperOrigin-RevId: 345008792
Change-Id: I888064ae680958cc751963c9803cd0c955dc5635
This commit is contained in:
Taehee Jeong 2020-12-01 06:37:29 -08:00 committed by TensorFlower Gardener
parent d019324c5a
commit d00fe1e4b7
2 changed files with 19 additions and 8 deletions

View File

@ -59,6 +59,12 @@ static llvm::cl::opt<bool> quantize_signed(
llvm::cl::desc("signed inference type. Only used in tests"), llvm::cl::desc("signed inference type. Only used in tests"),
llvm::cl::init(false)); llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> post_training_quantize(
"tfl-test-post-training-quantize", llvm::cl::value_desc("bool"),
llvm::cl::desc("enable post training quantization. Only used in tests"),
llvm::cl::init(false));
// NOLINTNEXTLINE // NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel( static llvm::cl::opt<bool> disable_per_channel(
"tfl-disable-per-channel", llvm::cl::value_desc("bool"), "tfl-disable-per-channel", llvm::cl::value_desc("bool"),
@ -89,10 +95,9 @@ class PrepareQuantizePass
// Constructor used by the PassRegistration and enforce uint8 quantization. // Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test. // This is only used by test.
explicit PrepareQuantizePass() { explicit PrepareQuantizePass() {
if (quantize_signed) quant_specs_.inference_type =
quant_specs_.inference_type = tensorflow::DT_QINT8; quantize_signed ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8;
else quant_specs_.post_training_quantization = post_training_quantize;
quant_specs_.inference_type = tensorflow::DT_QUINT8;
} }
// Constructor used by manually creating the pass. // Constructor used by manually creating the pass.
@ -362,7 +367,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false. // Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx); patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
} }
patterns.insert<PrepareLstmQuantStats>(ctx); patterns.insert<PrepareLstmQuantStats>(ctx, quant_specs_);
applyPatternsAndFoldGreedily(func, std::move(patterns)); applyPatternsAndFoldGreedily(func, std::move(patterns));
SanityCheckAndAdjustment(func); SanityCheckAndAdjustment(func);

View File

@ -58,8 +58,11 @@ namespace operator_property = ::tflite::optimize::operator_property;
template <typename SourceOp, typename Q, typename DQ> template <typename SourceOp, typename Q, typename DQ>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> { struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public: public:
explicit ConvertLstmStatsToQDQs(MLIRContext* context) ConvertLstmStatsToQDQs(MLIRContext* context,
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {} const QuantizationSpecs& quant_specs)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2),
quant_specs(quant_specs) {}
LogicalResult matchAndRewrite(SourceOp op, LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
operator_property::OpVariant lstm_variant; operator_property::OpVariant lstm_variant;
@ -137,7 +140,7 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
op.getLoc(), tensor_property.number_of_bits, op.getLoc(), tensor_property.number_of_bits,
calibrated_type.getMin(), calibrated_type.getMax(), calibrated_type.getMin(), calibrated_type.getMax(),
/*narrowRange=*/false, calibrated_type.getExpressedType(), /*narrowRange=*/false, calibrated_type.getExpressedType(),
/*isSigned=*/false); /*isSigned=*/quant_specs.IsSignedInferenceType());
} else if (tensor_property.number_of_bits == 16) { } else if (tensor_property.number_of_bits == 16) {
double max = std::max(std::abs(calibrated_type.getMin()), double max = std::max(std::abs(calibrated_type.getMin()),
std::abs(calibrated_type.getMax())); std::abs(calibrated_type.getMax()));
@ -185,6 +188,9 @@ struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q); rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success(); return success();
} }
private:
QuantizationSpecs quant_specs;
}; };
} // namespace TFL } // namespace TFL