Merge pull request #39171 from lgeiger:default-quant-int8

PiperOrigin-RevId: 311352383
Change-Id: I3527375b386072dc0a3893f884f0239ca714e66b
This commit is contained in:
TensorFlower Gardener 2020-05-13 10:16:29 -07:00
commit eaacee1738
4 changed files with 17 additions and 10 deletions

View File

@ -90,7 +90,7 @@ struct QuantizationSpecs {
bool RunWeightQuantization() const { return weight_quantization; } bool RunWeightQuantization() const { return weight_quantization; }
// Whether this inference type represents a signed storage type. // Whether this inference type represents a signed storage type.
bool IsSignedInferenceType() { bool IsSignedInferenceType() const {
switch (inference_type) { switch (inference_type) {
case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT8:
case tensorflow::DT_QUINT16: case tensorflow::DT_QUINT16:
@ -102,7 +102,7 @@ struct QuantizationSpecs {
// Gets the width of this quantization type. Returns 0 if it isn't a // Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type. // quantization type.
int64_t GetQuantizationTypeWidth() { int64_t GetQuantizationTypeWidth() const {
switch (inference_type) { switch (inference_type) {
case tensorflow::DT_QINT8: case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT8:

View File

@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.default_ranges.second.hasValue()) { quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0), quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0))); quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass()); pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass( pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));

View File

@ -46,8 +46,11 @@ namespace {
class DefaultQuantParamsPass class DefaultQuantParamsPass
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> { : public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
public: public:
explicit DefaultQuantParamsPass(double default_min, double default_max) explicit DefaultQuantParamsPass(double default_min, double default_max,
: default_min_(default_min), default_max_(default_max) {} bool is_signed)
: default_min_(default_min),
default_max_(default_max),
is_signed_(is_signed) {}
void runOnFunction() override; void runOnFunction() override;
@ -82,6 +85,7 @@ class DefaultQuantParamsPass
double default_min_; double default_min_;
double default_max_; double default_max_;
bool is_signed_;
quant::QuantParams default_quant_params_; quant::QuantParams default_quant_params_;
}; };
} // namespace } // namespace
@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
default_quant_params_ = quant::fakeQuantAttrsToType( default_quant_params_ = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(), builder.getUnknownLoc(),
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false, /*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
builder.getF32Type()); builder.getF32Type(), is_signed_);
} }
return default_quant_params_; return default_quant_params_;
} }
// Creates an instance of the default quant parameters pass. // Creates an instance of the default quant parameters pass.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass( std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) { double default_min, double default_max, bool is_signed) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max); return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max,
is_signed);
} }
// Registers this pass with default values, only for test // Registers this pass with default values, only for test
@ -230,7 +235,8 @@ static PassRegistration<DefaultQuantParamsPass> pass(
"tfl-default-quant", "tfl-default-quant",
"Apply quantization with default quantization parameter", [] { "Apply quantization with default quantization parameter", [] {
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0, return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
/*default_max=*/1.0); /*default_max=*/1.0,
/*is_signed=*/false);
}); });
} // namespace TFL } // namespace TFL

View File

@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance of the TensorFlow Lite dialect pass to add default // Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters. // quantization parameters.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass( std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max); double default_min, double default_max, bool is_signed);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense // Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format. // tensor to sparse format.