Expose inference type in the mlir quantizer

This is to prepare the 16 bits activation quantization release. The data type
specified by this flag is only applied on the activations.

PiperOrigin-RevId: 311478782
Change-Id: I5f63f0508011cc0b1b47a0debb35c17d3284eae9
This commit is contained in:
Feng Liu 2020-05-13 23:48:03 -07:00 committed by TensorFlower Gardener
parent 03f3e8153c
commit d33cb73389
11 changed files with 84 additions and 19 deletions

View File

@ -414,9 +414,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8]>:$filter, TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, I32]>:$bias, TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
I32Attr:$dilation_h_factor, I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor, I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function, TFL_AFAttr:$fused_activation_function,
@ -425,7 +425,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
I32Attr:$stride_w I32Attr:$stride_w
); );
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 0b1; let hasOptions = 0b1;
} }

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir { namespace mlir {
namespace lite { namespace lite {
@ -38,6 +39,7 @@ namespace lite {
TfLiteStatus QuantizeModel( TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type, const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names, const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize, bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder, flatbuffers::FlatBufferBuilder* builder,
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes // Apply quantization passes
PassManager pm(module->getContext()); PassManager pm(module->getContext());
TFL::QuantizationSpecs quant_specs; TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true; quant_specs.post_training_quantization = true;
quant_specs.disable_per_channel = disable_per_channel; quant_specs.disable_per_channel = disable_per_channel;
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
auto input_tf_type = tflite::TflTypeToTfType(input_type); auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) { if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true; emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) { } else if (input_tf_type == tensorflow::DT_UINT8 ||
quant_specs.inference_type = tensorflow::DT_QUINT8; input_tf_type == tensorflow::DT_INT8 ||
input_tf_type == tensorflow::DT_INT16) {
quant_specs.inference_type = input_tf_type;
} }
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));

View File

@ -26,11 +26,13 @@ namespace mlir {
namespace lite { namespace lite {
// Quantize the `input_model` and write the result to a flatbuffer `builder`. // Quantize the `input_model` and write the result to a flatbuffer `builder`.
// The `input_type` and `output_type` can be float32/qint8/int8. // The `input_type`, `output_type` and `inference_type` can be
// float32/qint8/int8/int16.
// Return partially quantized model if `fully_quantize` is false. // Return partially quantized model if `fully_quantize` is false.
TfLiteStatus QuantizeModel( TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type, const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names, const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize, bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder, flatbuffers::FlatBufferBuilder* builder,

View File

@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter; tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel( return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, *model, tflite::TensorType_INT8, tflite::TensorType_INT8,
tflite::TensorType_INT8, {},
/*disable_per_channel=*/false, /*disable_per_channel=*/false,
/*fully_quantize=*/true, builder, &error_reporter); /*fully_quantize=*/true, builder, &error_reporter);
} }

View File

@ -108,7 +108,8 @@ class ConverterError(Exception):
pass pass
def mlir_quantize(input_data_str, disable_per_channel=False): def mlir_quantize(input_data_str, disable_per_channel=False,
inference_type=_types_pb2.INT8):
"""Quantize `input_data_str` with calibration results. """Quantize `input_data_str` with calibration results.
Args: Args:
@ -116,13 +117,15 @@ def mlir_quantize(input_data_str, disable_per_channel=False):
calibration results). calibration results).
disable_per_channel: Bool indicating whether to do per-channel or disable_per_channel: Bool indicating whether to do per-channel or
per-tensor quantization per-tensor quantization
inference_type: Data type for the activations. The default value is int8.
Returns: Returns:
Quantized model in serialized form (e.g. a TFLITE model) with floating-point Quantized model in serialized form (e.g. a TFLITE model) with floating-point
inputs and outputs. inputs and outputs.
""" """
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str, return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
disable_per_channel) disable_per_channel,
inference_type)
def mlir_sparsify(input_data_str): def mlir_sparsify(input_data_str):

View File

@ -29,7 +29,9 @@ import tensorflow as tf
from tensorflow.lite.python import lite from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_v2_test_util from tensorflow.lite.python import lite_v2_test_util
from tensorflow.lite.python.convert import mlir_quantize
from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent
@ -204,6 +206,40 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
# Ensure that the quantized weights tflite model is smaller. # Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite)) self.assertLess(len(quantized_tflite), len(float_tflite))
def testCalibrateAndQuantizeBuiltinInt16(self):
func, calibration_gen = self._getCalibrationQuantizeModel()
# Convert float model.
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
# TODO(b/156309549): We should add INT16 to the builtin types.
converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.representative_dataset = calibration_gen
converter._experimental_calibrate_only = True
calibrated_tflite = converter.convert()
quantized_tflite = mlir_quantize(calibrated_tflite,
inference_type=_types_pb2.QUANTIZED_INT16)
self.assertTrue(quantized_tflite)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def _getTrainingTimeQuantizedModel(self): def _getTrainingTimeQuantizedModel(self):
class QLinear(tf.keras.layers.Layer): class QLinear(tf.keras.layers.Layer):

View File

@ -43,10 +43,12 @@ def wrapped_get_potentially_supported_ops():
return _pywrap_toco_api.TocoGetPotentiallySupportedOps() return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel): def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
inference_type):
"""Wraps experimental mlir quantize model.""" """Wraps experimental mlir quantize model."""
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str, return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
disable_per_channel) disable_per_channel,
inference_type)
def wrapped_experimental_mlir_sparsify(input_data_str): def wrapped_experimental_mlir_sparsify(input_data_str):

View File

@ -54,6 +54,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model", "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
"//tensorflow/lite/toco:types_proto_cc",
] + select({ ] + select({
# This is required when running `tflite_convert` from `bazel`. # This is required when running `tflite_convert` from `bazel`.
# It requires to link with TensorFlow Ops to get the op definitions. # It requires to link with TensorFlow Ops to get the op definitions.

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/lite/toco/toco_tooling.h" #include "tensorflow/lite/toco/toco_tooling.h"
#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/lite/toco/types.pb.h"
namespace toco { namespace toco {
@ -229,7 +230,7 @@ PyObject* TocoGetPotentiallySupportedOps() {
} }
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize) { bool fully_quantize, int inference_type) {
using tflite::interpreter_wrapper::PythonErrorReporter; using tflite::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr; char* buf = nullptr;
Py_ssize_t length; Py_ssize_t length;
@ -249,11 +250,25 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
auto tflite_model = absl::make_unique<tflite::ModelT>(); auto tflite_model = absl::make_unique<tflite::ModelT>();
model->GetModel()->UnPackTo(tflite_model.get(), nullptr); model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
tflite::TensorType inference_tensor_type;
switch (inference_type) {
case toco::IODataType::QUANTIZED_INT16:
inference_tensor_type = tflite::TensorType_INT16;
break;
case toco::IODataType::QUANTIZED_UINT8:
inference_tensor_type = tflite::TensorType_UINT8;
break;
case toco::IODataType::INT8:
inference_tensor_type = tflite::TensorType_INT8;
break;
default:
return nullptr;
}
flatbuffers::FlatBufferBuilder builder; flatbuffers::FlatBufferBuilder builder;
auto status = mlir::lite::QuantizeModel( auto status = mlir::lite::QuantizeModel(
*tflite_model, tflite::TensorType::TensorType_FLOAT32, *tflite_model, tflite::TensorType::TensorType_FLOAT32,
tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel, tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {},
fully_quantize, &builder, error_reporter.get()); disable_per_channel, fully_quantize, &builder, error_reporter.get());
if (status != kTfLiteOk) { if (status != kTfLiteOk) {
error_reporter->exception(); error_reporter->exception();

View File

@ -44,7 +44,7 @@ PyObject* TocoGetPotentiallySupportedOps();
// is specified by the calibration data are not sufficient to quantize the // is specified by the calibration data are not sufficient to quantize the
// model. // model.
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize); bool fully_quantize, int inference_type);
// Sparsifies model to encode sparse tensors with proper format. Throws error if // Sparsifies model to encode sparse tensors with proper format. Throws error if
// sparsification fails. // sparsification fails.

View File

@ -57,12 +57,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
m.def( m.def(
"ExperimentalMlirQuantizeModel", "ExperimentalMlirQuantizeModel",
[](py::object input_contents_txt_raw, bool disable_per_channel, [](py::object input_contents_txt_raw, bool disable_per_channel,
bool fully_quantize) { bool fully_quantize, int inference_type) {
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize)); input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize,
inference_type));
}, },
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false, py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
py::arg("fully_quantize") = true, py::arg("fully_quantize") = true, py::arg("inference_type") = 9,
R"pbdoc( R"pbdoc(
Returns a quantized model. Returns a quantized model.
)pbdoc"); )pbdoc");