Fix quant/dequant op pair removal logic to work correctly.

PiperOrigin-RevId: 360364811
Change-Id: I60cbda1f908535ca6ba1a5a86700b140767e2524
This commit is contained in:
Taehee Jeong 2021-03-01 23:55:43 -08:00 committed by TensorFlower Gardener
parent f09da909a8
commit a0825573fb

View File

@ -839,9 +839,8 @@ def _remove_redundant_quantize_ops(model):
# This is a requantize op, so write down its input tensor index.
if input_type != dtypes.float32 and output_type != dtypes.float32:
redundant_quant_tensors[op.inputs[0]] = op
elif (op.opcodeIndex in dequant_opcode_idxs and
op.outputs[0] in subgraph.outputs):
# Mark quant-dequant op pairs right before outputs to be removed.
if op.opcodeIndex in dequant_opcode_idxs and \
op.outputs[0] in subgraph.outputs:
output_dequant_tensors[op.inputs[0]] = op
# Remove all the quant ops which produce the redundant quant tensors.
@ -853,13 +852,12 @@ def _remove_redundant_quantize_ops(model):
requantize_op.inputs[0] = op.inputs[0]
operators.remove(op)
# Remove all the quant/dequant op pairs right before the outputs.
# Remove all the quant ops which connect to the output dequant op.
for op in all_quant_ops:
output_tensor_idx = op.outputs[0]
if output_tensor_idx in output_dequant_tensors:
dequant_op = output_dequant_tensors[output_tensor_idx]
output_idx = subgraph.outputs.index(dequant_op.outputs[0])
subgraph.outputs[output_idx] = op.inputs[0]
subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
operators.remove(op)
operators.remove(dequant_op)
@ -867,7 +865,7 @@ def _remove_redundant_quantize_ops(model):
def modify_model_io_type(
model, inference_input_type=dtypes.float32,
inference_output_type=dtypes.float32):
"""Modifies the input/output type of a tflite model.
"""Modify the input/output type of a tflite model.
Args:
model: A tflite model.
@ -879,7 +877,6 @@ def modify_model_io_type(
(default tf.float32. If model output is int8 dequantized, it must be in
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
Returns:
A tflite model with modified input/output type.