Expose inference type in the mlir quantizer
This is to prepare the 16 bits activation quantization release. The data type specified by this flag is only applied on the activations. PiperOrigin-RevId: 311478782 Change-Id: I5f63f0508011cc0b1b47a0debb35c17d3284eae9
This commit is contained in:
parent
03f3e8153c
commit
d33cb73389
|
@ -414,9 +414,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||||
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
|
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
|
||||||
TFL_TensorOfOrNone<[F32, I32]>:$bias,
|
TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
|
||||||
I32Attr:$dilation_h_factor,
|
I32Attr:$dilation_h_factor,
|
||||||
I32Attr:$dilation_w_factor,
|
I32Attr:$dilation_w_factor,
|
||||||
TFL_AFAttr:$fused_activation_function,
|
TFL_AFAttr:$fused_activation_function,
|
||||||
|
@ -425,7 +425,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||||
I32Attr:$stride_w
|
I32Attr:$stride_w
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
|
||||||
|
|
||||||
let hasOptions = 0b1;
|
let hasOptions = 0b1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -38,6 +39,7 @@ namespace lite {
|
||||||
TfLiteStatus QuantizeModel(
|
TfLiteStatus QuantizeModel(
|
||||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||||
const tflite::TensorType& output_type,
|
const tflite::TensorType& output_type,
|
||||||
|
const tflite::TensorType& inference_type,
|
||||||
const std::unordered_set<std::string>& operator_names,
|
const std::unordered_set<std::string>& operator_names,
|
||||||
bool disable_per_channel, bool fully_quantize,
|
bool disable_per_channel, bool fully_quantize,
|
||||||
flatbuffers::FlatBufferBuilder* builder,
|
flatbuffers::FlatBufferBuilder* builder,
|
||||||
|
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
|
||||||
// Apply quantization passes
|
// Apply quantization passes
|
||||||
PassManager pm(module->getContext());
|
PassManager pm(module->getContext());
|
||||||
TFL::QuantizationSpecs quant_specs;
|
TFL::QuantizationSpecs quant_specs;
|
||||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
|
||||||
quant_specs.post_training_quantization = true;
|
quant_specs.post_training_quantization = true;
|
||||||
quant_specs.disable_per_channel = disable_per_channel;
|
quant_specs.disable_per_channel = disable_per_channel;
|
||||||
|
|
||||||
|
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
|
||||||
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
auto input_tf_type = tflite::TflTypeToTfType(input_type);
|
||||||
if (input_tf_type == tensorflow::DT_FLOAT) {
|
if (input_tf_type == tensorflow::DT_FLOAT) {
|
||||||
emit_adaptor = true;
|
emit_adaptor = true;
|
||||||
} else if (input_tf_type == tensorflow::DT_UINT8) {
|
} else if (input_tf_type == tensorflow::DT_UINT8 ||
|
||||||
quant_specs.inference_type = tensorflow::DT_QUINT8;
|
input_tf_type == tensorflow::DT_INT8 ||
|
||||||
|
input_tf_type == tensorflow::DT_INT16) {
|
||||||
|
quant_specs.inference_type = input_tf_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));
|
||||||
|
|
|
@ -26,11 +26,13 @@ namespace mlir {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
|
||||||
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
|
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
|
||||||
// The `input_type` and `output_type` can be float32/qint8/int8.
|
// The `input_type`, `output_type` and `inference_type` can be
|
||||||
|
// float32/qint8/int8/int16.
|
||||||
// Return partially quantized model if `fully_quantize` is false.
|
// Return partially quantized model if `fully_quantize` is false.
|
||||||
TfLiteStatus QuantizeModel(
|
TfLiteStatus QuantizeModel(
|
||||||
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
|
||||||
const tflite::TensorType& output_type,
|
const tflite::TensorType& output_type,
|
||||||
|
const tflite::TensorType& inference_type,
|
||||||
const std::unordered_set<std::string>& operator_names,
|
const std::unordered_set<std::string>& operator_names,
|
||||||
bool disable_per_channel, bool fully_quantize,
|
bool disable_per_channel, bool fully_quantize,
|
||||||
flatbuffers::FlatBufferBuilder* builder,
|
flatbuffers::FlatBufferBuilder* builder,
|
||||||
|
|
|
@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
|
||||||
|
|
||||||
tflite::StderrReporter error_reporter;
|
tflite::StderrReporter error_reporter;
|
||||||
return mlir::lite::QuantizeModel(
|
return mlir::lite::QuantizeModel(
|
||||||
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
|
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
|
||||||
|
tflite::TensorType_INT8, {},
|
||||||
/*disable_per_channel=*/false,
|
/*disable_per_channel=*/false,
|
||||||
/*fully_quantize=*/true, builder, &error_reporter);
|
/*fully_quantize=*/true, builder, &error_reporter);
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,8 @@ class ConverterError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def mlir_quantize(input_data_str, disable_per_channel=False):
|
def mlir_quantize(input_data_str, disable_per_channel=False,
|
||||||
|
inference_type=_types_pb2.INT8):
|
||||||
"""Quantize `input_data_str` with calibration results.
|
"""Quantize `input_data_str` with calibration results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -116,13 +117,15 @@ def mlir_quantize(input_data_str, disable_per_channel=False):
|
||||||
calibration results).
|
calibration results).
|
||||||
disable_per_channel: Bool indicating whether to do per-channel or
|
disable_per_channel: Bool indicating whether to do per-channel or
|
||||||
per-tensor quantization
|
per-tensor quantization
|
||||||
|
inference_type: Data type for the activations. The default value is int8.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
|
Quantized model in serialized form (e.g. a TFLITE model) with floating-point
|
||||||
inputs and outputs.
|
inputs and outputs.
|
||||||
"""
|
"""
|
||||||
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
return wrap_toco.wrapped_experimental_mlir_quantize(input_data_str,
|
||||||
disable_per_channel)
|
disable_per_channel,
|
||||||
|
inference_type)
|
||||||
|
|
||||||
|
|
||||||
def mlir_sparsify(input_data_str):
|
def mlir_sparsify(input_data_str):
|
||||||
|
|
|
@ -29,7 +29,9 @@ import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.lite.python import lite
|
from tensorflow.lite.python import lite
|
||||||
from tensorflow.lite.python import lite_v2_test_util
|
from tensorflow.lite.python import lite_v2_test_util
|
||||||
|
from tensorflow.lite.python.convert import mlir_quantize
|
||||||
from tensorflow.lite.python.interpreter import Interpreter
|
from tensorflow.lite.python.interpreter import Interpreter
|
||||||
|
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.layers import recurrent
|
from tensorflow.python.keras.layers import recurrent
|
||||||
|
@ -204,6 +206,40 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||||
# Ensure that the quantized weights tflite model is smaller.
|
# Ensure that the quantized weights tflite model is smaller.
|
||||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||||
|
|
||||||
|
def testCalibrateAndQuantizeBuiltinInt16(self):
|
||||||
|
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||||
|
|
||||||
|
# Convert float model.
|
||||||
|
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||||
|
float_tflite = float_converter.convert()
|
||||||
|
self.assertTrue(float_tflite)
|
||||||
|
|
||||||
|
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||||
|
# TODO(b/156309549): We should add INT16 to the builtin types.
|
||||||
|
converter.target_spec.supported_ops = [
|
||||||
|
lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||||
|
]
|
||||||
|
converter.representative_dataset = calibration_gen
|
||||||
|
converter._experimental_calibrate_only = True
|
||||||
|
calibrated_tflite = converter.convert()
|
||||||
|
quantized_tflite = mlir_quantize(calibrated_tflite,
|
||||||
|
inference_type=_types_pb2.QUANTIZED_INT16)
|
||||||
|
|
||||||
|
self.assertTrue(quantized_tflite)
|
||||||
|
|
||||||
|
# The default input and output types should be float.
|
||||||
|
interpreter = Interpreter(model_content=quantized_tflite)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertLen(input_details, 1)
|
||||||
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
self.assertLen(output_details, 1)
|
||||||
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
|
|
||||||
|
# Ensure that the quantized weights tflite model is smaller.
|
||||||
|
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||||
|
|
||||||
def _getTrainingTimeQuantizedModel(self):
|
def _getTrainingTimeQuantizedModel(self):
|
||||||
|
|
||||||
class QLinear(tf.keras.layers.Layer):
|
class QLinear(tf.keras.layers.Layer):
|
||||||
|
|
|
@ -43,10 +43,12 @@ def wrapped_get_potentially_supported_ops():
|
||||||
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
|
return _pywrap_toco_api.TocoGetPotentiallySupportedOps()
|
||||||
|
|
||||||
|
|
||||||
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel):
|
def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel,
|
||||||
|
inference_type):
|
||||||
"""Wraps experimental mlir quantize model."""
|
"""Wraps experimental mlir quantize model."""
|
||||||
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
return _pywrap_toco_api.ExperimentalMlirQuantizeModel(input_data_str,
|
||||||
disable_per_channel)
|
disable_per_channel,
|
||||||
|
inference_type)
|
||||||
|
|
||||||
|
|
||||||
def wrapped_experimental_mlir_sparsify(input_data_str):
|
def wrapped_experimental_mlir_sparsify(input_data_str):
|
||||||
|
|
|
@ -54,6 +54,7 @@ cc_library(
|
||||||
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
|
"//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer",
|
||||||
"//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
|
"//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
|
||||||
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
|
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
|
||||||
|
"//tensorflow/lite/toco:types_proto_cc",
|
||||||
] + select({
|
] + select({
|
||||||
# This is required when running `tflite_convert` from `bazel`.
|
# This is required when running `tflite_convert` from `bazel`.
|
||||||
# It requires to link with TensorFlow Ops to get the op definitions.
|
# It requires to link with TensorFlow Ops to get the op definitions.
|
||||||
|
|
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/toco/toco_tooling.h"
|
#include "tensorflow/lite/toco/toco_tooling.h"
|
||||||
#include "tensorflow/lite/toco/toco_types.h"
|
#include "tensorflow/lite/toco/toco_types.h"
|
||||||
#include "tensorflow/lite/toco/tooling_util.h"
|
#include "tensorflow/lite/toco/tooling_util.h"
|
||||||
|
#include "tensorflow/lite/toco/types.pb.h"
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
|
@ -229,7 +230,7 @@ PyObject* TocoGetPotentiallySupportedOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||||
bool fully_quantize) {
|
bool fully_quantize, int inference_type) {
|
||||||
using tflite::interpreter_wrapper::PythonErrorReporter;
|
using tflite::interpreter_wrapper::PythonErrorReporter;
|
||||||
char* buf = nullptr;
|
char* buf = nullptr;
|
||||||
Py_ssize_t length;
|
Py_ssize_t length;
|
||||||
|
@ -249,11 +250,25 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||||
auto tflite_model = absl::make_unique<tflite::ModelT>();
|
auto tflite_model = absl::make_unique<tflite::ModelT>();
|
||||||
model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
|
model->GetModel()->UnPackTo(tflite_model.get(), nullptr);
|
||||||
|
|
||||||
|
tflite::TensorType inference_tensor_type;
|
||||||
|
switch (inference_type) {
|
||||||
|
case toco::IODataType::QUANTIZED_INT16:
|
||||||
|
inference_tensor_type = tflite::TensorType_INT16;
|
||||||
|
break;
|
||||||
|
case toco::IODataType::QUANTIZED_UINT8:
|
||||||
|
inference_tensor_type = tflite::TensorType_UINT8;
|
||||||
|
break;
|
||||||
|
case toco::IODataType::INT8:
|
||||||
|
inference_tensor_type = tflite::TensorType_INT8;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
flatbuffers::FlatBufferBuilder builder;
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
auto status = mlir::lite::QuantizeModel(
|
auto status = mlir::lite::QuantizeModel(
|
||||||
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
|
*tflite_model, tflite::TensorType::TensorType_FLOAT32,
|
||||||
tflite::TensorType::TensorType_FLOAT32, {}, disable_per_channel,
|
tflite::TensorType::TensorType_FLOAT32, inference_tensor_type, {},
|
||||||
fully_quantize, &builder, error_reporter.get());
|
disable_per_channel, fully_quantize, &builder, error_reporter.get());
|
||||||
|
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
error_reporter->exception();
|
error_reporter->exception();
|
||||||
|
|
|
@ -44,7 +44,7 @@ PyObject* TocoGetPotentiallySupportedOps();
|
||||||
// is specified by the calibration data are not sufficient to quantize the
|
// is specified by the calibration data are not sufficient to quantize the
|
||||||
// model.
|
// model.
|
||||||
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel,
|
||||||
bool fully_quantize);
|
bool fully_quantize, int inference_type);
|
||||||
|
|
||||||
// Sparsifies model to encode sparse tensors with proper format. Throws error if
|
// Sparsifies model to encode sparse tensors with proper format. Throws error if
|
||||||
// sparsification fails.
|
// sparsification fails.
|
||||||
|
|
|
@ -57,12 +57,13 @@ PYBIND11_MODULE(_pywrap_toco_api, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"ExperimentalMlirQuantizeModel",
|
"ExperimentalMlirQuantizeModel",
|
||||||
[](py::object input_contents_txt_raw, bool disable_per_channel,
|
[](py::object input_contents_txt_raw, bool disable_per_channel,
|
||||||
bool fully_quantize) {
|
bool fully_quantize, int inference_type) {
|
||||||
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
|
return tensorflow::PyoOrThrow(toco::MlirQuantizeModel(
|
||||||
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize));
|
input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize,
|
||||||
|
inference_type));
|
||||||
},
|
},
|
||||||
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
|
py::arg("input_contents_txt_raw"), py::arg("disable_per_channel") = false,
|
||||||
py::arg("fully_quantize") = true,
|
py::arg("fully_quantize") = true, py::arg("inference_type") = 9,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Returns a quantized model.
|
Returns a quantized model.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
Loading…
Reference in New Issue