From 5be613ef4f3ec2608deed653ab4815bbbcfbe7f8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 May 2020 12:09:14 -0700 Subject: [PATCH] Expose disable_per_channel in MLIR to be used experimentally by tflite tooling PiperOrigin-RevId: 310201122 Change-Id: I3fb460a182a23ae1cacb7f346d756a6e36eee748 --- .../mlir/lite/quantization/lite/quantize_model.cc | 4 +++- .../compiler/mlir/lite/quantization/lite/quantize_model.h | 3 ++- .../compiler/mlir/lite/quantization/lite/tfl_quantizer.cc | 1 + .../compiler/mlir/lite/quantization/quantization_config.h | 6 ++++++ .../compiler/mlir/lite/transforms/prepare_quantize.cc | 5 +++-- tensorflow/lite/python/convert.py | 7 +++++-- tensorflow/lite/python/wrap_toco.py | 5 +++-- tensorflow/lite/toco/python/toco_python_api.cc | 7 ++++--- tensorflow/lite/toco/python/toco_python_api.h | 3 ++- tensorflow/python/lite/toco_python_api_wrapper.cc | 8 +++++--- 10 files changed, 34 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 9b49757fd3f..0ac3fa419bc 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -38,7 +38,8 @@ namespace lite { TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool fully_quantize, + const std::unordered_set& operator_names, + bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter) { // TODO(b/142502494): remove this restriction by improving the `emit_adaptor` @@ -74,6 +75,7 @@ TfLiteStatus QuantizeModel( TFL::QuantizationSpecs quant_specs; quant_specs.inference_type = tensorflow::DT_QINT8; quant_specs.post_training_quantization = true; + quant_specs.disable_per_channel = disable_per_channel; bool emit_adaptor = false; auto input_tf_type = tflite::TflTypeToTfType(input_type); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 473e97e07df..578aa6438de 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -31,7 +31,8 @@ namespace lite { TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool fully_quantize, + const std::unordered_set& operator_names, + bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter); } // namespace lite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 7530cdf008f..77bd87a3c03 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -47,6 +47,7 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( *model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, + /*disable_per_channel=*/false, /*fully_quantize=*/true, builder, &error_reporter); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 5b1c73e7887..cac1df9eee1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -46,6 +46,12 @@ struct QuantizationSpecs { // post-training quantization. We need to deprecate the `weight_quantization`. bool post_training_quantization = false; + // When set to true, quantization will be done per-tensor. Currently, this + // option is only valid when the quantization parameters need to be created by + // scanning the constant content (post-training quantization or QAT without + // weight FakeQuant). + bool disable_per_channel = false; + // The node type when the model is exported. Currently this is limited to // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 4f25e434fac..a9e10a485bf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -273,8 +273,9 @@ void PrepareQuantizePass::runOnFunction() { // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). - ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel, - GetOpQuantSpec); + ApplyQuantizationParamsPropagation( + func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, + GetOpQuantSpec); ConvertMlirQuantOpsToTFLQuantOps(func); } diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 70bea536433..ae70afd6962 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -108,18 +108,21 @@ class ConverterError(Exception): pass -def mlir_quantize(input_data_str): +def mlir_quantize(input_data_str, disable_per_channel=False): """Quantize `input_data_str` with calibration results. Args: input_data_str: Input data in serialized form (e.g. a TFLITE model with calibration results). + disable_per_channel: Bool indicating whether to do per-channel or + per-tensor quantization Returns: Quantized model in serialized form (e.g. a TFLITE model) with floating-point inputs and outputs. """ - return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str) + return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str, + disable_per_channel) def mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index b8d3fc3c70b..3c1f98ff42d 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -43,9 +43,10 @@ def wrapped_get_potentially_supported_ops(): return _pywrap_toco_api.TocoGetPotentiallySupportedOps() -def wrapped_experimental_mlir_quantize(input_data_str): +def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel): """Wraps experimental mlir quantize model.""" - return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str) + return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str, + disable_per_channel) def wrapped_experimental_mlir_sparsify(input_data_str): diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index a19f5d26eed..aafd14f9da8 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -228,7 +228,8 @@ PyObject* TocoGetPotentiallySupportedOps() { return list; } -PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize) { +PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, + bool fully_quantize) { using tflite::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; @@ -251,8 +252,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize) { flatbuffers::FlatBufferBuilder builder; auto status = mlir::lite::QuantizeModel( *tflite_model, tflite::TensorType::TensorType_FLOAT32, - tflite::TensorType::TensorType_FLOAT32, {}, fully_quantize, &builder, - error_reporter.get()); + tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel, + fully_quantize, &builder, error_reporter.get()); if (status != kTfLiteOk) { error_reporter->exception(); diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index c7c7a3549a6..7afb097fd4a 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -43,7 +43,8 @@ PyObject* TocoGetPotentiallySupportedOps(); // Quantize the model with calibration data. Throw errors if `fully_quantize` // is specified by the calibration data are not sufficient to quantize the // model. -PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize); +PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, + bool fully_quantize); // Sparsifies model to encode sparse tensors with proper format. Throws error if // sparsification fails. diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index 2c6cee5d54d..e6e0e111ec4 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -56,11 +56,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { )pbdoc"); m.def( "ExperimentalMlirQuantizeModel", - [](py::object input_contents_txt_raw, bool fully_quantize) { + [](py::object input_contents_txt_raw, bool disable_per_channel, + bool fully_quantize) { return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( - input_contents_txt_raw.ptr(), fully_quantize)); + input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize)); }, - py::arg("input_contents_txt_raw"), py::arg("fully_quantize") = true, + py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false, + py::arg("fully_quantize") = true, R"pbdoc( Returns a quantized model. )pbdoc");