diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index fb4046731cb..afc761570ea 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -635,6 +635,8 @@ def gen_zip_test( flags += " --ignore_converter_errors --run_with_flex" elif conversion_mode == "forward-compat": flags += " --make_forward_compat_test" + elif conversion_mode == "mlir-quant": + flags += " --mlir_quantizer" if test_name.startswith(merged_test_model_name() + "_"): flags += flags_for_merged_test_models(test_name, conversion_mode) diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py index bc87fde9467..ed08373b534 100644 --- a/tensorflow/lite/testing/generate_examples.py +++ b/tensorflow/lite/testing/generate_examples.py @@ -91,6 +91,10 @@ parser.add_argument( help=("Comma-separated list of test set names to generate. " "If not specified, a test set is selected by parsing the name of " "'zip_to_output' file.")) +parser.add_argument( + "--mlir_quantizer", + action="store_true", + help=("Whether the new MLIR quantizer is being used.")) # Toco binary path provided by the generate rule. @@ -116,6 +120,7 @@ def main(unused_args): options.tflite_convert_function = toco_convert.toco_convert options.no_tests_limit = FLAGS.no_tests_limit options.no_conversion_report = FLAGS.no_conversion_report + options.mlir_quantizer = FLAGS.mlir_quantizer if FLAGS.test_sets: test_sets = FLAGS.test_sets.split(",") diff --git a/tensorflow/lite/testing/generate_examples_lib.py b/tensorflow/lite/testing/generate_examples_lib.py index 009a407d70e..bb7e9dbde9b 100644 --- a/tensorflow/lite/testing/generate_examples_lib.py +++ b/tensorflow/lite/testing/generate_examples_lib.py @@ -238,6 +238,7 @@ class Options(object): # TODO(juhoha): Separate the state from the options. self.multi_gen_state = None self.use_experimental_converter = False + self.mlir_quantizer = False def _prepare_dir(options): @@ -273,7 +274,10 @@ def generate_examples(options): else: # Remove suffixes to extract the test name from the output name. test_name = re.sub( - r"(_(|toco-flex|forward-compat|edgetpu))?\.zip$", "", out, count=1) + r"(_(|toco-flex|forward-compat|edgetpu|mlir-quant))?\.zip$", + "", + out, + count=1) test_function_name = "make_%s_tests" % test_name test_function = get_test_function(test_function_name) @@ -313,7 +317,10 @@ def generate_multi_set_examples(options, test_sets): # Remove suffix and set test_name to run proper test generation function. multi_gen_state.test_name = re.sub( - r"(_(|toco-flex|forward-compat))?$", "", test_name, count=1) + r"(_(|toco-flex|forward-compat|mlir-quant))?$", + "", + test_name, + count=1) # Set label base path to write test data files with proper path. multi_gen_state.label_base_path = os.path.join( os.path.dirname(zip_path), test_name + ".zip") diff --git a/tensorflow/lite/testing/toco_convert.py b/tensorflow/lite/testing/toco_convert.py index 48c19c49686..5216c1febbe 100644 --- a/tensorflow/lite/testing/toco_convert.py +++ b/tensorflow/lite/testing/toco_convert.py @@ -115,6 +115,7 @@ def toco_convert(options, graph_def, input_tensors, output_tensors, **kwargs): graphdef_file.name, input_arrays, output_tensors, input_shapes) converter.experimental_new_converter = options.use_experimental_converter + converter._experimental_new_quantizer = options.mlir_quantizer # pylint: disable=protected-access converter.optimizations = [tf.lite.Optimize.DEFAULT] if fully_quantize: diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py index 1b4460461b6..e7ade88e787 100644 --- a/tensorflow/lite/testing/zip_test_utils.py +++ b/tensorflow/lite/testing/zip_test_utils.py @@ -368,6 +368,11 @@ def make_zip_of_tests(options, "fully_quantize", False) or param_dict.get("quant_16x8", False)): continue + # Skips the new quantizer tests when `fully_quantize` is set to false + # or it is not set. + if options.mlir_quantizer and not param_dict.get("fully_quantize", False): + continue + def generate_inputs_outputs(tflite_model_binary, min_value=0, max_value=255):