Skip the dequantize op check when the tensor type is not float
PiperOrigin-RevId: 348588532 Change-Id: I9409b4b23b89c3f8f2d60d3d4c8f587bc4efda4c
This commit is contained in:
parent
94155a3934
commit
3376c111e2
@ -1337,7 +1337,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
if calibrate_quantize:
|
||||
result = self._calibrate_quantize_model(result, **flags)
|
||||
|
||||
if self.experimental_new_converter:
|
||||
if self.experimental_new_converter or self._experimental_new_quantizer:
|
||||
flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
|
||||
self.inference_input_type, self.inference_output_type)
|
||||
if flags_modify_model_io_type:
|
||||
|
@ -652,7 +652,12 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
||||
quant_opcode_idxs.append(idx)
|
||||
if operators and not quant_opcode_idxs:
|
||||
raise ValueError("Model input is not quantized.")
|
||||
for input_idx in subgraph.inputs:
|
||||
input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
|
||||
if input_type == dtypes.float32:
|
||||
raise ValueError("Model input is not dequantized.")
|
||||
# None of the inputs have float32, then they must be int16, int8, or bool
|
||||
return
|
||||
|
||||
# Validate that the model input is quantized
|
||||
input_quant_ops = []
|
||||
@ -663,10 +668,13 @@ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
||||
# If found, validate that the operator's input type is float
|
||||
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
||||
if float_type != dtypes.float32:
|
||||
raise ValueError(
|
||||
"Initial model input type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
if float_type == inference_input_type:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
"Initial model input type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
# If found, validate that the operator output is quantized and compatible
|
||||
# with the final model input type
|
||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||
@ -737,7 +745,12 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
|
||||
dequant_opcode_idxs.append(idx)
|
||||
if operators and not dequant_opcode_idxs:
|
||||
raise ValueError("Model output is not dequantized.")
|
||||
for output in subgraph.outputs:
|
||||
output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
|
||||
if output_type == dtypes.float32:
|
||||
raise ValueError("Model output is not dequantized.")
|
||||
# None of the outputs have float32, then they must be int16, int8, or bool
|
||||
return
|
||||
|
||||
# Validate that the model output is dequantized
|
||||
output_dequant_ops = []
|
||||
@ -749,10 +762,13 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
||||
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
||||
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
||||
if float_type != dtypes.float32:
|
||||
raise ValueError(
|
||||
"Initial model output type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
if float_type == inference_output_type:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
"Initial model output type must be tf.float32. Expected type for "
|
||||
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
||||
float_tensor.name, _get_tf_type_name(float_type)))
|
||||
# If found, validate that the operator input is quantized and compatible
|
||||
# with the final model output type
|
||||
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
||||
|
@ -371,11 +371,18 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
|
||||
model = None
|
||||
# Run model inference with float input output type
|
||||
output_data = _run_tflite_inference(model, tf.float32, tf.float32)
|
||||
# Run model inference with modified integer input output type
|
||||
# Modify the model io types to the target input/output types.
|
||||
model_io = util.modify_model_io_type(model, in_tftype, out_tftype)
|
||||
# Run model inference with modified integer input output type
|
||||
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
|
||||
# Validate that both the outputs are the same
|
||||
self.assertAllClose(output_data, output_io_data, atol=1.0)
|
||||
|
||||
# Validate that both the outputs are the same
|
||||
# Modify the model with the target input/output types should be a no op.
|
||||
model_io = util.modify_model_io_type(model_io, in_tftype, out_tftype)
|
||||
# Run model inference with modified integer input output type
|
||||
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
|
||||
# Validate that both the outputs are the same
|
||||
self.assertAllClose(output_data, output_io_data, atol=1.0)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user