Only infer the fixed output range when the input graph has dequantize ops
PiperOrigin-RevId: 327312937 Change-Id: Ice4c2e35aeb074e34516dc434ca0c066af947ca8
This commit is contained in:
parent
62793f5afa
commit
81c73b541d
@ -99,12 +99,14 @@ class QuantizationDriver {
|
||||
public:
|
||||
explicit QuantizationDriver(FuncOp fn, bool is_signed,
|
||||
bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter)
|
||||
OpQuantSpecGetter op_quant_spec_getter,
|
||||
bool enforce_fixed_output_range)
|
||||
: fn_(fn),
|
||||
builder_(fn.getBody()),
|
||||
is_signed_(is_signed),
|
||||
disable_per_channel_(disable_per_channel),
|
||||
op_quant_spec_getter_(op_quant_spec_getter) {}
|
||||
op_quant_spec_getter_(op_quant_spec_getter),
|
||||
enforce_fixed_output_range_(enforce_fixed_output_range) {}
|
||||
|
||||
// The entry point of the quantization parameters propagation.
|
||||
void Run();
|
||||
@ -354,6 +356,8 @@ class QuantizationDriver {
|
||||
llvm::SmallVector<BlockArgument, 4> args_;
|
||||
|
||||
OpQuantSpecGetter op_quant_spec_getter_;
|
||||
|
||||
bool enforce_fixed_output_range_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -794,7 +798,8 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
if (auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op)) {
|
||||
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
|
||||
if (restricted && enforce_fixed_output_range_) {
|
||||
// TODO(fengliuai): different result can have different fixed range.
|
||||
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
@ -864,10 +869,12 @@ void QuantizationDriver::Run() {
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyQuantizationParamsPropagation(
|
||||
mlir::FuncOp func, bool is_signed, bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter) {
|
||||
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter)
|
||||
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter,
|
||||
bool post_training_quantization) {
|
||||
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
|
||||
post_training_quantization)
|
||||
.Run();
|
||||
}
|
||||
|
||||
|
@ -490,9 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
|
||||
// and the propagation results are materialized by inserting pairs of quantize
|
||||
// and dequantize ops to this function. Set `disable_per_channel` to true to not
|
||||
// use per channel quantization even the op supports it.
|
||||
// Setting `enforce_fixed_output_range` to true, to infer quantization
|
||||
// parameters from the fixed output range ops. This is only used for
|
||||
// post-training quantization.
|
||||
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter);
|
||||
OpQuantSpecGetter op_quant_spec_getter,
|
||||
bool enforce_fixed_output_range);
|
||||
|
||||
// The function might contain more stats ops than required, and it will
|
||||
// introduce requantize if the calibration stats have conflicts. This method
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
@ -122,6 +123,10 @@ class PrepareQuantizePass
|
||||
// the best quantization practise. This also fixes some simple violations.
|
||||
void SanityCheckAndAdjustment(FuncOp func);
|
||||
|
||||
// Whether the func contains Quantize ops. This is used to determine whether
|
||||
// to use the quantization parameters from the fixed output range property.
|
||||
bool ContainsQuantizeOps(FuncOp func);
|
||||
|
||||
QuantizationSpecs quant_specs_;
|
||||
};
|
||||
|
||||
@ -285,6 +290,13 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
|
||||
});
|
||||
}
|
||||
|
||||
bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
|
||||
for (const auto& op : func.getOps()) {
|
||||
if (llvm::isa<quant::DequantizeCastOp>(op)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
using PrepareQuantStats =
|
||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
@ -309,6 +321,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
bool enforce_fixed_output_range = ContainsQuantizeOps(func);
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
@ -327,7 +340,8 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(
|
||||
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
||||
GetOpQuantSpec);
|
||||
GetOpQuantSpec,
|
||||
enforce_fixed_output_range || quant_specs_.post_training_quantization);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user