Merge pull request #39171 from lgeiger:default-quant-int8
PiperOrigin-RevId: 311352383 Change-Id: I3527375b386072dc0a3893f884f0239ca714e66b
This commit is contained in:
commit
eaacee1738
|
@ -90,7 +90,7 @@ struct QuantizationSpecs {
|
|||
bool RunWeightQuantization() const { return weight_quantization; }
|
||||
|
||||
// Whether this inference type represents a signed storage type.
|
||||
bool IsSignedInferenceType() {
|
||||
bool IsSignedInferenceType() const {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QUINT8:
|
||||
case tensorflow::DT_QUINT16:
|
||||
|
@ -102,7 +102,7 @@ struct QuantizationSpecs {
|
|||
|
||||
// Gets the width of this quantization type. Returns 0 if it isn't a
|
||||
// quantization type.
|
||||
int64_t GetQuantizationTypeWidth() {
|
||||
int64_t GetQuantizationTypeWidth() const {
|
||||
switch (inference_type) {
|
||||
case tensorflow::DT_QINT8:
|
||||
case tensorflow::DT_QUINT8:
|
||||
|
|
|
@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
|||
quant_specs.default_ranges.second.hasValue()) {
|
||||
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
|
||||
quant_specs.default_ranges.first.getValueOr(0.0),
|
||||
quant_specs.default_ranges.second.getValueOr(0.0)));
|
||||
quant_specs.default_ranges.second.getValueOr(0.0),
|
||||
quant_specs.IsSignedInferenceType()));
|
||||
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
|
|
|
@ -46,8 +46,11 @@ namespace {
|
|||
class DefaultQuantParamsPass
|
||||
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
|
||||
public:
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max)
|
||||
: default_min_(default_min), default_max_(default_max) {}
|
||||
explicit DefaultQuantParamsPass(double default_min, double default_max,
|
||||
bool is_signed)
|
||||
: default_min_(default_min),
|
||||
default_max_(default_max),
|
||||
is_signed_(is_signed) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
|
@ -82,6 +85,7 @@ class DefaultQuantParamsPass
|
|||
|
||||
double default_min_;
|
||||
double default_max_;
|
||||
bool is_signed_;
|
||||
quant::QuantParams default_quant_params_;
|
||||
};
|
||||
} // namespace
|
||||
|
@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
|||
default_quant_params_ = quant::fakeQuantAttrsToType(
|
||||
builder.getUnknownLoc(),
|
||||
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
|
||||
builder.getF32Type());
|
||||
builder.getF32Type(), is_signed_);
|
||||
}
|
||||
return default_quant_params_;
|
||||
}
|
||||
|
||||
// Creates an instance of the default quant parameters pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
|
||||
double default_min, double default_max, bool is_signed) {
|
||||
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max,
|
||||
is_signed);
|
||||
}
|
||||
|
||||
// Registers this pass with default values, only for test
|
||||
|
@ -230,7 +235,8 @@ static PassRegistration<DefaultQuantParamsPass> pass(
|
|||
"tfl-default-quant",
|
||||
"Apply quantization with default quantization parameter", [] {
|
||||
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
|
||||
/*default_max=*/1.0);
|
||||
/*default_max=*/1.0,
|
||||
/*is_signed=*/false);
|
||||
});
|
||||
|
||||
} // namespace TFL
|
||||
|
|
|
@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
|||
// Creates an instance of the TensorFlow Lite dialect pass to add default
|
||||
// quantization parameters.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max);
|
||||
double default_min, double default_max, bool is_signed);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
|
||||
// tensor to sparse format.
|
||||
|
|
Loading…
Reference in New Issue