Remove the redundant quantize op in the back to back quantize op pairs
This is usually happen when post-training quantizing the model which has custom ops on the output. The existing modifying io lib will produce the back to back quantize op pairs which will cause accuracy dropping. This cl tries to identify the back to back quantize op pairs pattern and remove the first quantize ops in the pairs. PiperOrigin-RevId: 359452170 Change-Id: Ib169208078c0b92ee56d9eef7cc7ac178c9f81a3
This commit is contained in:
parent
045c9963ec
commit
59b6ae1e28
@ -32,6 +32,7 @@ from tensorflow import keras
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python import util
|
||||
from tensorflow.lite.python.convert import ConverterError
|
||||
from tensorflow.lite.python.convert import mlir_quantize
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
@ -1917,6 +1918,44 @@ class FromFrozenGraphObjectDetection(LiteTest):
|
||||
output_details[3]['name'])
|
||||
self.assertAllEqual([1], output_details[3]['shape'])
|
||||
|
||||
def testModifyIOToUint8(self):
|
||||
# Tests the object detection model that cannot be loaded in TensorFlow.
|
||||
self._initObjectDetectionArgs()
|
||||
|
||||
def representative_dataset_gen():
|
||||
for _ in range(2):
|
||||
yield [np.random.uniform(low=0, high=1, size=(1, 300, 300, 3)).astype(
|
||||
np.float32)]
|
||||
converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
|
||||
self._input_arrays,
|
||||
self._output_arrays,
|
||||
self._input_shapes)
|
||||
converter.representative_dataset = representative_dataset_gen
|
||||
converter.target_spec.supported_ops = {lite.OpsSet.TFLITE_BUILTINS_INT8}
|
||||
converter.inference_type = dtypes.int8
|
||||
converter.inference_input_type = dtypes.uint8
|
||||
converter.inference_output_type = dtypes.uint8
|
||||
converter.experimental_new_quantizer = True
|
||||
converter.quantized_input_stats = {
|
||||
'normalized_input_image_tensor': (0., 1.)} # mean, std_dev
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
self.assertIsNotNone(tflite_model)
|
||||
|
||||
model = util._convert_model_from_bytearray_to_object(tflite_model)
|
||||
quant_opcode_idxs = util.get_quantize_opcode_idx(model)
|
||||
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
for op in operators:
|
||||
if op.opcodeIndex in quant_opcode_idxs:
|
||||
input_type = util._convert_tflite_enum_type_to_tf_type(
|
||||
tensors[op.inputs[0]].type)
|
||||
if op.outputs[0] in subgraph.outputs:
|
||||
self.assertEqual(input_type, dtypes.float32)
|
||||
|
||||
|
||||
class FromSavedModelTest(TestModels):
|
||||
|
||||
|
@ -558,6 +558,16 @@ def _convert_model_from_object_to_bytearray(model_object):
|
||||
return bytes(builder.Output())
|
||||
|
||||
|
||||
def get_quantize_opcode_idx(model):
|
||||
"""Returns the quantize op idx."""
|
||||
quant_opcode_idxs = []
|
||||
for idx, opcode in enumerate(model.operatorCodes):
|
||||
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
||||
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
quant_opcode_idxs.append(idx)
|
||||
return quant_opcode_idxs
|
||||
|
||||
|
||||
def _remove_tensors_from_model(model, remove_tensors_idxs):
|
||||
"""Remove tensors from model."""
|
||||
if not remove_tensors_idxs:
|
||||
@ -612,11 +622,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
operators = subgraph.operators
|
||||
|
||||
# Find all quantize operators
|
||||
quant_opcode_idxs = []
|
||||
for idx, opcode in enumerate(model.operatorCodes):
|
||||
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
||||
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
quant_opcode_idxs.append(idx)
|
||||
quant_opcode_idxs = get_quantize_opcode_idx(model)
|
||||
if operators and not quant_opcode_idxs:
|
||||
for input_idx in subgraph.inputs:
|
||||
input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
|
||||
@ -803,6 +809,40 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
get_tf_type_name(inference_output_type)))
|
||||
|
||||
|
||||
def _remove_redundant_quantize_ops(model):
|
||||
"""Finds back to back quantize ops and remove the first quantize op."""
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
|
||||
# Find all quantize operators.
|
||||
quant_opcode_idxs = get_quantize_opcode_idx(model)
|
||||
|
||||
# Find all redundant quant tensors.
|
||||
redundant_quant_tensors = {}
|
||||
all_quant_ops = []
|
||||
for op in operators:
|
||||
if op.opcodeIndex in quant_opcode_idxs:
|
||||
all_quant_ops.append(op)
|
||||
input_tensor = tensors[op.inputs[0]]
|
||||
output_tensor = tensors[op.outputs[0]]
|
||||
input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
|
||||
output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
|
||||
# This is a requantize op, so write down its input tensor index.
|
||||
if input_type != dtypes.float32 and output_type != dtypes.float32:
|
||||
redundant_quant_tensors[op.inputs[0]] = op
|
||||
|
||||
# Remove all the quant ops which produce the redundant quant tensors.
|
||||
for op in all_quant_ops:
|
||||
if op.opcodeIndex in quant_opcode_idxs:
|
||||
output_tensor_idx = op.outputs[0]
|
||||
if output_tensor_idx in redundant_quant_tensors:
|
||||
requantize_op = redundant_quant_tensors[output_tensor_idx]
|
||||
# Reset the input of the requantize op to the float input
|
||||
requantize_op.inputs[0] = op.inputs[0]
|
||||
operators.remove(op)
|
||||
|
||||
|
||||
def modify_model_io_type(
|
||||
model, inference_input_type=dtypes.float32,
|
||||
inference_output_type=dtypes.float32):
|
||||
@ -842,4 +882,6 @@ def modify_model_io_type(
|
||||
|
||||
_modify_model_output_type(model_object, inference_output_type)
|
||||
|
||||
_remove_redundant_quantize_ops(model_object)
|
||||
|
||||
return _convert_model_from_object_to_bytearray(model_object)
|
||||
|
Loading…
x
Reference in New Issue
Block a user