From d08e1a80eb9e1cda95806449682fddf963c03a87 Mon Sep 17 00:00:00 2001 From: Meghna Natraj Date: Wed, 2 Sep 2020 16:35:43 -0700 Subject: [PATCH] Reduce redundancy by replacing TFLite data types with TF data types PiperOrigin-RevId: 329812474 Change-Id: Ia2f02aea1e8b33f800c3c2241979f43b3be2ca10 --- .../lite/g3doc/r1/convert/cmdline_examples.md | 8 +-- .../g3doc/r1/convert/cmdline_reference.md | 26 ++++----- tensorflow/lite/python/convert.py | 3 +- tensorflow/lite/python/convert_test.py | 9 ++- tensorflow/lite/python/lite.py | 55 +++++++++---------- tensorflow/lite/python/lite_test.py | 45 ++++++++------- tensorflow/lite/python/lite_v2_test.py | 45 +++++++-------- tensorflow/lite/python/optimize/BUILD | 3 +- tensorflow/lite/python/optimize/calibrator.py | 5 +- .../lite/python/optimize/calibrator_test.py | 45 ++++++++------- tensorflow/lite/python/tflite_convert.py | 38 +++++++------ tensorflow/lite/python/util.py | 20 +++---- tensorflow/lite/python/util_test.py | 7 +-- .../model_coverage/model_coverage_lib.py | 4 +- tensorflow/lite/tools/optimize/python/BUILD | 2 +- .../modify_model_interface_constants.py | 8 +-- 16 files changed, 162 insertions(+), 161 deletions(-) diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md b/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md index e38c2b3e215..a2377412c8d 100644 --- a/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md +++ b/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md @@ -104,8 +104,8 @@ tflite_convert \ --std_dev_values=127.7 ``` -*If you're setting `--inference_type=QUANTIZED_UINT8` then update -`--mean_values=128` and `--std_dev_values=127`* +*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and +`--std_dev_values=127`* #### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model @@ -134,8 +134,8 @@ tflite_convert \ --default_ranges_max=6 ``` -*If you're setting `--inference_type=QUANTIZED_UINT8` then update -`--mean_values=128` and `--std_dev_values=127`* +*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and +`--std_dev_values=127`* #### Convert a model with select TensorFlow operators. diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md index 826bb7afdbb..386d9063f9f 100644 --- a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md +++ b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md @@ -63,8 +63,7 @@ based on index. has a shape of [2, 3] and "bar" has a shape of [4, 5, 6]. * `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats. These specify the (de-)quantization parameters of the input array, when it - is quantized. This is only needed if `inference_input_type` is `INT8` or - `QUANTIZED_UINT8`. + is quantized. Only needed if `inference_input_type` is `INT8` or `UINT8`. * The meaning of `mean_values` and `std_dev_values` is as follows: each quantized value in the quantized input array will be interpreted as a mathematical real number (i.e. as an input activation value) according @@ -75,12 +74,12 @@ based on index. the inference code according to the above formula, before proceeding with float inference. * When performing quantized inference (`inference_type` - is`INT8`or`QUANTIZED_UINT8`), no dequantization is performed by the - inference code. However, the quantization parameters of all arrays, - including those of the input arrays as specified - by`mean_value`and`std_dev_value`, determine the fixed-point multipliers - used in the quantized inference code.`mean_value` must be an integer - when performing quantized inference. + is`INT8`or`UINT8`), no dequantization is performed by the inference + code. However, the quantization parameters of all arrays, including + those of the input arrays as specified by`mean_value`and`std_dev_value`, + determine the fixed-point multipliers used in the quantized inference + code.`mean_value` must be an integer when performing quantized + inference. ## Transformation flags @@ -90,7 +89,7 @@ have. * `--inference_type`. Type: string. Default: `FLOAT`. Data type of all real-number arrays in the output file except for input arrays (defined by - `--inference_input_type`). Must be `{FLOAT, INT8, QUANTIZED_UINT8}`. + `--inference_input_type`). Must be `{FLOAT, INT8, UINT8}`. This flag only impacts real-number arrays including float and quantized arrays. This excludes all other data types including plain integer arrays @@ -102,16 +101,15 @@ have. * If `INT8`, then real-numbers arrays will be quantized as int8 in the output file. If they were float in the input file, then they get quantized. - * If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as - uint8 in the output file. If they were float in the input file, then - they get quantized. + * If `UINT8`, then real-numbers arrays will be quantized as uint8 in the + output file. If they were float in the input file, then they get + quantized. * `--inference_input_type`. Type: string. Data type of a real-number input array in the output file. By default the `--inference_type` is used as type of all of the input arrays. Flag is primarily intended for generating a float-point graph with a quantized input array. A Dequantized operator is - added immediately after the input array. Must be `{FLOAT, INT8, - QUANTIZED_UINT8}`. + added immediately after the input array. Must be `{FLOAT, INT8, UINT8}`. The flag is typically used for vision models taking a bitmap as input but requiring floating-point inference. For such image models, the uint8 input diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 7562337d0b4..6f990873228 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation @@ -290,7 +291,7 @@ Alternative, use virtualenv.""") def build_toco_convert_protos(input_tensors, output_tensors, - inference_type=lite_constants.FLOAT, + inference_type=dtypes.float32, inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, input_shapes=None, diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py index e3654217b3a..3cccce38669 100644 --- a/tensorflow/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.lite.python import convert -from tensorflow.lite.python import lite_constants from tensorflow.lite.python import op_hint from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python.client import session @@ -59,7 +58,7 @@ class ConvertTest(test_util.TensorFlowTestCase): tflite_model = convert.toco_convert( sess.graph_def, [in_tensor], [out_tensor], - inference_type=lite_constants.QUANTIZED_UINT8, + inference_type=dtypes.uint8, quantized_input_stats=[(0., 1.)]) self.assertTrue(tflite_model) @@ -73,7 +72,7 @@ class ConvertTest(test_util.TensorFlowTestCase): tflite_model = convert.toco_convert_graph_def( sess.graph_def, [("input", [1, 16, 16, 3])], ["add"], enable_mlir_converter=False, - inference_type=lite_constants.FLOAT) + inference_type=dtypes.float32) self.assertTrue(tflite_model) # Check values from converted model. @@ -111,7 +110,7 @@ class ConvertTest(test_util.TensorFlowTestCase): input_arrays_map, output_arrays, enable_mlir_converter=False, - inference_type=lite_constants.QUANTIZED_UINT8, + inference_type=dtypes.uint8, quantized_input_stats=[(0., 1.), (0., 1.)]) self.assertTrue(tflite_model) @@ -158,7 +157,7 @@ class ConvertTest(test_util.TensorFlowTestCase): input_arrays_map, output_arrays, enable_mlir_converter=False, - inference_type=lite_constants.QUANTIZED_UINT8) + inference_type=dtypes.uint8) self.assertEqual( "std_dev and mean must be defined when inference_type or " "inference_input_type is QUANTIZED_UINT8 or INT8.", diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 4a0ae9d4c9e..272ce18c010 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -161,9 +161,8 @@ class TargetSpec(object): supported_ops: Experimental flag, subject to change. Set of OpsSet options supported by the device. (default set([OpsSet.TFLITE_BUILTINS])) supported_types: List of types for constant values on the target device. - Supported values are types exported by lite.constants. Frequently, an - optimization choice is driven by the most compact (i.e. smallest) type in - this list (default [constants.FLOAT]) + Frequently, an optimization choice is driven by the most compact + (i.e. smallest) type in this list (default [tf.float32]) """ def __init__(self, supported_ops=None, supported_types=None): @@ -198,7 +197,7 @@ class QuantizationMode(object): return (self._any_optimization_enabled() and not self._is_int16x8_target_required() and self._representative_dataset is not None and - self._smallest_supported_type() == constants.INT8) + self._smallest_supported_type() == _dtypes.int8) def is_post_training_integer_quantize_8(self): """Post training integer 8 quantization.""" @@ -239,12 +238,12 @@ class QuantizationMode(object): return (self._any_optimization_enabled() and self._representative_dataset is None and not self.contains_training_quant_op() and - self._smallest_supported_type() == constants.INT8) + self._smallest_supported_type() == _dtypes.int8) def post_training_fp16(self): """Post training fp16 quantize.""" return (self._any_optimization_enabled() and - self._smallest_supported_type() == constants.FLOAT16) + self._smallest_supported_type() == _dtypes.float16) def fp32_execution(self): """If none of the above are true.""" @@ -257,36 +256,36 @@ class QuantizationMode(object): self.post_training_fp16()) def activations_type(self): - return constants.INT16 if self._is_int16x8_target_required() \ - else constants.INT8 + return _dtypes.int16 if self._is_int16x8_target_required() \ + else _dtypes.int8 def converter_flags(self, inference_ty=None, inference_input_ty=None): """Flags to the converter.""" if self.is_post_training_integer_quantize(): # The inference_input_type is for the quantizer, then we need to keep the # converter inference_input_type to float. - inference_input_ty = constants.FLOAT + inference_input_ty = _dtypes.float32 if self.training_time_int8_allow_float(): return { "inference_type": inference_ty if inference_ty else \ self.activations_type(), "inference_input_type": - inference_input_ty if inference_input_ty else constants.FLOAT, + inference_input_ty if inference_input_ty else _dtypes.float32, "post_training_quantize": False, # disable dynamic range quantization "quantize_to_float16": False # disable float16 quantization } elif self.post_training_dynamic_range_int8(): return { - "inference_type": constants.FLOAT, - "inference_input_type": constants.FLOAT, + "inference_type": _dtypes.float32, + "inference_input_type": _dtypes.float32, "post_training_quantize": True, # enable dynamic range quantization "quantize_to_float16": False # disable float16 quantization } elif self.post_training_fp16(): return { - "inference_type": constants.FLOAT, - "inference_input_type": constants.FLOAT, + "inference_type": _dtypes.float32, + "inference_input_type": _dtypes.float32, "post_training_quantize": True, "quantize_to_float16": True # enable float16 quantization } @@ -294,7 +293,7 @@ class QuantizationMode(object): # Note this might still trigger (uint8) quantization to be compatible with # TOCO. return { - "inference_type": inference_ty if inference_ty else constants.FLOAT, + "inference_type": inference_ty if inference_ty else _dtypes.float32, "inference_input_type": inference_input_ty, "post_training_quantize": False, # enable dynamic range quantization "quantize_to_float16": False # disable float16 quantization @@ -303,8 +302,8 @@ class QuantizationMode(object): def quantizer_flags(self, input_ty=None, output_ty=None): """Default flags to the TFMOT quantizer.""" - inference_input_type = input_ty if input_ty else constants.FLOAT - inference_output_type = output_ty if output_ty else constants.FLOAT + inference_input_type = input_ty if input_ty else _dtypes.float32 + inference_output_type = output_ty if output_ty else _dtypes.float32 if self.post_training_int8_no_float() \ or self.post_training_int16x8_no_float(): @@ -326,7 +325,7 @@ class QuantizationMode(object): return False, None def flags_modify_model_io_type( - self, input_type=constants.FLOAT, output_type=constants.FLOAT): + self, input_type=_dtypes.float32, output_type=_dtypes.float32): """Flags for modifying the input and output type of a tflite model.""" is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0] is_training_time_only_quantize = self.training_time_int8_allow_float() and \ @@ -350,7 +349,7 @@ class QuantizationMode(object): return if self._target_spec.supported_types and (self._smallest_supported_type() != - constants.INT8): + _dtypes.int8): raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported " "type to be INT8.") @@ -369,7 +368,7 @@ class QuantizationMode(object): def _is_int8_target_required(self): return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set( self._target_spec.supported_ops) or - set(self._target_spec.supported_types) == set([constants.INT8])) + set(self._target_spec.supported_types) == set([_dtypes.int8])) def _is_int16x8_target_required(self): return bool( @@ -394,7 +393,7 @@ class QuantizationMode(object): return min(self._target_spec.supported_types, key=lambda x: x.size) else: # The default smallest supported type is INT8. - return constants.INT8 + return _dtypes.int8 def contains_training_quant_op(self): """Checks if the graph contains any training-time quantization ops.""" @@ -553,18 +552,18 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): def __init__(self): """Constructor for TFLiteConverter.""" super(TFLiteConverterBaseV2, self).__init__() - self.inference_input_type = constants.FLOAT - self.inference_output_type = constants.FLOAT + self.inference_input_type = _dtypes.float32 + self.inference_output_type = _dtypes.float32 def _validate_inference_input_output_types(self, quant_mode): """Validate inference_input_type and inference_output_type flags.""" - default_types = [constants.FLOAT] + default_types = [_dtypes.float32] # We support integer input/output for integer quantized models only. if quant_mode.training_time_int8_allow_float(): if quant_mode.is_post_training_integer_quantize_16x8(): - all_types = default_types + [constants.INT16] + all_types = default_types + [_dtypes.int16] else: - all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8] + all_types = default_types + [_dtypes.int8, _dtypes.uint8] if self.inference_input_type not in all_types or \ self.inference_output_type not in all_types: all_types_names = ["tf." + t.name for t in all_types] @@ -1103,7 +1102,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): graph debug info for a set of nodes from the `graph_def`. """ super(TFLiteConverterBaseV1, self).__init__() - self.inference_type = constants.FLOAT + self.inference_type = _dtypes.float32 self.inference_input_type = None self.inference_output_type = None self.output_format = constants.TFLITE @@ -1150,7 +1149,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): def _validate_quantized_input_stats(self, converter_kwargs, calibrate): """Ensure the `quantized_input_stats` flag is provided if required.""" - quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8}) + quantized_types = frozenset({_dtypes.int8, _dtypes.uint8}) requires_quantized_input_stats = ( (converter_kwargs["inference_type"] in quantized_types or diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index d17fc94cd20..65e7b572a85 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -155,8 +155,8 @@ class FromSessionTest(TestModels, parameterized.TestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - converter.inference_input_type = lite_constants.QUANTIZED_UINT8 - converter.inference_type = lite_constants.FLOAT + converter.inference_input_type = dtypes.uint8 + converter.inference_type = dtypes.float32 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev tflite_model = converter.convert() self.assertIsNotNone(tflite_model) @@ -788,8 +788,8 @@ class FromSessionTest(TestModels, parameterized.TestCase): quantized_converter = lite.TFLiteConverter.from_session( sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir_converter - quantized_converter.inference_input_type = lite_constants.INT8 - quantized_converter.inference_output_type = lite_constants.INT8 + quantized_converter.inference_input_type = dtypes.int8 + quantized_converter.inference_output_type = dtypes.int8 quantized_converter.optimizations = [lite.Optimize.DEFAULT] quantized_converter.representative_dataset = calibration_gen quantized_tflite_model = quantized_converter.convert() @@ -832,7 +832,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): quantized_converter.experimental_new_converter = enable_mlir_converter quantized_converter.optimizations = [lite.Optimize.DEFAULT] # Restricting to int8 type only - quantized_converter.target_spec.supported_types = [lite.constants.INT8] + quantized_converter.target_spec.supported_types = [dtypes.int8] # A representative dataset is required for full fixed point quantization. with self.assertRaises(ValueError) as error: quantized_converter.convert() @@ -857,7 +857,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1, in_tensor_2], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = { 'inputA': (0., 1.), 'inputB': (0., 1.) @@ -898,7 +898,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev converter.default_ranges_stats = (0, 6) # min, max tflite_model = converter.convert() @@ -954,16 +954,15 @@ class FromSessionTest(TestModels, parameterized.TestCase): interpreter.allocate_tensors() self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name) self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'], - lite.constants.FLOAT) + dtypes.float32) # Convert model to quantized version quantized_converter = lite.TFLiteConverter.from_session( sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir_converter quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16] + quantized_converter.target_spec.supported_types = [dtypes.float16] if include_int8: - quantized_converter.target_spec.supported_types.append( - lite.constants.INT8) + quantized_converter.target_spec.supported_types.append(dtypes.int8) if use_rep_data: quantized_converter.representative_dataset = calibration_gen @@ -984,11 +983,11 @@ class FromSessionTest(TestModels, parameterized.TestCase): if is_float16_quantized: # Verify that bias constant is float16 type. self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'], - lite.constants.FLOAT16) + dtypes.float16) elif is_post_training_quantized: # Verify that bias constants is int32 type. self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'], - lite.constants.INT32) + dtypes.int32) else: raise ValueError('Invalid test options.') @@ -1005,7 +1004,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir_converter quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16] + quantized_converter.target_spec.supported_types = [dtypes.float16] # Specify only int8 builtin ops quantized_converter.target_spec.supported_ops = [ lite.OpsSet.TFLITE_BUILTINS_INT8 @@ -1017,8 +1016,8 @@ class FromSessionTest(TestModels, parameterized.TestCase): str(error.exception)) @parameterized.named_parameters( - ('InferenceType_INT8', lite_constants.INT8), - ('InferenceType_UINT8', lite_constants.QUANTIZED_UINT8)) + ('InferenceType_INT8', dtypes.int8), + ('InferenceType_UINT8', dtypes.uint8)) def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type): with ops.Graph().as_default(): in_tensor = array_ops.placeholder( @@ -1039,7 +1038,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): 'flag is set to tf.uint8 or tf.int8.', str(error.exception)) with self.assertRaises(ValueError) as error: - quantized_converter.inference_type = lite_constants.FLOAT + quantized_converter.inference_type = dtypes.float32 quantized_converter.inference_input_type = quantized_type quantized_converter.convert() self.assertEqual( @@ -1070,7 +1069,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1, in_tensor_2], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev with self.assertRaises(ValueError) as error: converter.convert() @@ -1091,9 +1090,9 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) # extra flags to trigger training time quantization conversion - converter.inference_type = lite_constants.INT8 - converter.inference_input_type = lite_constants.FLOAT - converter.inference_output_type = lite_constants.FLOAT + converter.inference_type = dtypes.int8 + converter.inference_input_type = dtypes.float32 + converter.inference_output_type = dtypes.float32 input_arrays = converter.get_input_arrays() converter.quantized_input_stats = { input_arrays[0]: (0., 1.) @@ -1255,7 +1254,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev tflite_model = converter.convert() self.assertIsNotNone(tflite_model) @@ -2334,7 +2333,7 @@ class DefaultConverterAttrsTest(LiteTest): self.assertEqual(converter.output_format, lite_constants.TFLITE) # Assert the default inference type is float. - self.assertEqual(converter.inference_type, lite_constants.FLOAT) + self.assertEqual(converter.inference_type, dtypes.float32) # Assert the default inference type overrides are None. self.assertIsNone(converter.inference_input_type) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 170db3f6bce..c8ba4d57298 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -32,6 +32,7 @@ from tensorflow.lite.python import lite_v2_test_util from tensorflow.lite.python.convert import mlir_quantize from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.layers import recurrent @@ -74,9 +75,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertEqual(expected_value.numpy(), actual_value) @parameterized.named_parameters( - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8), - ('_INT16InputOutput', lite.constants.INT16)) + ('_INT8InputOutput', dtypes.int8), + ('_UINT8InputOutput', dtypes.uint8), + ('_INT16InputOutput', dtypes.int16)) @test_util.run_v2_only def testInvalidFloat(self, inference_input_output_type): root = self._getSimpleVariableModel() @@ -194,9 +195,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) @parameterized.named_parameters( - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8), - ('_INT16InputOutput', lite.constants.INT16)) + ('_INT8InputOutput', dtypes.int8), + ('_UINT8InputOutput', dtypes.uint8), + ('_INT16InputOutput', dtypes.int16)) @test_util.run_v2_only def testInvalidPostTrainingDynamicRangeQuantization( self, inference_input_output_type): @@ -219,18 +220,18 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): 'must be tf.float32.', str(error.exception)) @parameterized.named_parameters( - ('_Default', False, False, lite.constants.FLOAT), - ('_INT8InputOutput', False, False, lite.constants.INT8), - ('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8), - ('_INT16Quantize', False, True, lite.constants.FLOAT), - ('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16), - ('_IntOnly', True, False, lite.constants.FLOAT), - ('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8), + ('_Default', False, False, dtypes.float32), + ('_INT8InputOutput', False, False, dtypes.int8), + ('_UINT8InputOutput', False, False, dtypes.uint8), + ('_INT16Quantize', False, True, dtypes.float32), + ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), + ('_IntOnly', True, False, dtypes.float32), + ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), ('_IntOnly_UINT8InputOutput', True, False, - lite.constants.QUANTIZED_UINT8), - ('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT), + dtypes.uint8), + ('_IntOnly_INT16Quantize', True, True, dtypes.float32), ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, - lite.constants.INT16)) + dtypes.int16)) def testIntegerQuantization(self, is_int_only, is_int16_quantize, inference_input_output_type): func, calibration_gen = self._getIntegerQuantizeModel() @@ -281,7 +282,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertLess(len(quantized_tflite_model), len(tflite_model)) @parameterized.named_parameters( - ('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8)) + ('_INT16Quantize_INT8InputOutput', True, dtypes.int8)) def testInvalidIntegerQuantization(self, is_int16_quantize, inference_input_output_type): func, calibration_gen = self._getIntegerQuantizeModel() @@ -297,8 +298,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): lite.OpsSet.TFLITE_BUILTINS ] with self.assertRaises(ValueError) as error: - quantized_converter.inference_input_type = lite.constants.INT8 - quantized_converter.inference_output_type = lite.constants.INT8 + quantized_converter.inference_input_type = dtypes.int8 + quantized_converter.inference_output_type = dtypes.int8 quantized_converter.convert() self.assertEqual( "The inference_input_type and inference_output_type " @@ -377,9 +378,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): return tf.keras.Sequential(QLinear(3, input_shape=(2,))) @parameterized.named_parameters( - ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT), - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8)) + ('_DefaultFLOAT32InputOutput', dtypes.float32), + ('_INT8InputOutput', dtypes.int8), + ('_UINT8InputOutput', dtypes.uint8)) @test_util.run_v2_only def testTrainingTimeQuantization(self, inference_input_output_type): model = self._getTrainingTimeQuantizedModel() diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index 1a0d3db3b73..b921fc45cde 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -50,6 +50,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":_pywrap_tensorflow_lite_calibration_wrapper", # buildcleaner: keep + "//tensorflow/python:dtypes", "//tensorflow/python:util", "//third_party/py/numpy", ], @@ -67,8 +68,8 @@ py_test( tags = ["no_oss"], deps = [ ":calibrator", - "//tensorflow/lite/python:lite_constants", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//third_party/py/numpy", diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index 2b08ec690ff..dfef8b9cb79 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -18,8 +18,9 @@ from __future__ import division from __future__ import print_function import numpy as np + +from tensorflow.python.framework import dtypes from tensorflow.python.util.lazy_loader import LazyLoader -from tensorflow.lite.python import lite_constants # Lazy load since some of the performance benchmark skylark rules # break dependencies. Must use double quotes to match code internal rewrite @@ -60,7 +61,7 @@ class Calibrator(object): input_type, output_type, allow_float, - activations_type=lite_constants.INT8, + activations_type=dtypes.int8, resize_input=True): """Calibrates the model with specified generator and then quantizes it. diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index 371b3514ca3..a9ab12c6095 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -23,8 +23,8 @@ from absl.testing import parameterized import numpy as np from six.moves import range -from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python.optimize import calibrator as _calibrator +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -34,9 +34,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( # Activation type Int8 - ('UseActivationTypeInt8', constants.INT8), + ('UseActivationTypeInt8', dtypes.int8), # Activation type Int16 - ('UseActivationTypeInt16', constants.INT16)) + ('UseActivationTypeInt16', dtypes.int16)) def test_calibration_with_quantization(self, activations_type): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') @@ -49,16 +49,17 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] quantized_model = quantizer.calibrate_and_quantize(input_gen, - constants.FLOAT, - constants.FLOAT, False, + dtypes.float32, + dtypes.float32, + False, activations_type) self.assertIsNotNone(quantized_model) @parameterized.named_parameters( # Activation type Int8 - ('UseActivationTypeInt8', constants.INT8), + ('UseActivationTypeInt8', dtypes.int8), # Activation type Int16 - ('UseActivationTypeInt16', constants.INT16)) + ('UseActivationTypeInt16', dtypes.int16)) def test_calibration_with_quantization_allow_float(self, activations_type): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') @@ -71,8 +72,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] quantized_model = quantizer.calibrate_and_quantize(input_gen, - constants.FLOAT, - constants.FLOAT, True, + dtypes.float32, + dtypes.float32, + True, activations_type) self.assertIsNotNone(quantized_model) @@ -88,7 +90,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] quantized_model = quantizer.calibrate_and_quantize_single( - input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd') + input_gen, dtypes.float32, dtypes.float32, True, 'conv2d_8/BiasAdd') self.assertIsNotNone(quantized_model) def test_calibration_with_string_input(self): @@ -103,14 +105,14 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.array(u'Test' + str(i))] quantized_model = quantizer.calibrate_and_quantize_single( - input_gen, constants.FLOAT, constants.FLOAT, True, 'Identity') + input_gen, dtypes.float32, dtypes.float32, True, 'Identity') self.assertIsNotNone(quantized_model) @parameterized.named_parameters( # Activation type Int8 - ('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8), + ('UseActivationTypeInt8 - EnableMlirQuantizer', dtypes.int8), # Activation type Int16 - ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16)) + ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', dtypes.int16)) def test_calibration_with_quantization_multiple_inputs( self, activations_type): # Load multi add model from test data. @@ -126,8 +128,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)] quantized_model = quantizer.calibrate_and_quantize(input_gen, - constants.FLOAT, - constants.FLOAT, False, + dtypes.float32, + dtypes.float32, + False, activations_type) self.assertIsNotNone(quantized_model) @@ -148,8 +151,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield i with self.assertRaises(RuntimeError): - quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT, - constants.FLOAT, False) + quantizer.calibrate_and_quantize(empty_input_gen, dtypes.float32, + dtypes.float32, False) def test_invalid_shape_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( @@ -163,8 +166,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)] with self.assertRaisesRegex(ValueError, 'Size mismatch'): - quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, constants.INT8, + quantizer.calibrate_and_quantize(input_gen, dtypes.float32, + dtypes.float32, False, dtypes.int8, False) def test_invalid_type_calibrator_gen(self): @@ -179,8 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)] with self.assertRaises(ValueError): - quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, constants.INT8) + quantizer.calibrate_and_quantize(input_gen, dtypes.float32, + dtypes.float32, False, dtypes.int8) def test_calibration(self): model_path = resource_loader.get_path_to_datafile( diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index b5eb66eedc4..3e456a4fc9f 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -28,11 +28,11 @@ import six from six.moves import zip from tensorflow.lite.python import lite -from tensorflow.lite.python import lite_constants from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco.logging import gen_html from tensorflow.python import keras from tensorflow.python import tf2 +from tensorflow.python.framework import dtypes from tensorflow.python.platform import app @@ -62,13 +62,13 @@ def _parse_inference_type(value, flag): ValueError: Unsupported value. """ if value == "FLOAT": - return lite_constants.FLOAT - if value == "QUANTIZED_UINT8": - return lite_constants.QUANTIZED_UINT8 + return dtypes.float32 if value == "INT8": - return lite_constants.INT8 - raise ValueError("Unsupported value for --{0}. Only FLOAT and " - "QUANTIZED_UINT8 are supported.".format(flag)) + return dtypes.int8 + if value == "UINT8" or value == "QUANTIZED_UINT8": + return dtypes.uint8 + raise ValueError("Unsupported value for --{0}. Value must be in " + "(FLOAT, INT8, UINT8)".format(flag)) def _get_tflite_converter(flags): @@ -146,10 +146,10 @@ def _convert_tf1_model(flags): # In quantized inference, mean_value has to be integer so that the real # value 0.0 is exactly representable. - if converter.inference_type == lite_constants.QUANTIZED_UINT8: - mean_values = _parse_array(flags.mean_values, type_fn=int) - else: + if converter.inference_type == dtypes.float32: mean_values = _parse_array(flags.mean_values, type_fn=float) + else: + mean_values = _parse_array(flags.mean_values, type_fn=int) quant_stats = list(zip(mean_values, std_dev_values)) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): @@ -189,13 +189,13 @@ def _convert_tf1_model(flags): if flags.post_training_quantize: converter.optimizations = [lite.Optimize.DEFAULT] - if converter.inference_type == lite_constants.QUANTIZED_UINT8: + if converter.inference_type != dtypes.float32: print("--post_training_quantize quantizes a graph of inference_type " - "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.") - converter.inference_type = lite_constants.FLOAT + "FLOAT. Overriding inference type to FLOAT.") + converter.inference_type = dtypes.float32 if flags.quantize_to_float16: - converter.target_spec.supported_types = [lite.constants.FLOAT16] + converter.target_spec.supported_types = [dtypes.float16] if not flags.post_training_quantize: print("--quantize_to_float16 will only take effect with the " "--post_training_quantize flag enabled.") @@ -354,14 +354,16 @@ def _get_tf1_flags(parser): parser.add_argument( "--inference_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], - help="Target data type of real-number arrays in the output file.") + default="FLOAT", + help=("Target data type of real-number arrays in the output file. " + "Must be either FLOAT, INT8 or UINT8.")) parser.add_argument( "--inference_input_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], + default="FLOAT", help=("Target data type of real-number input arrays. Allows for a " - "different type for input arrays in the case of quantization.")) + "different type for input arrays in the case of quantization. " + "Must be either FLOAT, INT8 or UINT8.")) # Input and output arrays flags. parser.add_argument( diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index 79d2775d1dc..c99d0ec159c 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -31,7 +31,6 @@ import flatbuffers from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import graph_debug_info_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 -from tensorflow.lite.python import lite_constants as _lite_constants from tensorflow.lite.python import schema_py_generated as schema_fb from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes @@ -77,8 +76,7 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = { _TFLITE_FILE_IDENTIFIER = b"TFL3" -_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (_lite_constants.FLOAT, _lite_constants.INT8, - _lite_constants.QUANTIZED_UINT8) +_TFLITE_MODEL_INPUT_OUTPUT_TYPES = (dtypes.float32, dtypes.int8, dtypes.uint8) def convert_dtype_to_tflite_type(tf_dtype): @@ -684,8 +682,8 @@ def _validate_and_find_int8_quantized_inputs_outputs(model): def modify_integer_quantized_model_io_type( - model, inference_input_type=_lite_constants.FLOAT, - inference_output_type=_lite_constants.FLOAT): + model, inference_input_type=dtypes.float32, + inference_output_type=dtypes.float32): """Modify the float input/output type of an integer quantized model. Args: @@ -705,8 +703,8 @@ def modify_integer_quantized_model_io_type( """ # Return if input and output types default to float - if inference_input_type == _lite_constants.FLOAT and \ - inference_output_type == _lite_constants.FLOAT: + if inference_input_type == dtypes.float32 and \ + inference_output_type == dtypes.float32: return model # Validate input and output types @@ -738,7 +736,7 @@ def modify_integer_quantized_model_io_type( remove_tensors_idxs = set() # Modify model input type - if inference_input_type == _lite_constants.QUANTIZED_UINT8: + if inference_input_type == dtypes.uint8: # Change quant op (float to int8) to quant op (uint8 to int8) for op in input_quant_ops: int8_quantization = tensors[op.outputs[0]].quantization @@ -747,7 +745,7 @@ def modify_integer_quantized_model_io_type( uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] tensors[op.inputs[0]].quantization = uint8_quantization tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 - elif inference_input_type == _lite_constants.INT8: + elif inference_input_type == dtypes.int8: # Remove the inputs and the quant operator for op in input_quant_ops: subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] @@ -755,7 +753,7 @@ def modify_integer_quantized_model_io_type( operators.remove(op) # Modify model output type - if inference_output_type == _lite_constants.QUANTIZED_UINT8: + if inference_output_type == dtypes.uint8: # Change dequant op (int8 to float) to quant op (int8 to uint8) for op in output_dequant_ops: op.opcodeIndex = input_quant_ops[0].opcodeIndex @@ -765,7 +763,7 @@ def modify_integer_quantized_model_io_type( uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] tensors[op.outputs[0]].quantization = uint8_quantization tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 - elif inference_output_type == _lite_constants.INT8: + elif inference_output_type == dtypes.int8: # Remove the outputs and the dequant operator for op in output_dequant_ops: subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index 820cda4c7d6..60057cddbb1 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -24,7 +24,6 @@ import numpy as np from six.moves import range import tensorflow as tf -from tensorflow.lite.python import lite_constants from tensorflow.lite.python import util from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.client import session @@ -292,9 +291,9 @@ def _test_param_modify_integer_model_io_type(): # "DuringTraining": False, } map_types = { - "": lite_constants.FLOAT, - "INT8": lite_constants.INT8, - "UINT8": lite_constants.QUANTIZED_UINT8 + "": dtypes.float32, + "INT8": dtypes.int8, + "UINT8": dtypes.uint8, } for k1, v1 in map_model_type.items(): for k2, v2 in map_types.items(): diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index d9cd6883a8d..7825dfb560a 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -28,11 +28,11 @@ from google.protobuf.message import DecodeError from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.lite.python import convert_saved_model as _convert_saved_model from tensorflow.lite.python import lite as _lite -from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python import util as _util from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.keras.preprocessing import image @@ -97,7 +97,7 @@ def _convert(converter, **kwargs): if "post_training_quantize" in kwargs: converter.optimizations = [_lite.Optimize.DEFAULT] if kwargs.get("quantize_to_float16", False): - converter.target_spec.supported_types = [constants.FLOAT16] + converter.target_spec.supported_types = [dtypes.float16] if kwargs.get("post_training_quantize_16x8", False): input_size = kwargs.get("model_input_size") diff --git a/tensorflow/lite/tools/optimize/python/BUILD b/tensorflow/lite/tools/optimize/python/BUILD index 050c0008924..34f57cbecaf 100644 --- a/tensorflow/lite/tools/optimize/python/BUILD +++ b/tensorflow/lite/tools/optimize/python/BUILD @@ -52,7 +52,7 @@ py_library( name = "modify_model_interface_constants", srcs = ["modify_model_interface_constants.py"], srcs_version = "PY3", - deps = ["//tensorflow/lite/python:lite_constants"], + deps = ["//tensorflow/python:dtypes"], ) pybind_extension( diff --git a/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py b/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py index cbe1aa92022..f7c7cc60d5c 100644 --- a/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py +++ b/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py @@ -19,12 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.lite.python import lite_constants +from tensorflow.python.framework import dtypes STR_TO_TFLITE_TYPES = { - 'INT8': lite_constants.INT8, - 'INT16': lite_constants.INT16, - 'UINT8': lite_constants.QUANTIZED_UINT8 + 'INT8': dtypes.int8, + 'UINT8': dtypes.uint8, + 'INT16': dtypes.int16, } TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}