Expose disable_per_channel in MLIR to be used experimentally by tflite tooling
PiperOrigin-RevId: 310201122 Change-Id: I3fb460a182a23ae1cacb7f346d756a6e36eee748
This commit is contained in:
parent
dd7df2f89f
commit
5be613ef4f
tensorflow
compiler/mlir/lite
quantization
transforms
lite
python/lite
@ -38,7 +38,8 @@ namespace lite {
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter) {
|
||||
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
|
||||
@ -74,6 +75,7 @@ TfLiteStatus QuantizeModel(
|
||||
TFL::QuantizationSpecs quant_specs;
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.post_training_quantization = true;
|
||||
quant_specs.disable_per_channel = disable_per_channel;
|
||||
|
||||
bool emit_adaptor = false;
|
||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||
|
@ -31,7 +31,8 @@ namespace lite {
|
||||
TfLiteStatus QuantizeModel(
|
||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||
const tflite::TensorType& output_type,
|
||||
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
|
||||
const std::unordered_set<std::string>& operator_names,
|
||||
bool disable_per_channel, bool fully_quantize,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter);
|
||||
} // namespace lite
|
||||
|
@ -47,6 +47,7 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
|
||||
tflite::StderrReporter error_reporter;
|
||||
return mlir::lite::QuantizeModel(
|
||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
|
||||
/*disable_per_channel=*/false,
|
||||
/*fully_quantize=*/true, builder, &error_reporter);
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,12 @@ struct QuantizationSpecs {
|
||||
// post-training quantization. We need to deprecate the `weight_quantization`.
|
||||
bool post_training_quantization = false;
|
||||
|
||||
// When set to true, quantization will be done per-tensor. Currently, this
|
||||
// option is only valid when the quantization parameters need to be created by
|
||||
// scanning the constant content (post-training quantization or QAT without
|
||||
// weight FakeQuant).
|
||||
bool disable_per_channel = false;
|
||||
|
||||
// The node type when the model is exported. Currently this is limited to
|
||||
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
|
||||
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
|
||||
|
@ -273,8 +273,9 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
|
||||
// Finally, the quantization parameters can be propagated to the rest of the
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
|
||||
GetOpQuantSpec);
|
||||
ApplyQuantizationParamsPropagation(
|
||||
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
||||
GetOpQuantSpec);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
@ -108,18 +108,21 @@ class ConverterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mlir_quantize(input_data_str):
|
||||
def mlir_quantize(input_data_str, disable_per_channel=False):
|
||||
"""Quantize `input_data_str` with calibration results.
|
||||
|
||||
Args:
|
||||
input_data_str: Input data in serialized form (e.g. a TFLITE model with
|
||||
calibration results).
|
||||
disable_per_channel: Bool indicating whether to do per-channel or
|
||||
per-tensor quantization
|
||||
|
||||
Returns:
|
||||
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
|
||||
inputs and outputs.
|
||||
"""
|
||||
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str)
|
||||
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
||||
disable_per_channel)
|
||||
|
||||
|
||||
def mlir_sparsify(input_data_str):
|
||||
|
@ -43,9 +43,10 @@ def wrapped_get_potentially_supported_ops():
|
||||
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
|
||||
|
||||
|
||||
def wrapped_experimental_mlir_quantize(input_data_str):
|
||||
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel):
|
||||
"""Wraps experimental mlir quantize model."""
|
||||
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str)
|
||||
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
||||
disable_per_channel)
|
||||
|
||||
|
||||
def wrapped_experimental_mlir_sparsify(input_data_str):
|
||||
|
@ -228,7 +228,8 @@ PyObject* TocoGetPotentiallySupportedOps() {
|
||||
return list;
|
||||
}
|
||||
|
||||
PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize) {
|
||||
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||
bool fully_quantize) {
|
||||
using tflite::interpreter_wrapper::PythonErrorReporter;
|
||||
char* buf = nullptr;
|
||||
Py_ssize_t length;
|
||||
@ -251,8 +252,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize) {
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto status = mlir::lite::QuantizeModel(
|
||||
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
|
||||
tflite::TensorType::TensorType_FLOAT32, {}, fully_quantize, &builder,
|
||||
error_reporter.get());
|
||||
tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel,
|
||||
fully_quantize, &builder, error_reporter.get());
|
||||
|
||||
if (status != kTfLiteOk) {
|
||||
error_reporter->exception();
|
||||
|
@ -43,7 +43,8 @@ PyObject* TocoGetPotentiallySupportedOps();
|
||||
// Quantize the model with calibration data. Throw errors if `fully_quantize`
|
||||
// is specified by the calibration data are not sufficient to quantize the
|
||||
// model.
|
||||
PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize);
|
||||
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||
bool fully_quantize);
|
||||
|
||||
// Sparsifies model to encode sparse tensors with proper format. Throws error if
|
||||
// sparsification fails.
|
||||
|
@ -56,11 +56,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ExperimentalMlirQuantizeModel",
|
||||
[](py::object input_contents_txt_raw, bool fully_quantize) {
|
||||
[](py::object input_contents_txt_raw, bool disable_per_channel,
|
||||
bool fully_quantize) {
|
||||
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
|
||||
input_contents_txt_raw.ptr(), fully_quantize));
|
||||
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize));
|
||||
},
|
||||
py::arg("input_contents_txt_raw"), py::arg("fully_quantize") = true,
|
||||
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
|
||||
py::arg("fully_quantize") = true,
|
||||
R"pbdoc(
|
||||
Returns a quantized model.
|
||||
)pbdoc");
|
||||
|
Loading…
Reference in New Issue
Block a user