Expose disable_per_channel in MLIR to be used experimentally by tflite tooling

PiperOrigin-RevId: 310201122
Change-Id: I3fb460a182a23ae1cacb7f346d756a6e36eee748
This commit is contained in:
A. Unique TensorFlower 2020-05-06 12:09:14 -07:00 committed by TensorFlower Gardener
parent dd7df2f89f
commit 5be613ef4f
10 changed files with 34 additions and 15 deletions

View File

@ -38,7 +38,8 @@ namespace lite {
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
// TODO(b/142502494): remove this restriction by improving the `emit_adaptor`
@ -74,6 +75,7 @@ TfLiteStatus QuantizeModel(
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.post_training_quantization = true;
quant_specs.disable_per_channel = disable_per_channel;
bool emit_adaptor = false;
auto input_tf_type = tflite::TflTypeToTfType(input_type);

View File

@ -31,7 +31,8 @@ namespace lite {
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const std::unordered_set<std::string>& operator_names, bool fully_quantize,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter);
} // namespace lite

View File

@ -47,6 +47,7 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
/*disable_per_channel=*/false,
/*fully_quantize=*/true, builder, &error_reporter);
}

View File

@ -46,6 +46,12 @@ struct QuantizationSpecs {
// post-training quantization. We need to deprecate the `weight_quantization`.
bool post_training_quantization = false;
// When set to true, quantization will be done per-tensor. Currently, this
// option is only valid when the quantization parameters need to be created by
// scanning the constant content (post-training quantization or QAT without
// weight FakeQuant).
bool disable_per_channel = false;
// The node type when the model is exported. Currently this is limited to
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,

View File

@ -273,8 +273,9 @@ void PrepareQuantizePass::runOnFunction() {
// Finally, the quantization parameters can be propagated to the rest of the
// values (tensors).
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
GetOpQuantSpec);
ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
GetOpQuantSpec);
ConvertMlirQuantOpsToTFLQuantOps(func);
}

View File

@ -108,18 +108,21 @@ class ConverterError(Exception):
pass
def mlir_quantize(input_data_str):
def mlir_quantize(input_data_str, disable_per_channel=False):
"""Quantize `input_data_str` with calibration results.
Args:
input_data_str: Input data in serialized form (e.g. a TFLITE model with
calibration results).
disable_per_channel: Bool indicating whether to do per-channel or
per-tensor quantization
Returns:
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
inputs and outputs.
"""
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str)
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
disable_per_channel)
def mlir_sparsify(input_data_str):

View File

@ -43,9 +43,10 @@ def wrapped_get_potentially_supported_ops():
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
def wrapped_experimental_mlir_quantize(input_data_str):
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel):
"""Wraps experimental mlir quantize model."""
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str)
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
disable_per_channel)
def wrapped_experimental_mlir_sparsify(input_data_str):

View File

@ -228,7 +228,8 @@ PyObject* TocoGetPotentiallySupportedOps() {
return list;
}
PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize) {
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize) {
using tflite::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
@ -251,8 +252,8 @@ PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize) {
flatbuffers::FlatBufferBuilder builder;
auto status = mlir::lite::QuantizeModel(
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
tflite::TensorType::TensorType_FLOAT32, {}, fully_quantize, &builder,
error_reporter.get());
tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel,
fully_quantize, &builder, error_reporter.get());
if (status != kTfLiteOk) {
error_reporter->exception();

View File

@ -43,7 +43,8 @@ PyObject* TocoGetPotentiallySupportedOps();
// Quantize the model with calibration data. Throw errors if `fully_quantize`
// is specified by the calibration data are not sufficient to quantize the
// model.
PyObject* MlirQuantizeModel(PyObject* data, bool fully_quantize);
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize);
// Sparsifies model to encode sparse tensors with proper format. Throws error if
// sparsification fails.

View File

@ -56,11 +56,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
)pbdoc");
m.def(
"ExperimentalMlirQuantizeModel",
[](py::object input_contents_txt_raw, bool fully_quantize) {
[](py::object input_contents_txt_raw, bool disable_per_channel,
bool fully_quantize) {
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
input_contents_txt_raw.ptr(), fully_quantize));
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize));
},
py::arg("input_contents_txt_raw"), py::arg("fully_quantize") = true,
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
py::arg("fully_quantize") = true,
R"pbdoc(
Returns a quantized model.
)pbdoc");