Fix quant/dequant op pair removal logic to work correctly.
PiperOrigin-RevId: 360364811 Change-Id: I60cbda1f908535ca6ba1a5a86700b140767e2524
This commit is contained in:
parent
f09da909a8
commit
a0825573fb
@ -839,9 +839,8 @@ def _remove_redundant_quantize_ops(model):
|
||||
# This is a requantize op, so write down its input tensor index.
|
||||
if input_type != dtypes.float32 and output_type != dtypes.float32:
|
||||
redundant_quant_tensors[op.inputs[0]] = op
|
||||
elif (op.opcodeIndex in dequant_opcode_idxs and
|
||||
op.outputs[0] in subgraph.outputs):
|
||||
# Mark quant-dequant op pairs right before outputs to be removed.
|
||||
if op.opcodeIndex in dequant_opcode_idxs and \
|
||||
op.outputs[0] in subgraph.outputs:
|
||||
output_dequant_tensors[op.inputs[0]] = op
|
||||
|
||||
# Remove all the quant ops which produce the redundant quant tensors.
|
||||
@ -853,13 +852,12 @@ def _remove_redundant_quantize_ops(model):
|
||||
requantize_op.inputs[0] = op.inputs[0]
|
||||
operators.remove(op)
|
||||
|
||||
# Remove all the quant/dequant op pairs right before the outputs.
|
||||
# Remove all the quant ops which connect to the output dequant op.
|
||||
for op in all_quant_ops:
|
||||
output_tensor_idx = op.outputs[0]
|
||||
if output_tensor_idx in output_dequant_tensors:
|
||||
dequant_op = output_dequant_tensors[output_tensor_idx]
|
||||
output_idx = subgraph.outputs.index(dequant_op.outputs[0])
|
||||
subgraph.outputs[output_idx] = op.inputs[0]
|
||||
subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
|
||||
operators.remove(op)
|
||||
operators.remove(dequant_op)
|
||||
|
||||
@ -867,7 +865,7 @@ def _remove_redundant_quantize_ops(model):
|
||||
def modify_model_io_type(
|
||||
model, inference_input_type=dtypes.float32,
|
||||
inference_output_type=dtypes.float32):
|
||||
"""Modifies the input/output type of a tflite model.
|
||||
"""Modify the input/output type of a tflite model.
|
||||
|
||||
Args:
|
||||
model: A tflite model.
|
||||
@ -879,7 +877,6 @@ def modify_model_io_type(
|
||||
(default tf.float32. If model output is int8 dequantized, it must be in
|
||||
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
|
||||
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
||||
|
||||
Returns:
|
||||
A tflite model with modified input/output type.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user