Update modify model interface (python) interface to support int16x8 quantized models
PiperOrigin-RevId: 337204166 Change-Id: I7c3b77403cb9437270f907034f0e7918e8338e11
This commit is contained in:
parent
bd7dc164d8
commit
8822574f19
@ -76,7 +76,10 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
|||||||
|
|
||||||
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||||
|
|
||||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (dtypes.float32, dtypes.int8, dtypes.uint8)
|
_MAP_QUANT_TO_IO_TYPES = {
|
||||||
|
dtypes.int8: {dtypes.int8, dtypes.uint8},
|
||||||
|
dtypes.int16: {dtypes.int16},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def convert_dtype_to_tflite_type(tf_dtype):
|
def convert_dtype_to_tflite_type(tf_dtype):
|
||||||
@ -631,13 +634,6 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
if inference_input_type == dtypes.float32:
|
if inference_input_type == dtypes.float32:
|
||||||
return
|
return
|
||||||
|
|
||||||
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported `inference_output_type` value. Expected to be in {}, "
|
|
||||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
|
||||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
|
||||||
_get_tf_type_name(inference_input_type)))
|
|
||||||
|
|
||||||
subgraph = model.subgraphs[0]
|
subgraph = model.subgraphs[0]
|
||||||
tensors = subgraph.tensors
|
tensors = subgraph.tensors
|
||||||
operators = subgraph.operators
|
operators = subgraph.operators
|
||||||
@ -650,25 +646,38 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
if not quant_opcode_idxs:
|
if not quant_opcode_idxs:
|
||||||
raise ValueError("Model input is not quantized.")
|
raise ValueError("Model input is not quantized.")
|
||||||
|
|
||||||
# Ensure that the model input is quantized
|
# Validate that the model input is quantized
|
||||||
input_quant_ops = []
|
input_quant_ops = []
|
||||||
for op in operators:
|
for op in operators:
|
||||||
# Check if the operator quantizes an input
|
# Find operators that quantize model input
|
||||||
if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
|
if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
|
||||||
# If found, validate the operator input/output tensor types
|
float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||||
float_tensor, int_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
# If found, validate that the operator's input type is float
|
||||||
if float_tensor.type != schema_fb.TensorType.FLOAT32:
|
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
||||||
|
if float_type != dtypes.float32:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model input type must be tf.float32. Expected type for tensor "
|
"Initial model input type must be tf.float32. Expected type for "
|
||||||
"with name '{}' is tf.float32, instead type is {}".format(
|
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||||
float_tensor.name, _get_tf_type_name(
|
float_tensor.name, _get_tf_type_name(float_type)))
|
||||||
_convert_tflite_enum_type_to_tf_type(float_tensor.type))))
|
# If found, validate that the operator output is quantized and compatible
|
||||||
if int_tensor.type != schema_fb.TensorType.INT8:
|
# with the final model input type
|
||||||
|
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||||
|
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model input is not quantized. Expected type for tensor "
|
"Initial model input is not quantized. Expected type for "
|
||||||
"with name '{}' is tf.int8, instead type is {}".format(
|
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||||
int_tensor.name, _get_tf_type_name(
|
quant_tensor.name,
|
||||||
_convert_tflite_enum_type_to_tf_type(int_tensor.type))))
|
tuple(_get_tf_type_name(t) for t in
|
||||||
|
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||||
|
_get_tf_type_name(quant_type)))
|
||||||
|
else:
|
||||||
|
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||||
|
if inference_input_type not in inference_io_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported `inference_input_type` value. Expected to be in "
|
||||||
|
"{}, instead got {}.".format(
|
||||||
|
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
||||||
|
_get_tf_type_name(inference_input_type)))
|
||||||
input_quant_ops.append(op)
|
input_quant_ops.append(op)
|
||||||
|
|
||||||
if len(subgraph.inputs) != len(input_quant_ops):
|
if len(subgraph.inputs) != len(input_quant_ops):
|
||||||
@ -684,7 +693,7 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
||||||
tensors[op.inputs[0]].quantization = uint8_quantization
|
tensors[op.inputs[0]].quantization = uint8_quantization
|
||||||
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
|
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
|
||||||
elif inference_input_type == dtypes.int8:
|
elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
|
||||||
# Remove the inputs and the quant operator
|
# Remove the inputs and the quant operator
|
||||||
remove_tensors_idxs = set()
|
remove_tensors_idxs = set()
|
||||||
for op in input_quant_ops:
|
for op in input_quant_ops:
|
||||||
@ -695,10 +704,8 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
|||||||
_remove_tensors_from_model(model, remove_tensors_idxs)
|
_remove_tensors_from_model(model, remove_tensors_idxs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported `inference_input_type` value. Expected to be in {}, "
|
"Unsupported `inference_input_type` value {}.".format(
|
||||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
_get_tf_type_name(inference_input_type)))
|
||||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
|
||||||
_get_tf_type_name(inference_input_type)))
|
|
||||||
|
|
||||||
|
|
||||||
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||||
@ -707,13 +714,6 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
if inference_output_type == dtypes.float32:
|
if inference_output_type == dtypes.float32:
|
||||||
return
|
return
|
||||||
|
|
||||||
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported `inference_output_type` value. Expected to be in {}, "
|
|
||||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
|
||||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
|
||||||
_get_tf_type_name(inference_output_type)))
|
|
||||||
|
|
||||||
subgraph = model.subgraphs[0]
|
subgraph = model.subgraphs[0]
|
||||||
tensors = subgraph.tensors
|
tensors = subgraph.tensors
|
||||||
operators = subgraph.operators
|
operators = subgraph.operators
|
||||||
@ -726,26 +726,39 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
if not dequant_opcode_idxs:
|
if not dequant_opcode_idxs:
|
||||||
raise ValueError("Model output is not dequantized.")
|
raise ValueError("Model output is not dequantized.")
|
||||||
|
|
||||||
# Ensure that the model output is dequantized
|
# Validate that the model output is dequantized
|
||||||
output_dequant_ops = []
|
output_dequant_ops = []
|
||||||
for op in operators:
|
for op in operators:
|
||||||
# Check if the operator dequantizes an output
|
# Find operators that dequantize model output
|
||||||
if op.opcodeIndex in dequant_opcode_idxs and \
|
if op.opcodeIndex in dequant_opcode_idxs and \
|
||||||
op.outputs[0] in subgraph.outputs:
|
op.outputs[0] in subgraph.outputs:
|
||||||
# If found, validate the operator input/output tensor types
|
# If found, validate that the operator's output type is float
|
||||||
int_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||||
if float_tensor.type != schema_fb.TensorType.FLOAT32:
|
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
||||||
|
if float_type != dtypes.float32:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model output type must be tf.float32. Expected type for tensor "
|
"Initial model output type must be tf.float32. Expected type for "
|
||||||
"with name '{}' is tf.float32, instead type is {}".format(
|
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||||
float_tensor.name, _get_tf_type_name(
|
float_tensor.name, _get_tf_type_name(float_type)))
|
||||||
_convert_tflite_enum_type_to_tf_type(float_tensor.type))))
|
# If found, validate that the operator input is quantized and compatible
|
||||||
if int_tensor.type != schema_fb.TensorType.INT8:
|
# with the final model output type
|
||||||
|
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||||
|
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model output is not dequantized. Expected type for tensor "
|
"Initial model output is not dequantized. Expected type for "
|
||||||
"with name '{}' is tf.int8, instead type is {}".format(
|
"tensor with name '{}' should be in {}, instead type is {}".format(
|
||||||
int_tensor.name, _get_tf_type_name(
|
quant_tensor.name,
|
||||||
_convert_tflite_enum_type_to_tf_type(int_tensor.type))))
|
tuple(_get_tf_type_name(t) for t in
|
||||||
|
_MAP_QUANT_TO_IO_TYPES.keys()),
|
||||||
|
_get_tf_type_name(quant_type)))
|
||||||
|
else:
|
||||||
|
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
||||||
|
if inference_output_type not in inference_io_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported `inference_output_type` value. Expected to be in "
|
||||||
|
"{}, instead got {}.".format(
|
||||||
|
tuple(_get_tf_type_name(t) for t in inference_io_types),
|
||||||
|
_get_tf_type_name(inference_output_type)))
|
||||||
output_dequant_ops.append(op)
|
output_dequant_ops.append(op)
|
||||||
|
|
||||||
if len(subgraph.outputs) != len(output_dequant_ops):
|
if len(subgraph.outputs) != len(output_dequant_ops):
|
||||||
@ -775,7 +788,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
||||||
tensors[op.outputs[0]].quantization = uint8_quantization
|
tensors[op.outputs[0]].quantization = uint8_quantization
|
||||||
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
|
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
|
||||||
elif inference_output_type == dtypes.int8:
|
elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
|
||||||
# Remove the outputs and the dequant operator
|
# Remove the outputs and the dequant operator
|
||||||
remove_tensors_idxs = set()
|
remove_tensors_idxs = set()
|
||||||
for op in output_dequant_ops:
|
for op in output_dequant_ops:
|
||||||
@ -786,10 +799,8 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
|||||||
_remove_tensors_from_model(model, remove_tensors_idxs)
|
_remove_tensors_from_model(model, remove_tensors_idxs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported `inference_output_type` value. Expected to be in {}, "
|
"Unsupported `inference_output_type` value {}.".format(
|
||||||
"instead got {}.".format(tuple(_get_tf_type_name(t) for t in
|
_get_tf_type_name(inference_output_type)))
|
||||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES),
|
|
||||||
_get_tf_type_name(inference_output_type)))
|
|
||||||
|
|
||||||
|
|
||||||
def modify_model_io_type(
|
def modify_model_io_type(
|
||||||
@ -801,11 +812,12 @@ def modify_model_io_type(
|
|||||||
model: A tflite model.
|
model: A tflite model.
|
||||||
inference_input_type: tf.DType representing modified input type.
|
inference_input_type: tf.DType representing modified input type.
|
||||||
(default tf.float32. If model input is int8 quantized, it must be in
|
(default tf.float32. If model input is int8 quantized, it must be in
|
||||||
{tf.float32, tf.int8, tf.uint8}, else it must be tf.float32)
|
{tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
|
||||||
|
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
||||||
inference_output_type: tf.DType representing modified output type.
|
inference_output_type: tf.DType representing modified output type.
|
||||||
(default tf.float32. If model output is int8 dequantized, it must be in
|
(default tf.float32. If model output is int8 dequantized, it must be in
|
||||||
{tf.float32, tf.int8, tf.uint8}, else it must be tf.float32)
|
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
|
||||||
|
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
||||||
Returns:
|
Returns:
|
||||||
A tflite model with modified input/output type.
|
A tflite model with modified input/output type.
|
||||||
|
|
||||||
|
@ -230,7 +230,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
def _generate_integer_tflite_model():
|
def _generate_integer_tflite_model(quantization_type=dtypes.int8):
|
||||||
"""Define an integer post-training quantized tflite model."""
|
"""Define an integer post-training quantized tflite model."""
|
||||||
# Load MNIST dataset
|
# Load MNIST dataset
|
||||||
n = 10 # Number of samples
|
n = 10 # Number of samples
|
||||||
@ -276,7 +276,13 @@ def _generate_integer_tflite_model():
|
|||||||
np.float32)
|
np.float32)
|
||||||
]
|
]
|
||||||
converter.representative_dataset = representative_dataset_gen
|
converter.representative_dataset = representative_dataset_gen
|
||||||
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
|
if quantization_type == dtypes.int8:
|
||||||
|
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
|
||||||
|
else:
|
||||||
|
converter.target_spec.supported_ops = {
|
||||||
|
tf.lite.OpsSet
|
||||||
|
.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||||
|
}
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
return tflite_model
|
return tflite_model
|
||||||
@ -285,22 +291,24 @@ def _generate_integer_tflite_model():
|
|||||||
def _test_param_modify_integer_model_io_type():
|
def _test_param_modify_integer_model_io_type():
|
||||||
"""Function to generate parameterized inputs for testing."""
|
"""Function to generate parameterized inputs for testing."""
|
||||||
params = []
|
params = []
|
||||||
str_template = "_{}{}{}"
|
str_template = "_{}{}{}{}"
|
||||||
map_model_type = {
|
map_model_type = {
|
||||||
"PostTraining": True,
|
"PostTraining": True,
|
||||||
# "DuringTraining": False,
|
# "DuringTraining": False,
|
||||||
}
|
}
|
||||||
map_types = {
|
map_quantize_type_to_io_types = {
|
||||||
"": dtypes.float32,
|
tf.int8: {tf.float32, tf.int8, tf.uint8},
|
||||||
"INT8": dtypes.int8,
|
tf.int16: {tf.float32, tf.int16}
|
||||||
"UINT8": dtypes.uint8,
|
|
||||||
}
|
}
|
||||||
for k1, v1 in map_model_type.items():
|
for k1, v1 in map_model_type.items():
|
||||||
for k2, v2 in map_types.items():
|
for qtype, v2 in map_quantize_type_to_io_types.items():
|
||||||
istr = "_Input{}".format(k2) if k2 else ""
|
qstr = "_IntegerQuantize{}".format(qtype.name.capitalize())
|
||||||
for k3, v3 in map_types.items():
|
for itype in v2:
|
||||||
ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
|
istr = "_Input{}".format(itype.name.capitalize())
|
||||||
params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
|
for otype in v2:
|
||||||
|
ostr = "_Output{}".format(otype.name.capitalize())
|
||||||
|
params.append((str_template.format(k1, qstr, istr, ostr),
|
||||||
|
v1, qtype, itype, otype))
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@ -311,10 +319,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
|
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
|
||||||
cls.post_train_integer_model = _generate_integer_tflite_model()
|
cls.post_train_int8_model = _generate_integer_tflite_model()
|
||||||
|
cls.post_train_int16_model = _generate_integer_tflite_model(
|
||||||
|
quantization_type=dtypes.int16)
|
||||||
|
|
||||||
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
|
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
|
||||||
def test(self, is_post_train, in_tftype, out_tftype):
|
def test(self, is_post_train, quantization_type, in_tftype, out_tftype):
|
||||||
"""Modify the float input/output type of an integer quantized model."""
|
"""Modify the float input/output type of an integer quantized model."""
|
||||||
|
|
||||||
def _run_tflite_inference(model, in_tftype, out_tftype):
|
def _run_tflite_inference(model, in_tftype, out_tftype):
|
||||||
@ -353,7 +363,12 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
|
|||||||
|
|
||||||
return output_data
|
return output_data
|
||||||
|
|
||||||
model = self.__class__.post_train_integer_model if is_post_train else None
|
if is_post_train and quantization_type == tf.int8:
|
||||||
|
model = self.__class__.post_train_int8_model
|
||||||
|
elif is_post_train and quantization_type == tf.int16:
|
||||||
|
model = self.__class__.post_train_int16_model
|
||||||
|
else:
|
||||||
|
model = None
|
||||||
# Run model inference with float input output type
|
# Run model inference with float input output type
|
||||||
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
|
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
|
||||||
# Run model inference with modified integer input output type
|
# Run model inference with modified integer input output type
|
||||||
|
Loading…
x
Reference in New Issue
Block a user