Expose inference type in the mlir quantizer

This is to prepare the 16 bits activation quantization release. The data type
specified by this flag is only applied on the activations.

PiperOrigin-RevId: 311478782
Change-Id: I5f63f0508011cc0b1b47a0debb35c17d3284eae9
This commit is contained in:
Feng Liu 2020-05-13 23:48:03 -07:00 committed by TensorFlower Gardener
parent 03f3e8153c
commit d33cb73389
11 changed files with 84 additions and 19 deletions

View File

@ -414,9 +414,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
}];
let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, I32]>:$bias,
TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function,
@ -425,7 +425,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
I32Attr:$stride_w
);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 0b1;
}

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
namespace lite {
@ -38,6 +39,7 @@ namespace lite {
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes
PassManager pm(module->getContext());
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true;
quant_specs.disable_per_channel = disable_per_channel;
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) {
quant_specs.inference_type = tensorflow::DT_QUINT8;
} else if (input_tf_type == tensorflow::DT_UINT8 ||
input_tf_type == tensorflow::DT_INT8 ||
input_tf_type == tensorflow::DT_INT16) {
quant_specs.inference_type = input_tf_type;
}
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));

View File

@ -26,11 +26,13 @@ namespace mlir {
namespace lite {
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
// The `input_type` and `output_type` can be float32/qint8/int8.
// The `input_type`, `output_type` and `inference_type` can be
// float32/qint8/int8/int16.
// Return partially quantized model if `fully_quantize` is false.
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,

View File

@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
tflite::TensorType_INT8, {},
/*disable_per_channel=*/false,
/*fully_quantize=*/true, builder, &error_reporter);
}

View File

@ -108,7 +108,8 @@ class ConverterError(Exception):
pass
def mlir_quantize(input_data_str, disable_per_channel=False):
def mlir_quantize(input_data_str, disable_per_channel=False,
inference_type=_types_pb2.INT8):
"""Quantize `input_data_str` with calibration results.
Args:
@ -116,13 +117,15 @@ def mlir_quantize(input_data_str, disable_per_channel=False):
calibration results).
disable_per_channel: Bool indicating whether to do per-channel or
per-tensor quantization
inference_type: Data type for the activations. The default value is int8.
Returns:
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
inputs and outputs.
"""
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
disable_per_channel)
disable_per_channel,
inference_type)
def mlir_sparsify(input_data_str):

View File

@ -29,7 +29,9 @@ import tensorflow as tf
from tensorflow.lite.python import lite
from tensorflow.lite.python import lite_v2_test_util
from tensorflow.lite.python.convert import mlir_quantize
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import recurrent
@ -204,6 +206,40 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def testCalibrateAndQuantizeBuiltinInt16(self):
func, calibration_gen = self._getCalibrationQuantizeModel()
# Convert float model.
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
# TODO(b/156309549): We should add INT16 to the builtin types.
converter.target_spec.supported_ops = [
lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.representative_dataset = calibration_gen
converter._experimental_calibrate_only = True
calibrated_tflite = converter.convert()
quantized_tflite = mlir_quantize(calibrated_tflite,
inference_type=_types_pb2.QUANTIZED_INT16)
self.assertTrue(quantized_tflite)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def _getTrainingTimeQuantizedModel(self):
class QLinear(tf.keras.layers.Layer):

View File

@ -43,10 +43,12 @@ def wrapped_get_potentially_supported_ops():
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel):
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
inference_type):
"""Wraps experimental mlir quantize model."""
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
disable_per_channel)
disable_per_channel,
inference_type)
def wrapped_experimental_mlir_sparsify(input_data_str):

View File

@ -54,6 +54,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
"//tensorflow/lite/toco:types_proto_cc",
] + select({
# This is required when running `tflite_convert` from `bazel`.
# It requires to link with TensorFlow Ops to get the op definitions.

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/lite/toco/toco_tooling.h"
#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/lite/toco/types.pb.h"
namespace toco {
@ -229,7 +230,7 @@ PyObject* TocoGetPotentiallySupportedOps() {
}
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize) {
bool fully_quantize, int inference_type) {
using tflite::interpreter_wrapper::PythonErrorReporter;
char* buf = nullptr;
Py_ssize_t length;
@ -249,11 +250,25 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
auto tflite_model = absl::make_unique<tflite::ModelT>();
model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
tflite::TensorType inference_tensor_type;
switch (inference_type) {
case toco::IODataType::QUANTIZED_INT16:
inference_tensor_type = tflite::TensorType_INT16;
break;
case toco::IODataType::QUANTIZED_UINT8:
inference_tensor_type = tflite::TensorType_UINT8;
break;
case toco::IODataType::INT8:
inference_tensor_type = tflite::TensorType_INT8;
break;
default:
return nullptr;
}
flatbuffers::FlatBufferBuilder builder;
auto status = mlir::lite::QuantizeModel(
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel,
fully_quantize, &builder, error_reporter.get());
tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {},
disable_per_channel, fully_quantize, &builder, error_reporter.get());
if (status != kTfLiteOk) {
error_reporter->exception();

View File

@ -44,7 +44,7 @@ PyObject* TocoGetPotentiallySupportedOps();
// is specified by the calibration data are not sufficient to quantize the
// model.
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
bool fully_quantize);
bool fully_quantize, int inference_type);
// Sparsifies model to encode sparse tensors with proper format. Throws error if
// sparsification fails.

View File

@ -57,12 +57,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
m.def(
"ExperimentalMlirQuantizeModel",
[](py::object input_contents_txt_raw, bool disable_per_channel,
bool fully_quantize) {
bool fully_quantize, int inference_type) {
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize));
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize,
inference_type));
},
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
py::arg("fully_quantize") = true,
py::arg("fully_quantize") = true, py::arg("inference_type") = 9,
R"pbdoc(
Returns a quantized model.
)pbdoc");