Only infer the fixed output range when the input graph has dequantize ops

PiperOrigin-RevId: 327312937
Change-Id: Ice4c2e35aeb074e34516dc434ca0c066af947ca8
This commit is contained in:
Feng Liu 2020-08-18 14:54:35 -07:00 committed by TensorFlower Gardener
parent 62793f5afa
commit 81c73b541d
3 changed files with 34 additions and 9 deletions

View File

@ -99,12 +99,14 @@ class QuantizationDriver {
public:
explicit QuantizationDriver(FuncOp fn, bool is_signed,
bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter)
OpQuantSpecGetter op_quant_spec_getter,
bool enforce_fixed_output_range)
: fn_(fn),
builder_(fn.getBody()),
is_signed_(is_signed),
disable_per_channel_(disable_per_channel),
op_quant_spec_getter_(op_quant_spec_getter) {}
op_quant_spec_getter_(op_quant_spec_getter),
enforce_fixed_output_range_(enforce_fixed_output_range) {}
// The entry point of the quantization parameters propagation.
void Run();
@ -354,6 +356,8 @@ class QuantizationDriver {
llvm::SmallVector<BlockArgument, 4> args_;
OpQuantSpecGetter op_quant_spec_getter_;
bool enforce_fixed_output_range_;
};
} // namespace
@ -794,7 +798,8 @@ bool QuantizationDriver::PropagateParams() {
}
// TODO(fengliuai): make the bit width configurable.
if (auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op)) {
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
if (restricted && enforce_fixed_output_range_) {
// TODO(fengliuai): different result can have different fixed range.
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
for (auto i = 0; i < op->getNumResults(); ++i) {
@ -864,10 +869,12 @@ void QuantizationDriver::Run() {
}
}
void ApplyQuantizationParamsPropagation(
mlir::FuncOp func, bool is_signed, bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter) {
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter)
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter,
bool post_training_quantization) {
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
post_training_quantization)
.Run();
}

View File

@ -490,9 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
// and the propagation results are materialized by inserting pairs of quantize
// and dequantize ops to this function. Set `disable_per_channel` to true to not
// use per channel quantization even the op supports it.
// Setting `enforce_fixed_output_range` to true, to infer quantization
// parameters from the fixed output range ops. This is only used for
// post-training quantization.
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter);
OpQuantSpecGetter op_quant_spec_getter,
bool enforce_fixed_output_range);
// The function might contain more stats ops than required, and it will
// introduce requantize if the calibration stats have conflicts. This method

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
@ -122,6 +123,10 @@ class PrepareQuantizePass
// the best quantization practise. This also fixes some simple violations.
void SanityCheckAndAdjustment(FuncOp func);
// Whether the func contains Quantize ops. This is used to determine whether
// to use the quantization parameters from the fixed output range property.
bool ContainsQuantizeOps(FuncOp func);
QuantizationSpecs quant_specs_;
};
@ -285,6 +290,13 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
});
}
bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
for (const auto& op : func.getOps()) {
if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
}
return false;
}
using PrepareQuantStats =
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
@ -309,6 +321,7 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
bool enforce_fixed_output_range = ContainsQuantizeOps(func);
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
@ -327,7 +340,8 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors).
ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
GetOpQuantSpec);
GetOpQuantSpec,
enforce_fixed_output_range || quant_specs_.post_training_quantization);
ConvertMlirQuantOpsToTFLQuantOps(func);
}