Merge pull request #39171 from lgeiger:default-quant-int8

PiperOrigin-RevId: 311352383
Change-Id: I3527375b386072dc0a3893f884f0239ca714e66b
This commit is contained in:
TensorFlower Gardener 2020-05-13 10:16:29 -07:00
commit eaacee1738
4 changed files with 17 additions and 10 deletions

View File

@ -90,7 +90,7 @@ struct QuantizationSpecs {
bool RunWeightQuantization() const { return weight_quantization; }
// Whether this inference type represents a signed storage type.
bool IsSignedInferenceType() {
bool IsSignedInferenceType() const {
switch (inference_type) {
case tensorflow::DT_QUINT8:
case tensorflow::DT_QUINT16:
@ -102,7 +102,7 @@ struct QuantizationSpecs {
// Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type.
int64_t GetQuantizationTypeWidth() {
int64_t GetQuantizationTypeWidth() const {
switch (inference_type) {
case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8:

View File

@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0)));
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));

View File

@ -46,8 +46,11 @@ namespace {
class DefaultQuantParamsPass
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
public:
explicit DefaultQuantParamsPass(double default_min, double default_max)
: default_min_(default_min), default_max_(default_max) {}
explicit DefaultQuantParamsPass(double default_min, double default_max,
bool is_signed)
: default_min_(default_min),
default_max_(default_max),
is_signed_(is_signed) {}
void runOnFunction() override;
@ -82,6 +85,7 @@ class DefaultQuantParamsPass
double default_min_;
double default_max_;
bool is_signed_;
quant::QuantParams default_quant_params_;
};
} // namespace
@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
default_quant_params_ = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(),
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
builder.getF32Type());
builder.getF32Type(), is_signed_);
}
return default_quant_params_;
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
double default_min, double default_max, bool is_signed) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max,
is_signed);
}
// Registers this pass with default values, only for test
@ -230,7 +235,8 @@ static PassRegistration<DefaultQuantParamsPass> pass(
"tfl-default-quant",
"Apply quantization with default quantization parameter", [] {
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
/*default_max=*/1.0);
/*default_max=*/1.0,
/*is_signed=*/false);
});
} // namespace TFL

View File

@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max);
double default_min, double default_max, bool is_signed);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format.