Support integer input and output type for Quantize-Aware Trained models
PiperOrigin-RevId: 322658564 Change-Id: I388d625fe22df0099dc2ed5a5e87db30a4a9d647
This commit is contained in:
parent
e3be70aa9d
commit
a0c12335d3
@ -56,6 +56,8 @@
|
||||
* `tf.lite`:
|
||||
* Better support for ops with high-dimensional broadcasting inputs by adding
|
||||
`BroadcastTo` ops when necessary.
|
||||
* `TFLiteConverter`:
|
||||
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (tf.int8, tf.uint8) instead of defaulting to float type (tf.float32).
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Math and Linear Algebra:
|
||||
@ -68,7 +70,7 @@
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more context.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
|
@ -212,8 +212,11 @@ py_library(
|
||||
deps = [
|
||||
":lite_constants",
|
||||
":op_hint",
|
||||
":schema_py",
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
"//tensorflow/python/eager:wrap_function",
|
||||
"@absl_py//absl/logging",
|
||||
"@flatbuffers//:runtime_py",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
@ -224,12 +227,24 @@ py_test(
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_mac",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":lite_constants",
|
||||
":util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:convert_to_constants",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -61,6 +61,7 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
|
||||
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
|
||||
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
|
||||
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
|
||||
from tensorflow.lite.python.util import modify_integer_quantized_model_io_type as _modify_integer_quantized_model_io_type
|
||||
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
|
||||
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
|
||||
from tensorflow.python import keras as _keras
|
||||
@ -314,6 +315,23 @@ class QuantizationMode(object):
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def flags_modify_model_io_type(
|
||||
self, input_type=constants.FLOAT, output_type=constants.FLOAT):
|
||||
"""Flags for modifying the input and output type of a tflite model."""
|
||||
is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
|
||||
is_training_time_only_quantize = self.training_time_int8_allow_float() and \
|
||||
not is_post_training_quantize
|
||||
|
||||
# TODO(b/153576658): Consolidate post/during training quantization workflows
|
||||
# to modify model input/output type after MLIR conversion.
|
||||
if is_training_time_only_quantize:
|
||||
return {
|
||||
"inference_input_type": input_type,
|
||||
"inference_output_type": output_type,
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# Below are helpers for the above functions.
|
||||
|
||||
def _validate_int8_required(self):
|
||||
@ -557,9 +575,8 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
def _validate_inference_input_output_types(self, quant_mode):
|
||||
"""Validate inference_input_type and inference_output_type flags."""
|
||||
default_types = [constants.FLOAT, None]
|
||||
# We only support integer types for post training integer quantization
|
||||
# as we have statistical information to quantize the input and output.
|
||||
if quant_mode.is_post_training_integer_quantize():
|
||||
# We support integer input/output for integer quantized models only.
|
||||
if quant_mode.training_time_int8_allow_float():
|
||||
all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
|
||||
if self.inference_input_type not in all_types or \
|
||||
self.inference_output_type not in all_types:
|
||||
@ -643,6 +660,12 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
|
||||
if calibrate_and_quantize:
|
||||
result = self._calibrate_quantize_model(result, **flags)
|
||||
|
||||
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
if flags_modify_model_io_type:
|
||||
result = _modify_integer_quantized_model_io_type(
|
||||
result, **flags_modify_model_io_type)
|
||||
|
||||
if self._experimental_sparsify_model:
|
||||
result = _mlir_sparsify(result)
|
||||
|
||||
|
@ -374,8 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
|
||||
('_INT8InputOutput', lite.constants.INT8),
|
||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
|
||||
@test_util.run_v2_only
|
||||
def testTrainingTimeQuantization(self):
|
||||
def testTrainingTimeQuantization(self, inference_input_output_type):
|
||||
model = self._getTrainingTimeQuantizedModel()
|
||||
|
||||
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
@ -384,37 +388,24 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
self.assertEqual(np.float32, interpreter.get_input_details()[0]['dtype'])
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual(inference_input_output_type.as_numpy_dtype,
|
||||
input_details[0]['dtype'])
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertLen(output_details, 1)
|
||||
self.assertEqual(inference_input_output_type.as_numpy_dtype,
|
||||
output_details[0]['dtype'])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT8InputOutput', lite.constants.INT8),
|
||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
|
||||
def testInvalidTrainingTimeQuantization(self, inference_input_output_type):
|
||||
# We currently don't support integer inference_input_type and
|
||||
# inference_output_type flags for training time quantization.
|
||||
|
||||
model = self._getTrainingTimeQuantizedModel()
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
with self.assertRaises(ValueError) as error:
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_converter.convert()
|
||||
self.assertEqual(
|
||||
'The inference_input_type and inference_output_type '
|
||||
'must be tf.float32.', str(error.exception))
|
||||
# Ensure that the quantized tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNewQuantizer(self):
|
||||
|
@ -19,15 +19,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
from absl import logging
|
||||
|
||||
import six
|
||||
from six.moves import range
|
||||
|
||||
from flatbuffers.python import flatbuffers
|
||||
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
||||
from tensorflow.lite.python import lite_constants as _lite_constants
|
||||
from tensorflow.lite.python import schema_py_generated as _schema_fb
|
||||
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
||||
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
@ -55,6 +61,25 @@ _MAP_TF_TO_TFLITE_TYPES = {
|
||||
dtypes.bool: _types_pb2.BOOL,
|
||||
}
|
||||
|
||||
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
||||
0: dtypes.float32,
|
||||
1: dtypes.float16,
|
||||
2: dtypes.int32,
|
||||
3: dtypes.uint8,
|
||||
4: dtypes.int64,
|
||||
5: dtypes.string,
|
||||
6: dtypes.bool,
|
||||
7: dtypes.int16,
|
||||
8: dtypes.complex64,
|
||||
9: dtypes.int8,
|
||||
10: dtypes.float64,
|
||||
}
|
||||
|
||||
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
|
||||
_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8,
|
||||
_lite_constants.QUANTIZED_UINT8)
|
||||
|
||||
|
||||
def convert_dtype_to_tflite_type(tf_dtype):
|
||||
"""Converts tf.dtype to TFLite proto type.
|
||||
@ -74,6 +99,31 @@ def convert_dtype_to_tflite_type(tf_dtype):
|
||||
return result
|
||||
|
||||
|
||||
def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
|
||||
"""Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
|
||||
|
||||
Args:
|
||||
tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid tflite enum type is provided.
|
||||
|
||||
Returns:
|
||||
tf type (eg: tf.float32)
|
||||
"""
|
||||
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
|
||||
if tf_type is None:
|
||||
raise ValueError(
|
||||
"Unsupported enum {}. The valid map of enum to tf.dtypes is : {}"
|
||||
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
|
||||
return tf_type
|
||||
|
||||
|
||||
def _get_dtype_name(tf_type):
|
||||
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
|
||||
return "tf." + tf_type.name
|
||||
|
||||
|
||||
def get_tensor_name(tensor):
|
||||
"""Returns name of the input tensor.
|
||||
|
||||
@ -514,3 +564,218 @@ extern const int {array_name}_len;
|
||||
license_text=license_text)
|
||||
|
||||
return source_text, header_text
|
||||
|
||||
|
||||
def _convert_model_from_bytearray_to_object(model_bytearray):
|
||||
"""Converts a tflite model from a bytearray into a parsable object."""
|
||||
model_object = _schema_fb.Model.GetRootAsModel(model_bytearray, 0)
|
||||
model_object = _schema_fb.ModelT.InitFromObj(model_object)
|
||||
model_object = copy.deepcopy(model_object)
|
||||
model_object.subgraphs[0].inputs[0] = model_object.subgraphs[0].inputs[0]
|
||||
return model_object
|
||||
|
||||
|
||||
def _convert_model_from_object_to_bytearray(model_object):
|
||||
"""Converts a tflite model from a parsable object into a bytearray."""
|
||||
# Initial size of the buffer, which will grow automatically if needed
|
||||
builder = flatbuffers.Builder(1024)
|
||||
model_offset = model_object.Pack(builder)
|
||||
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
||||
return bytes(builder.Output())
|
||||
|
||||
|
||||
def _remove_tensors_from_model(model, remove_tensors_idxs):
|
||||
"""Remove tensors from model."""
|
||||
if not remove_tensors_idxs:
|
||||
return
|
||||
if len(model.subgraphs) > 1:
|
||||
raise ValueError("Model must only have one subgraph. Instead, it has "
|
||||
"{} subgraphs.".format(len(model.subgraphs)))
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
|
||||
logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
|
||||
# An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
|
||||
# exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
|
||||
if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
|
||||
logging.debug("Removing tensors only at the end of the tensor list")
|
||||
del tensors[min(remove_tensors_idxs):]
|
||||
else:
|
||||
logging.debug("Removing tensors requires updating the model")
|
||||
# Map the old tensor indices to new tensor indices
|
||||
d_old_to_new_tensors = {}
|
||||
left_shift_by = 0
|
||||
for idx in range(len(tensors)):
|
||||
if idx in remove_tensors_idxs:
|
||||
left_shift_by += 1
|
||||
else:
|
||||
d_old_to_new_tensors[idx] = idx - left_shift_by
|
||||
logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
|
||||
# Update tensor indices referenced throughout the model
|
||||
def update_tensors(tensor_idxs):
|
||||
for i, ti in enumerate(tensor_idxs):
|
||||
tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
|
||||
update_tensors(subgraph.inputs)
|
||||
update_tensors(subgraph.outputs)
|
||||
for op in operators:
|
||||
update_tensors(op.inputs)
|
||||
update_tensors(op.outputs)
|
||||
# Delete the tensors
|
||||
for idx in sorted(remove_tensors_idxs, reverse=True):
|
||||
tensors.pop(idx)
|
||||
logging.debug("Removed tensors marked for deletion")
|
||||
|
||||
|
||||
def _validate_and_find_int8_quantized_inputs_outputs(model):
|
||||
"""Validate that model input is quantized and output is dequantized."""
|
||||
if len(model.subgraphs) > 1:
|
||||
raise ValueError("Model must only have one subgraph. Instead, it has "
|
||||
"{} subgraphs.".format(len(model.subgraphs)))
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
|
||||
# Ensure model has atleast one quantize and dequantize operator
|
||||
quant_opcode_idx, dequant_opcode_idx = None, None
|
||||
for idx, opcode in enumerate(model.operatorCodes):
|
||||
if opcode.builtinCode == _schema_fb.BuiltinOperator.QUANTIZE:
|
||||
quant_opcode_idx = idx
|
||||
elif opcode.builtinCode == _schema_fb.BuiltinOperator.DEQUANTIZE:
|
||||
dequant_opcode_idx = idx
|
||||
if quant_opcode_idx is not None and dequant_opcode_idx is not None:
|
||||
break
|
||||
if quant_opcode_idx is None and dequant_opcode_idx is None:
|
||||
raise ValueError("Model is not integer quantized as it does not "
|
||||
"contain quantize/dequantize operators.")
|
||||
|
||||
# Ensure model inputs and outputs are integer quantized
|
||||
input_quant_ops, output_dequant_ops = [], []
|
||||
for op in operators:
|
||||
# Find input quantize operator
|
||||
if op.opcodeIndex == quant_opcode_idx and op.inputs[0] in subgraph.inputs:
|
||||
pos, float_tensor, int_tensor = \
|
||||
"input", tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||
input_quant_ops.append(op)
|
||||
# Find output dequantize operator
|
||||
elif op.opcodeIndex == dequant_opcode_idx and \
|
||||
op.outputs[0] in subgraph.outputs:
|
||||
pos, float_tensor, int_tensor = \
|
||||
"output", tensors[op.outputs[0]], tensors[op.inputs[0]]
|
||||
output_dequant_ops.append(op)
|
||||
# Otherwise, ignore
|
||||
else:
|
||||
continue
|
||||
# If found, validate the input/output tensor type
|
||||
if float_tensor.type != _schema_fb.TensorType.FLOAT32:
|
||||
raise ValueError(
|
||||
"Model {} type must be tf.float32. Expected type for tensor with "
|
||||
"name '{}' is tf.float32, instead type is tf.{}".format(
|
||||
pos, float_tensor.name,
|
||||
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name))
|
||||
if int_tensor.type != _schema_fb.TensorType.INT8:
|
||||
raise ValueError(
|
||||
"Model is not integer quantized. Expected type for tensor with "
|
||||
"name '{}' is tf.int8, instead type is tf.{}".format(
|
||||
int_tensor.name,
|
||||
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name))
|
||||
|
||||
return input_quant_ops, output_dequant_ops
|
||||
|
||||
|
||||
def modify_integer_quantized_model_io_type(
|
||||
model, inference_input_type=_lite_constants.FLOAT,
|
||||
inference_output_type=_lite_constants.FLOAT):
|
||||
"""Modify the float input/output type of an integer quantized model.
|
||||
|
||||
Args:
|
||||
model: An int8 quantized tflite model with float input and output.
|
||||
inference_input_type: tf.DType representing final input type.
|
||||
(default tf.float32)
|
||||
inference_output_type: tf.DType representing final output type.
|
||||
(default tf.float32)
|
||||
|
||||
Returns:
|
||||
An int8 quantized tflite model with modified input and/or output type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not int8 quantized or the inference_input_type
|
||||
and/or inference_input_type is unsupported.
|
||||
RuntimeError: If the modification was unsuccessful.
|
||||
|
||||
"""
|
||||
# Return if input and output types default to float
|
||||
if inference_input_type == _lite_constants.FLOAT and \
|
||||
inference_output_type == _lite_constants.FLOAT:
|
||||
return model
|
||||
|
||||
# Validate input and output types
|
||||
if inference_input_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
|
||||
raise ValueError("The `inference_input_type` should be in {}".format(
|
||||
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
|
||||
if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
|
||||
raise ValueError("The `inference_output_type` should be in {}".format(
|
||||
tuple(_get_dtype_name(t) for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES)))
|
||||
|
||||
logging.debug(("Attempting to modify the model input from tf.float32 to %s "
|
||||
"and output from tf.float32 to %s"),
|
||||
_get_dtype_name(inference_input_type),
|
||||
_get_dtype_name(inference_output_type))
|
||||
# Convert the model to an object
|
||||
model = _convert_model_from_bytearray_to_object(model)
|
||||
|
||||
# Validate the integer quantized model
|
||||
input_quant_ops, output_dequant_ops = \
|
||||
_validate_and_find_int8_quantized_inputs_outputs(model)
|
||||
|
||||
# Initialize references and variables
|
||||
if len(model.subgraphs) > 1:
|
||||
raise ValueError("Model must only have one subgraph. Instead, it has "
|
||||
"{} subgraphs.".format(len(model.subgraphs)))
|
||||
subgraph = model.subgraphs[0]
|
||||
tensors = subgraph.tensors
|
||||
operators = subgraph.operators
|
||||
remove_tensors_idxs = set()
|
||||
|
||||
# Modify model input type
|
||||
if inference_input_type == _lite_constants.QUANTIZED_UINT8:
|
||||
# Change quant op (float to int8) to quant op (uint8 to int8)
|
||||
for op in input_quant_ops:
|
||||
int8_quantization = tensors[op.outputs[0]].quantization
|
||||
uint8_quantization = _schema_fb.QuantizationParametersT()
|
||||
uint8_quantization.scale = [int8_quantization.scale[0]]
|
||||
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
||||
tensors[op.inputs[0]].quantization = uint8_quantization
|
||||
tensors[op.inputs[0]].type = _schema_fb.TensorType.UINT8
|
||||
elif inference_input_type == _lite_constants.INT8:
|
||||
# Remove the inputs and the quant operator
|
||||
for op in input_quant_ops:
|
||||
subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
|
||||
remove_tensors_idxs.add(op.inputs[0])
|
||||
operators.remove(op)
|
||||
|
||||
# Modify model output type
|
||||
if inference_output_type == _lite_constants.QUANTIZED_UINT8:
|
||||
# Change dequant op (int8 to float) to quant op (int8 to uint8)
|
||||
for op in output_dequant_ops:
|
||||
op.opcodeIndex = input_quant_ops[0].opcodeIndex
|
||||
int8_quantization = tensors[op.inputs[0]].quantization
|
||||
uint8_quantization = _schema_fb.QuantizationParametersT()
|
||||
uint8_quantization.scale = [int8_quantization.scale[0]]
|
||||
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
||||
tensors[op.outputs[0]].quantization = uint8_quantization
|
||||
tensors[op.outputs[0]].type = _schema_fb.TensorType.UINT8
|
||||
elif inference_output_type == _lite_constants.INT8:
|
||||
# Remove the outputs and the dequant operator
|
||||
for op in output_dequant_ops:
|
||||
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
|
||||
remove_tensors_idxs.add(op.outputs[0])
|
||||
operators.remove(op)
|
||||
|
||||
# Remove tensors marked for deletion.
|
||||
_remove_tensors_from_model(model, remove_tensors_idxs)
|
||||
|
||||
# Convert the model to a bytearray
|
||||
model = _convert_model_from_object_to_bytearray(model)
|
||||
|
||||
return model
|
||||
|
@ -19,7 +19,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six.moves import range
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python import util
|
||||
@ -61,6 +64,31 @@ class UtilTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
|
||||
|
||||
def testConvertEnumToDtype(self):
|
||||
self.assertEqual(
|
||||
util._convert_tflite_enum_type_to_tf_type(0), dtypes.float32)
|
||||
self.assertEqual(
|
||||
util._convert_tflite_enum_type_to_tf_type(1), dtypes.float16)
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(2), dtypes.int32)
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(3), dtypes.uint8)
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(4), dtypes.int64)
|
||||
self.assertEqual(
|
||||
util._convert_tflite_enum_type_to_tf_type(5), dtypes.string)
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(6), dtypes.bool)
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(7), dtypes.int16)
|
||||
self.assertEqual(
|
||||
util._convert_tflite_enum_type_to_tf_type(8), dtypes.complex64)
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8)
|
||||
self.assertEqual(
|
||||
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
util._convert_tflite_enum_type_to_tf_type(11)
|
||||
self.assertEqual(
|
||||
"Unsupported enum 11. The valid map of enum to tf.dtypes is : "
|
||||
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
|
||||
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
|
||||
"10: tf.float64}", str(error.exception))
|
||||
|
||||
def testTensorName(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
|
||||
@ -195,5 +223,140 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
|
||||
|
||||
def _generate_integer_tflite_model():
|
||||
"""Define an integer post-training quantized tflite model."""
|
||||
# Load MNIST dataset
|
||||
n = 10 # Number of samples
|
||||
(train_images, train_labels), (test_images, test_labels) = \
|
||||
tf.keras.datasets.mnist.load_data()
|
||||
train_images, train_labels, test_images, test_labels = \
|
||||
train_images[:n], train_labels[:n], test_images[:n], test_labels[:n]
|
||||
|
||||
# Normalize the input image so that each pixel value is between 0 to 1.
|
||||
train_images = train_images / 255.0
|
||||
test_images = test_images / 255.0
|
||||
|
||||
# Define TF model
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.InputLayer(input_shape=(28, 28)),
|
||||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
|
||||
tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"),
|
||||
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(10)
|
||||
])
|
||||
|
||||
# Train
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
metrics=["accuracy"])
|
||||
|
||||
model.fit(
|
||||
train_images,
|
||||
train_labels,
|
||||
epochs=1,
|
||||
validation_split=0.1,
|
||||
)
|
||||
|
||||
# Convert TF Model to an Integer Quantized TFLite Model
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.optimizations = {tf.lite.Optimize.DEFAULT}
|
||||
def representative_dataset_gen():
|
||||
for _ in range(2):
|
||||
yield [
|
||||
np.random.uniform(low=0, high=1, size=(1, 28, 28)).astype(
|
||||
np.float32)
|
||||
]
|
||||
converter.representative_dataset = representative_dataset_gen
|
||||
converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8}
|
||||
tflite_model = converter.convert()
|
||||
|
||||
return tflite_model
|
||||
|
||||
|
||||
def _test_param_modify_integer_model_io_type():
|
||||
"""Function to generate parameterized inputs for testing."""
|
||||
params = []
|
||||
str_template = "_{}{}{}"
|
||||
map_model_type = {
|
||||
"PostTraining": True,
|
||||
# "DuringTraining": False,
|
||||
}
|
||||
map_types = {
|
||||
"": lite_constants.FLOAT,
|
||||
"INT8": lite_constants.INT8,
|
||||
"UINT8": lite_constants.QUANTIZED_UINT8
|
||||
}
|
||||
for k1, v1 in map_model_type.items():
|
||||
for k2, v2 in map_types.items():
|
||||
istr = "_Input{}".format(k2) if k2 else ""
|
||||
for k3, v3 in map_types.items():
|
||||
ostr = "_Output{}".format(k3) if k3 else "" if istr else "_NoUpdate"
|
||||
params.append((str_template.format(k1, istr, ostr), v1, v2, v3))
|
||||
return params
|
||||
|
||||
|
||||
# TODO(b/161174063): Merge tests for integer input/output type
|
||||
class UtilModifyIntegerQuantizedModelIOTypeTest(
|
||||
test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(UtilModifyIntegerQuantizedModelIOTypeTest, cls).setUpClass()
|
||||
cls.post_train_integer_model = _generate_integer_tflite_model()
|
||||
|
||||
@parameterized.named_parameters(_test_param_modify_integer_model_io_type())
|
||||
def test(self, is_post_train, in_tftype, out_tftype):
|
||||
"""Modify the float input/output type of an integer quantized model."""
|
||||
|
||||
def _run_tflite_inference(model, in_tftype, out_tftype):
|
||||
"""Run inference on a model with a specific input/output type."""
|
||||
# Load TFLite model and allocate tensors.
|
||||
interpreter = tf.lite.Interpreter(model_content=model)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()[0]
|
||||
output_details = interpreter.get_output_details()[0]
|
||||
|
||||
# Validate TFLite model input and output types
|
||||
self.assertEqual(input_details["dtype"], in_tftype.as_numpy_dtype)
|
||||
self.assertEqual(output_details["dtype"], out_tftype.as_numpy_dtype)
|
||||
|
||||
# Define Input
|
||||
np.random.seed(0)
|
||||
input_data = np.random.uniform(low=0, high=1, size=(1, 28, 28))
|
||||
input_data = input_data.astype(np.float32)
|
||||
if input_details["dtype"] != np.float32:
|
||||
# quantize float to int
|
||||
scale, zero_point = input_details["quantization"]
|
||||
input_data = input_data / scale + zero_point
|
||||
input_data = input_data.astype(input_details["dtype"])
|
||||
|
||||
# Run Inference
|
||||
interpreter.set_tensor(input_details["index"], input_data)
|
||||
interpreter.invoke()
|
||||
|
||||
# Get output
|
||||
output_data = interpreter.get_tensor(output_details["index"])[0]
|
||||
if output_details["dtype"] != np.float32:
|
||||
# dequantize int to float
|
||||
scale, zero_point = output_details["quantization"]
|
||||
output_data = output_data.astype(np.float32)
|
||||
output_data = (output_data - zero_point) * scale
|
||||
|
||||
return output_data
|
||||
|
||||
model = self.__class__.post_train_integer_model if is_post_train else None
|
||||
# Run model inference with float input output type
|
||||
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
|
||||
# Run model inference with modified integer input output type
|
||||
model_io = util.modify_integer_quantized_model_io_type(
|
||||
model, in_tftype, out_tftype)
|
||||
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
|
||||
|
||||
# Validate that both the outputs are the same
|
||||
self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user