Update modify model interface (python) interface to support int16x8 quantized models
PiperOrigin-RevId: 337204166 Change-Id: I7c3b77403cb9437270f907034f0e7918e8338e11
This commit is contained in:
parent
bd7dc164d8
commit
8822574f19
tensorflow/lite/python
@ -76,7 +76,10 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
||||
|
||||
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
|
||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (dtypes.float32, dtypes.int8, dtypes.uint8)
|
||||
_MAP_QUANT_TO_IO_TYPES = {
|
||||
dtypes.int8: {dtypes.int8, dtypes.uint8},
|
||||
dtypes.int16: {dtypes.int16},
|
||||
}
|
||||
|
||||
|
||||
def convert_dtype_to_tflite_type(tf_dtype):
|
||||
@ -631,13 +634,6 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
if inference_input_type == dtypes.float32:
|
||||
return
|
||||
|
||||
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_output_type` value. Expected to be in {}, "
|
||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
||||
_get_tf_type_name(inference_input_type)))
|
||||
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
@ -650,25 +646,38 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
if not quant_opcode_idxs:
|
||||
raise ValueError("Model input is not quantized.")
|
||||
|
||||
# Ensure that the model input is quantized
|
||||
# Validate that the model input is quantized
|
||||
input_quant_ops = []
|
||||
for op in operators:
|
||||
# Check if the operator quantizes an input
|
||||
# Find operators that quantize model input
|
||||
if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
|
||||
# If found, validate the operator input/output tensor types
|
||||
float_tensor, int_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||
if float_tensor.type != schema_fb.TensorType.FLOAT32:
|
||||
float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||
# If found, validate that the operator's input type is float
|
||||
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
||||
if float_type != dtypes.float32:
|
||||
raise ValueError(
|
||||
"Model input type must be tf.float32. Expected type for tensor "
|
||||
"with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(
|
||||
_convert_tflite_enum_type_to_tf_type(float_tensor.type))))
|
||||
if int_tensor.type != schema_fb.TensorType.INT8:
|
||||
"Initial model input type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
# If found, validate that the operator output is quantized and compatible
|
||||
# with the final model input type
|
||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
||||
raise ValueError(
|
||||
"Model input is not quantized. Expected type for tensor "
|
||||
"with name '{}' is tf.int8, instead type is {}".format(
|
||||
int_tensor.name, _get_tf_type_name(
|
||||
_convert_tflite_enum_type_to_tf_type(int_tensor.type))))
|
||||
"Initial model input is not quantized. Expected type for "
|
||||
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||
quant_tensor.name,
|
||||
tuple(_get_tf_type_name(t) for t in
|
||||
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||
_get_tf_type_name(quant_type)))
|
||||
else:
|
||||
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||
if inference_input_type not in inference_io_types:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_input_type` value. Expected to be in "
|
||||
"{}, instead got {}.".format(
|
||||
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
||||
_get_tf_type_name(inference_input_type)))
|
||||
input_quant_ops.append(op)
|
||||
|
||||
if len(subgraph.inputs) != len(input_quant_ops):
|
||||
@ -684,7 +693,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
||||
tensors[op.inputs[0]].quantization = uint8_quantization
|
||||
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
|
||||
elif inference_input_type == dtypes.int8:
|
||||
elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
|
||||
# Remove the inputs and the quant operator
|
||||
remove_tensors_idxs = set()
|
||||
for op in input_quant_ops:
|
||||
@ -695,10 +704,8 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
_remove_tensors_from_model(model, remove_tensors_idxs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_input_type` value. Expected to be in {}, "
|
||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
||||
_get_tf_type_name(inference_input_type)))
|
||||
"Unsupported `inference_input_type` value {}.".format(
|
||||
_get_tf_type_name(inference_input_type)))
|
||||
|
||||
|
||||
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
@ -707,13 +714,6 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
if inference_output_type == dtypes.float32:
|
||||
return
|
||||
|
||||
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_output_type` value. Expected to be in {}, "
|
||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
||||
_get_tf_type_name(inference_output_type)))
|
||||
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
@ -726,26 +726,39 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
if not dequant_opcode_idxs:
|
||||
raise ValueError("Model output is not dequantized.")
|
||||
|
||||
# Ensure that the model output is dequantized
|
||||
# Validate that the model output is dequantized
|
||||
output_dequant_ops = []
|
||||
for op in operators:
|
||||
# Check if the operator dequantizes an output
|
||||
# Find operators that dequantize model output
|
||||
if op.opcodeIndex in dequant_opcode_idxs and \
|
||||
op.outputs[0] in subgraph.outputs:
|
||||
# If found, validate the operator input/output tensor types
|
||||
int_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||
if float_tensor.type != schema_fb.TensorType.FLOAT32:
|
||||
# If found, validate that the operator's output type is float
|
||||
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
||||
if float_type != dtypes.float32:
|
||||
raise ValueError(
|
||||
"Model output type must be tf.float32. Expected type for tensor "
|
||||
"with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(
|
||||
_convert_tflite_enum_type_to_tf_type(float_tensor.type))))
|
||||
if int_tensor.type != schema_fb.TensorType.INT8:
|
||||
"Initial model output type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
# If found, validate that the operator input is quantized and compatible
|
||||
# with the final model output type
|
||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
||||
raise ValueError(
|
||||
"Model output is not dequantized. Expected type for tensor "
|
||||
"with name '{}' is tf.int8, instead type is {}".format(
|
||||
int_tensor.name, _get_tf_type_name(
|
||||
_convert_tflite_enum_type_to_tf_type(int_tensor.type))))
|
||||
"Initial model output is not dequantized. Expected type for "
|
||||
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||
quant_tensor.name,
|
||||
tuple(_get_tf_type_name(t) for t in
|
||||
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||
_get_tf_type_name(quant_type)))
|
||||
else:
|
||||
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||
if inference_output_type not in inference_io_types:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_output_type` value. Expected to be in "
|
||||
"{}, instead got {}.".format(
|
||||
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
||||
_get_tf_type_name(inference_output_type)))
|
||||
output_dequant_ops.append(op)
|
||||
|
||||
if len(subgraph.outputs) != len(output_dequant_ops):
|
||||
@ -775,7 +788,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
||||
tensors[op.outputs[0]].quantization = uint8_quantization
|
||||
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
|
||||
elif inference_output_type == dtypes.int8:
|
||||
elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
|
||||
# Remove the outputs and the dequant operator
|
||||
remove_tensors_idxs = set()
|
||||
for op in output_dequant_ops:
|
||||
@ -786,10 +799,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
_remove_tensors_from_model(model, remove_tensors_idxs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported `inference_output_type` value. Expected to be in {}, "
|
||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
||||
_get_tf_type_name(inference_output_type)))
|
||||
"Unsupported `inference_output_type` value {}.".format(
|
||||
_get_tf_type_name(inference_output_type)))
|
||||
|
||||
|
||||
def modify_model_io_type(
|
||||
@ -801,11 +812,12 @@ def modify_model_io_type(
|
||||
model: A tflite model.
|
||||
inference_input_type: tf.DType representing modified input type.
|
||||
(default tf.float32. If model input is int8 quantized, it must be in
|
||||
{tf.float32, tf.int8, tf.uint8}, else it must be tf.float32)
|
||||
{tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
|
||||
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
||||
inference_output_type: tf.DType representing modified output type.
|
||||
(default tf.float32. If model output is int8 dequantized, it must be in
|
||||
{tf.float32, tf.int8, tf.uint8}, else it must be tf.float32)
|
||||
|
||||
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
|
||||
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
||||
Returns:
|
||||
A tflite model with modified input/output type.
|
||||
|
||||
|
@ -230,7 +230,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
|
||||
def _generate_integer_tflite_model():
|
||||
def _generate_integer_tflite_model(quantization_type=dtypes.int8):
|
||||
"""Define an integer post-training quantized tflite model."""
|
||||
# Load MNIST dataset
|
||||
n = 10 # Number of samples
|
||||
@ -276,7 +276,13 @@ def _generate_integer_tflite_model():
|
||||
np.float32)
|
||||
]
|
||||
converter.representative_dataset = representative_dataset_gen
|
||||
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
|
||||
if quantization_type == dtypes.int8:
|
||||
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
|
||||
else:
|
||||
converter.target_spec.supported_ops = {
|
||||
tf.lite.OpsSet
|
||||
.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||
}
|
||||
tflite_model = converter.convert()
|
||||
|
||||
return tflite_model
|
||||
@ -285,22 +291,24 @@ def _generate_integer_tflite_model():
|
||||
def _test_param_modify_integer_model_io_type():
|
||||
"""Function to generate parameterized inputs for testing."""
|
||||
params = []
|
||||
str_template = "_{}{}{}"
|
||||
str_template = "_{}{}{}{}"
|
||||
map_model_type = {
|
||||
"PostTraining": True,
|
||||
# "DuringTraining": False,
|
||||
}
|
||||
map_types = {
|
||||
"": dtypes.float32,
|
||||
"INT8": dtypes.int8,
|
||||
"UINT8": dtypes.uint8,
|
||||
map_quantize_type_to_io_types = {
|
||||
tf.int8: {tf.float32, tf.int8, tf.uint8},
|
||||
tf.int16: {tf.float32, tf.int16}
|
||||
}
|
||||
for k1, v1 in map_model_type.items():
|
||||
for k2, v2 in map_types.items():
|
||||
istr = "_Input{}".format(k2) if k2 else ""
|
||||
for k3, v3 in map_types.items():
|
||||
ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
|
||||
params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
|
||||
for qtype, v2 in map_quantize_type_to_io_types.items():
|
||||
qstr = "_IntegerQuantize{}".format(qtype.name.capitalize())
|
||||
for itype in v2:
|
||||
istr = "_Input{}".format(itype.name.capitalize())
|
||||
for otype in v2:
|
||||
ostr = "_Output{}".format(otype.name.capitalize())
|
||||
params.append((str_template.format(k1, qstr, istr, ostr),
|
||||
v1, qtype, itype, otype))
|
||||
return params
|
||||
|
||||
|
||||
@ -311,10 +319,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
|
||||
cls.post_train_integer_model = _generate_integer_tflite_model()
|
||||
cls.post_train_int8_model = _generate_integer_tflite_model()
|
||||
cls.post_train_int16_model = _generate_integer_tflite_model(
|
||||
quantization_type=dtypes.int16)
|
||||
|
||||
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
|
||||
def test(self, is_post_train, in_tftype, out_tftype):
|
||||
def test(self, is_post_train, quantization_type, in_tftype, out_tftype):
|
||||
"""Modify the float input/output type of an integer quantized model."""
|
||||
|
||||
def _run_tflite_inference(model, in_tftype, out_tftype):
|
||||
@ -353,7 +363,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
|
||||
|
||||
return output_data
|
||||
|
||||
model = self.__class__.post_train_integer_model if is_post_train else None
|
||||
if is_post_train and quantization_type == tf.int8:
|
||||
model = self.__class__.post_train_int8_model
|
||||
elif is_post_train and quantization_type == tf.int16:
|
||||
model = self.__class__.post_train_int16_model
|
||||
else:
|
||||
model = None
|
||||
# Run model inference with float input output type
|
||||
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
|
||||
# Run model inference with modified integer input output type
|
||||
|
Loading…
Reference in New Issue
Block a user