Add dynamic_range_quantize to generated op_test infra.
Will add to all op_tests to get complete coverage in subsequent CLs. PiperOrigin-RevId: 316177819 Change-Id: I3fe9d13e7116aa849111a27ab38b4c1815ee82e2
This commit is contained in:
parent
e665a737f9
commit
7270ba4e6d
|
@ -32,6 +32,7 @@ def make_abs_tests(options):
|
||||||
test_parameters = [{
|
test_parameters = [{
|
||||||
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
"input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
|
||||||
[3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
|
[3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
|
||||||
|
"dynamic_range_quantize": [False, True]
|
||||||
}]
|
}]
|
||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
|
|
|
@ -103,10 +103,9 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
|
||||||
input_arrays = [x[0] for x in input_tensors]
|
input_arrays = [x[0] for x in input_tensors]
|
||||||
data_types = [zip_test_utils.TF_TYPE_INFO[x[2]][1] for x in input_tensors]
|
data_types = [zip_test_utils.TF_TYPE_INFO[x[2]][1] for x in input_tensors]
|
||||||
|
|
||||||
if test_params.get("fully_quantize", False):
|
fully_quantize = test_params.get("fully_quantize", False)
|
||||||
# Read the input range for the representative dataset from parameters.
|
dynamic_range_quantize = test_params.get("dynamic_range_quantize", False)
|
||||||
min_value, max_value = test_params.get("input_range", (-1, 1))
|
if dynamic_range_quantize or fully_quantize:
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile() as graphdef_file:
|
with tempfile.NamedTemporaryFile() as graphdef_file:
|
||||||
graphdef_file.write(graph_def_str)
|
graphdef_file.write(graph_def_str)
|
||||||
graphdef_file.flush()
|
graphdef_file.flush()
|
||||||
|
@ -115,32 +114,38 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs):
|
||||||
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
|
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
|
||||||
graphdef_file.name, input_arrays, output_tensors, input_shapes)
|
graphdef_file.name, input_arrays, output_tensors, input_shapes)
|
||||||
|
|
||||||
def representative_dataset(input_tensors):
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
calibration_inputs = []
|
|
||||||
for _, shape, _ in input_tensors:
|
|
||||||
if shape:
|
|
||||||
dims = [dim.value for dim in shape.dims]
|
|
||||||
calibration_inputs.append(
|
|
||||||
np.random.uniform(min_value, max_value,
|
|
||||||
tuple(dims)).astype(np.float32))
|
|
||||||
return calibration_inputs
|
|
||||||
|
|
||||||
def representative_dataset_gen():
|
if fully_quantize:
|
||||||
for _ in range(100):
|
# Read the input range for the representative dataset from parameters.
|
||||||
yield representative_dataset(input_tensors)
|
min_value, max_value = test_params.get("input_range", (-1, 1))
|
||||||
|
|
||||||
converter.target_spec.supported_ops = [
|
def representative_dataset(input_tensors):
|
||||||
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
calibration_inputs = []
|
||||||
]
|
for _, shape, _ in input_tensors:
|
||||||
converter.representative_dataset = representative_dataset_gen
|
if shape:
|
||||||
if extra_toco_options.inference_input_type:
|
dims = [dim.value for dim in shape.dims]
|
||||||
converter.inference_input_type = (
|
calibration_inputs.append(
|
||||||
extra_toco_options.inference_input_type)
|
np.random.uniform(min_value, max_value,
|
||||||
if extra_toco_options.inference_output_type:
|
tuple(dims)).astype(np.float32))
|
||||||
converter.inference_output_type = (
|
return calibration_inputs
|
||||||
extra_toco_options.inference_output_type)
|
|
||||||
else:
|
def representative_dataset_gen():
|
||||||
converter.inference_output_type = tf.int8
|
for _ in range(100):
|
||||||
|
yield representative_dataset(input_tensors)
|
||||||
|
|
||||||
|
converter.target_spec.supported_ops = [
|
||||||
|
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
|
||||||
|
]
|
||||||
|
converter.representative_dataset = representative_dataset_gen
|
||||||
|
if extra_toco_options.inference_input_type:
|
||||||
|
converter.inference_input_type = (
|
||||||
|
extra_toco_options.inference_input_type)
|
||||||
|
if extra_toco_options.inference_output_type:
|
||||||
|
converter.inference_output_type = (
|
||||||
|
extra_toco_options.inference_output_type)
|
||||||
|
else:
|
||||||
|
converter.inference_output_type = tf.int8
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
Loading…
Reference in New Issue