Skip the dequantize op check when the tensor type is not float

PiperOrigin-RevId: 348588532
Change-Id: I9409b4b23b89c3f8f2d60d3d4c8f587bc4efda4c
This commit is contained in:
Feng Liu 2020-12-21 23:59:42 -08:00 committed by TensorFlower Gardener
parent 94155a3934
commit 3376c111e2
3 changed files with 36 additions and 13 deletions

View File

@ -1337,7 +1337,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
if calibrate_quantize:
result = self._calibrate_quantize_model(result, **flags)
if self.experimental_new_converter:
if self.experimental_new_converter or self._experimental_new_quantizer:
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
self.inference_input_type, self.inference_output_type)
if flags_modify_model_io_type:

View File

@ -652,7 +652,12 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
quant_opcode_idxs.append(idx)
if operators and not quant_opcode_idxs:
raise ValueError("Model input is not quantized.")
for input_idx in subgraph.inputs:
input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
if input_type == dtypes.float32:
raise ValueError("Model input is not dequantized.")
# None of the inputs have float32, then they must be int16, int8, or bool
return
# Validate that the model input is quantized
input_quant_ops = []
@ -663,10 +668,13 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
# If found, validate that the operator's input type is float
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
if float_type != dtypes.float32:
raise ValueError(
"Initial model input type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
if float_type == inference_input_type:
continue
else:
raise ValueError(
"Initial model input type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
# If found, validate that the operator output is quantized and compatible
# with the final model input type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
@ -737,7 +745,12 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
dequant_opcode_idxs.append(idx)
if operators and not dequant_opcode_idxs:
raise ValueError("Model output is not dequantized.")
for output in subgraph.outputs:
output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
if output_type == dtypes.float32:
raise ValueError("Model output is not dequantized.")
# None of the outputs have float32, then they must be int16, int8, or bool
return
# Validate that the model output is dequantized
output_dequant_ops = []
@ -749,10 +762,13 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
if float_type != dtypes.float32:
raise ValueError(
"Initial model output type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
if float_type == inference_output_type:
continue
else:
raise ValueError(
"Initial model output type must be tf.float32. Expected type for "
"tensor with name '{}' is tf.float32, instead type is {}".format(
float_tensor.name, _get_tf_type_name(float_type)))
# If found, validate that the operator input is quantized and compatible
# with the final model output type
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)

View File

@ -371,11 +371,18 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
model = None
# Run model inference with float input output type
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
# Run model inference with modified integer input output type
# Modify the model io types to the target input/output types.
model_io = util.modify_model_io_type(model, in_tftype, out_tftype)
# Run model inference with modified integer input output type
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
# Validate that both the outputs are the same
self.assertAllClose(output_data, output_io_data, atol=1.0)
# Validate that both the outputs are the same
# Modify the model with the target input/output types should be a no op.
model_io = util.modify_model_io_type(model_io, in_tftype, out_tftype)
# Run model inference with modified integer input output type
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
# Validate that both the outputs are the same
self.assertAllClose(output_data, output_io_data, atol=1.0)