Update modify model interface (python) interface to support int16x8 quantized models

PiperOrigin-RevId: 337204166
Change-Id: I7c3b77403cb9437270f907034f0e7918e8338e11
This commit is contained in:
Meghna Natraj 2020-10-14 17:04:08 -07:00 committed by TensorFlower Gardener
parent bd7dc164d8
commit 8822574f19
2 changed files with 98 additions and 71 deletions
tensorflow/lite/python

View File

@ -76,7 +76,10 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
_TFLITE_FILE_IDENTIFIER = b"TFL3"
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (dtypes.float32, dtypes.int8, dtypes.uint8)
_MAP_QUANT_TO_IO_TYPES = {
dtypes.int8: {dtypes.int8, dtypes.uint8},
dtypes.int16: {dtypes.int16},
}
def convert_dtype_to_tflite_type(tf_dtype):
@ -631,13 +634,6 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
if inference_input_type == dtypes.float32:
return
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError(
"Unsupported `inference_output_type` value. Expected to be in {}, "
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
_get_tf_type_name(inference_input_type)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
@ -650,25 +646,38 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
if not quant_opcode_idxs:
raise ValueError("Model input is not quantized.")
# Ensure that the model input is quantized
# Validate that the model input is quantized
input_quant_ops = []
for op in operators:
# Check if the operator quantizes an input
# Find operators that quantize model input
if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
# If found, validate the operator input/output tensor types
float_tensor, int_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
if float_tensor.type != schema_fb.TensorType.FLOAT32:
float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
# If found, validate that the operator's input type is float
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
if float_type != dtypes.float32:
raise ValueError(
"Model input type must be tf.float32. Expected type for tensor "
"with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(
_convert_tflite_enum_type_to_tf_type(float_tensor.type))))
if int_tensor.type != schema_fb.TensorType.INT8:
"Initial model input type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
# If found, validate that the operator output is quantized and compatible
# with the final model input type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
raise ValueError(
"Model input is not quantized. Expected type for tensor "
"with name '{}' is tf.int8, instead type is {}".format(
int_tensor.name, _get_tf_type_name(
_convert_tflite_enum_type_to_tf_type(int_tensor.type))))
"Initial model input is not quantized. Expected type for "
"tensor with name '{}' should be in {}, instead type is {}".format(
quant_tensor.name,
tuple(_get_tf_type_name(t) for t in
_MAP_QUANT_TO_IO_TYPES.keys()),
_get_tf_type_name(quant_type)))
else:
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
if inference_input_type not in inference_io_types:
raise ValueError(
"Unsupported `inference_input_type` value. Expected to be in "
"{}, instead got {}.".format(
tuple(_get_tf_type_name(t) for t in inference_io_types),
_get_tf_type_name(inference_input_type)))
input_quant_ops.append(op)
if len(subgraph.inputs) != len(input_quant_ops):
@ -684,7 +693,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.inputs[0]].quantization = uint8_quantization
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
elif inference_input_type == dtypes.int8:
elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
# Remove the inputs and the quant operator
remove_tensors_idxs = set()
for op in input_quant_ops:
@ -695,10 +704,8 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
_remove_tensors_from_model(model, remove_tensors_idxs)
else:
raise ValueError(
"Unsupported `inference_input_type` value. Expected to be in {}, "
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
_get_tf_type_name(inference_input_type)))
"Unsupported `inference_input_type` value {}.".format(
_get_tf_type_name(inference_input_type)))
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
@ -707,13 +714,6 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
if inference_output_type == dtypes.float32:
return
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
raise ValueError(
"Unsupported `inference_output_type` value. Expected to be in {}, "
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
_get_tf_type_name(inference_output_type)))
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
@ -726,26 +726,39 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
if not dequant_opcode_idxs:
raise ValueError("Model output is not dequantized.")
# Ensure that the model output is dequantized
# Validate that the model output is dequantized
output_dequant_ops = []
for op in operators:
# Check if the operator dequantizes an output
# Find operators that dequantize model output
if op.opcodeIndex in dequant_opcode_idxs and \
op.outputs[0] in subgraph.outputs:
# If found, validate the operator input/output tensor types
int_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
if float_tensor.type != schema_fb.TensorType.FLOAT32:
# If found, validate that the operator's output type is float
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
if float_type != dtypes.float32:
raise ValueError(
"Model output type must be tf.float32. Expected type for tensor "
"with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(
_convert_tflite_enum_type_to_tf_type(float_tensor.type))))
if int_tensor.type != schema_fb.TensorType.INT8:
"Initial model output type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
# If found, validate that the operator input is quantized and compatible
# with the final model output type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
raise ValueError(
"Model output is not dequantized. Expected type for tensor "
"with name '{}' is tf.int8, instead type is {}".format(
int_tensor.name, _get_tf_type_name(
_convert_tflite_enum_type_to_tf_type(int_tensor.type))))
"Initial model output is not dequantized. Expected type for "
"tensor with name '{}' should be in {}, instead type is {}".format(
quant_tensor.name,
tuple(_get_tf_type_name(t) for t in
_MAP_QUANT_TO_IO_TYPES.keys()),
_get_tf_type_name(quant_type)))
else:
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
if inference_output_type not in inference_io_types:
raise ValueError(
"Unsupported `inference_output_type` value. Expected to be in "
"{}, instead got {}.".format(
tuple(_get_tf_type_name(t) for t in inference_io_types),
_get_tf_type_name(inference_output_type)))
output_dequant_ops.append(op)
if len(subgraph.outputs) != len(output_dequant_ops):
@ -775,7 +788,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
tensors[op.outputs[0]].quantization = uint8_quantization
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
elif inference_output_type == dtypes.int8:
elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
# Remove the outputs and the dequant operator
remove_tensors_idxs = set()
for op in output_dequant_ops:
@ -786,10 +799,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
_remove_tensors_from_model(model, remove_tensors_idxs)
else:
raise ValueError(
"Unsupported `inference_output_type` value. Expected to be in {}, "
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
_get_tf_type_name(inference_output_type)))
"Unsupported `inference_output_type` value {}.".format(
_get_tf_type_name(inference_output_type)))
def modify_model_io_type(
@ -801,11 +812,12 @@ def modify_model_io_type(
model: A tflite model.
inference_input_type: tf.DType representing modified input type.
(default tf.float32. If model input is int8 quantized, it must be in
{tf.float32, tf.int8, tf.uint8}, else it must be tf.float32)
{tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
inference_output_type: tf.DType representing modified output type.
(default tf.float32. If model output is int8 dequantized, it must be in
{tf.float32, tf.int8, tf.uint8}, else it must be tf.float32)
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
Returns:
A tflite model with modified input/output type.

View File

@ -230,7 +230,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([None, 3, 5], tensor.shape)
def _generate_integer_tflite_model():
def _generate_integer_tflite_model(quantization_type=dtypes.int8):
"""Define an integer post-training quantized tflite model."""
# Load MNIST dataset
n = 10 # Number of samples
@ -276,7 +276,13 @@ def _generate_integer_tflite_model():
np.float32)
]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
if quantization_type == dtypes.int8:
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
else:
converter.target_spec.supported_ops = {
tf.lite.OpsSet
.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
}
tflite_model = converter.convert()
return tflite_model
@ -285,22 +291,24 @@ def _generate_integer_tflite_model():
def _test_param_modify_integer_model_io_type():
"""Function to generate parameterized inputs for testing."""
params = []
str_template = "_{}{}{}"
str_template = "_{}{}{}{}"
map_model_type = {
"PostTraining": True,
# "DuringTraining": False,
}
map_types = {
"": dtypes.float32,
"INT8": dtypes.int8,
"UINT8": dtypes.uint8,
map_quantize_type_to_io_types = {
tf.int8: {tf.float32, tf.int8, tf.uint8},
tf.int16: {tf.float32, tf.int16}
}
for k1, v1 in map_model_type.items():
for k2, v2 in map_types.items():
istr = "_Input{}".format(k2) if k2 else ""
for k3, v3 in map_types.items():
ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
for qtype, v2 in map_quantize_type_to_io_types.items():
qstr = "_IntegerQuantize{}".format(qtype.name.capitalize())
for itype in v2:
istr = "_Input{}".format(itype.name.capitalize())
for otype in v2:
ostr = "_Output{}".format(otype.name.capitalize())
params.append((str_template.format(k1, qstr, istr, ostr),
v1, qtype, itype, otype))
return params
@ -311,10 +319,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
@classmethod
def setUpClass(cls):
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
cls.post_train_integer_model = _generate_integer_tflite_model()
cls.post_train_int8_model = _generate_integer_tflite_model()
cls.post_train_int16_model = _generate_integer_tflite_model(
quantization_type=dtypes.int16)
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
def test(self, is_post_train, in_tftype, out_tftype):
def test(self, is_post_train, quantization_type, in_tftype, out_tftype):
"""Modify the float input/output type of an integer quantized model."""
def _run_tflite_inference(model, in_tftype, out_tftype):
@ -353,7 +363,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
return output_data
model = self.__class__.post_train_integer_model if is_post_train else None
if is_post_train and quantization_type == tf.int8:
model = self.__class__.post_train_int8_model
elif is_post_train and quantization_type == tf.int16:
model = self.__class__.post_train_int16_model
else:
model = None
# Run model inference with float input output type
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
# Run model inference with modified integer input output type