Reduce redundancy by replacing TFLite data types with TF data types
PiperOrigin-RevId: 337004382 Change-Id: I48b5734781527138705bdda5e6ea80120e3a33a1
This commit is contained in:
		
							parent
							
								
									e485bb64c5
								
							
						
					
					
						commit
						2396803aa1
					
				@ -104,8 +104,8 @@ tflite_convert \
 | 
			
		||||
  --std_dev_values=127.7
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
*If you're setting `--inference_type=QUANTIZED_UINT8` then update
 | 
			
		||||
`--mean_values=128` and `--std_dev_values=127`*
 | 
			
		||||
*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and
 | 
			
		||||
`--std_dev_values=127`*
 | 
			
		||||
 | 
			
		||||
#### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model
 | 
			
		||||
 | 
			
		||||
@ -134,8 +134,8 @@ tflite_convert \
 | 
			
		||||
  --default_ranges_max=6
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
*If you're setting `--inference_type=QUANTIZED_UINT8` then update
 | 
			
		||||
`--mean_values=128` and `--std_dev_values=127`*
 | 
			
		||||
*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and
 | 
			
		||||
`--std_dev_values=127`*
 | 
			
		||||
 | 
			
		||||
#### Convert a model with select TensorFlow operators.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,8 +63,7 @@ based on index.
 | 
			
		||||
        has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
 | 
			
		||||
*   `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
 | 
			
		||||
    These specify the (de-)quantization parameters of the input array, when it
 | 
			
		||||
    is quantized. This is only needed if `inference_input_type` is `INT8` or
 | 
			
		||||
    `QUANTIZED_UINT8`.
 | 
			
		||||
    is quantized. Only needed if `inference_input_type` is `INT8` or `UINT8`.
 | 
			
		||||
    *   The meaning of `mean_values` and `std_dev_values` is as follows: each
 | 
			
		||||
        quantized value in the quantized input array will be interpreted as a
 | 
			
		||||
        mathematical real number (i.e. as an input activation value) according
 | 
			
		||||
@ -75,12 +74,12 @@ based on index.
 | 
			
		||||
        the inference code according to the above formula, before proceeding
 | 
			
		||||
        with float inference.
 | 
			
		||||
    *   When performing quantized inference (`inference_type`
 | 
			
		||||
        is`INT8`or`QUANTIZED_UINT8`), no dequantization is performed by the
 | 
			
		||||
        inference code. However, the quantization parameters of all arrays,
 | 
			
		||||
        including those of the input arrays as specified
 | 
			
		||||
        by`mean_value`and`std_dev_value`, determine the fixed-point multipliers
 | 
			
		||||
        used in the quantized inference code.`mean_value` must be an integer
 | 
			
		||||
        when performing quantized inference.
 | 
			
		||||
        is`INT8`or`UINT8`), no dequantization is performed by the inference
 | 
			
		||||
        code. However, the quantization parameters of all arrays, including
 | 
			
		||||
        those of the input arrays as specified by`mean_value`and`std_dev_value`,
 | 
			
		||||
        determine the fixed-point multipliers used in the quantized inference
 | 
			
		||||
        code.`mean_value` must be an integer when performing quantized
 | 
			
		||||
        inference.
 | 
			
		||||
 | 
			
		||||
## Transformation flags
 | 
			
		||||
 | 
			
		||||
@ -90,7 +89,7 @@ have.
 | 
			
		||||
 | 
			
		||||
*   `--inference_type`. Type: string. Default: `FLOAT`. Data type of all
 | 
			
		||||
    real-number arrays in the output file except for input arrays (defined by
 | 
			
		||||
    `--inference_input_type`). Must be `{FLOAT, INT8, QUANTIZED_UINT8}`.
 | 
			
		||||
    `--inference_input_type`). Must be `{FLOAT, INT8, UINT8}`.
 | 
			
		||||
 | 
			
		||||
    This flag only impacts real-number arrays including float and quantized
 | 
			
		||||
    arrays. This excludes all other data types including plain integer arrays
 | 
			
		||||
@ -102,16 +101,15 @@ have.
 | 
			
		||||
    *   If `INT8`, then real-numbers arrays will be quantized as int8 in the
 | 
			
		||||
        output file. If they were float in the input file, then they get
 | 
			
		||||
        quantized.
 | 
			
		||||
    *   If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as
 | 
			
		||||
        uint8 in the output file. If they were float in the input file, then
 | 
			
		||||
        they get quantized.
 | 
			
		||||
    *   If `UINT8`, then real-numbers arrays will be quantized as uint8 in the
 | 
			
		||||
        output file. If they were float in the input file, then they get
 | 
			
		||||
        quantized.
 | 
			
		||||
 | 
			
		||||
*   `--inference_input_type`. Type: string. Data type of a real-number input
 | 
			
		||||
    array in the output file. By default the `--inference_type` is used as type
 | 
			
		||||
    of all of the input arrays. Flag is primarily intended for generating a
 | 
			
		||||
    float-point graph with a quantized input array. A Dequantized operator is
 | 
			
		||||
    added immediately after the input array. Must be `{FLOAT, INT8,
 | 
			
		||||
    QUANTIZED_UINT8}`.
 | 
			
		||||
    added immediately after the input array. Must be `{FLOAT, INT8, UINT8}`.
 | 
			
		||||
 | 
			
		||||
    The flag is typically used for vision models taking a bitmap as input but
 | 
			
		||||
    requiring floating-point inference. For such image models, the uint8 input
 | 
			
		||||
 | 
			
		||||
@ -99,7 +99,7 @@ py_test(
 | 
			
		||||
    ],
 | 
			
		||||
    python_version = "PY3",
 | 
			
		||||
    # Increased thread count for reducing timeout failures.
 | 
			
		||||
    shard_count = 4,
 | 
			
		||||
    shard_count = 10,
 | 
			
		||||
    srcs_version = "PY2AND3",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_oss",
 | 
			
		||||
@ -109,9 +109,26 @@ py_test(
 | 
			
		||||
        "notsan",  # b/160824139
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":convert",
 | 
			
		||||
        ":tflite_convert",
 | 
			
		||||
        "//tensorflow/python:array_ops",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:constant_op",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:framework",
 | 
			
		||||
        "//tensorflow/python:framework_ops",
 | 
			
		||||
        "//tensorflow/python:framework_test_lib",
 | 
			
		||||
        "//tensorflow/python:platform",
 | 
			
		||||
        "//tensorflow/python:random_ops",
 | 
			
		||||
        "//tensorflow/python:session",
 | 
			
		||||
        "//tensorflow/python:tf2",
 | 
			
		||||
        "//tensorflow/python/eager:def_function",
 | 
			
		||||
        "//tensorflow/python/keras",
 | 
			
		||||
        "//tensorflow/python/saved_model",
 | 
			
		||||
        "//tensorflow/python/saved_model:save",
 | 
			
		||||
        "//tensorflow/python/training:training_util",
 | 
			
		||||
        "//tensorflow/python/training/tracking",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco
 | 
			
		||||
from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2
 | 
			
		||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
 | 
			
		||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import tensor_shape
 | 
			
		||||
from tensorflow.python.platform import resource_loader as _resource_loader
 | 
			
		||||
from tensorflow.python.util import deprecation
 | 
			
		||||
@ -301,7 +302,7 @@ Alternative, use virtualenv.""")
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_toco_flags(inference_type=lite_constants.FLOAT,
 | 
			
		||||
def build_toco_flags(inference_type=dtypes.float32,
 | 
			
		||||
                     inference_input_type=None,
 | 
			
		||||
                     input_format=lite_constants.TENSORFLOW_GRAPHDEF,
 | 
			
		||||
                     output_format=lite_constants.TFLITE,
 | 
			
		||||
@ -352,7 +353,7 @@ def build_toco_flags(inference_type=lite_constants.FLOAT,
 | 
			
		||||
 | 
			
		||||
def build_toco_convert_protos(input_tensors,
 | 
			
		||||
                              output_tensors,
 | 
			
		||||
                              inference_type=lite_constants.FLOAT,
 | 
			
		||||
                              inference_type=dtypes.float32,
 | 
			
		||||
                              inference_input_type=None,
 | 
			
		||||
                              input_format=lite_constants.TENSORFLOW_GRAPHDEF,
 | 
			
		||||
                              input_shapes=None,
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,6 @@ from __future__ import print_function
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.lite.python import convert
 | 
			
		||||
from tensorflow.lite.python import lite_constants
 | 
			
		||||
from tensorflow.lite.python import op_hint
 | 
			
		||||
from tensorflow.lite.python.interpreter import Interpreter
 | 
			
		||||
from tensorflow.python.client import session
 | 
			
		||||
@ -59,7 +58,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
 | 
			
		||||
 | 
			
		||||
    tflite_model = convert.toco_convert(
 | 
			
		||||
        sess.graph_def, [in_tensor], [out_tensor],
 | 
			
		||||
        inference_type=lite_constants.QUANTIZED_UINT8,
 | 
			
		||||
        inference_type=dtypes.uint8,
 | 
			
		||||
        quantized_input_stats=[(0., 1.)])
 | 
			
		||||
    self.assertTrue(tflite_model)
 | 
			
		||||
 | 
			
		||||
@ -73,7 +72,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
 | 
			
		||||
    tflite_model = convert.toco_convert_graph_def(
 | 
			
		||||
        sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
 | 
			
		||||
        enable_mlir_converter=False,
 | 
			
		||||
        inference_type=lite_constants.FLOAT)
 | 
			
		||||
        inference_type=dtypes.float32)
 | 
			
		||||
    self.assertTrue(tflite_model)
 | 
			
		||||
 | 
			
		||||
    # Check values from converted model.
 | 
			
		||||
@ -111,7 +110,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
 | 
			
		||||
        input_arrays_map,
 | 
			
		||||
        output_arrays,
 | 
			
		||||
        enable_mlir_converter=False,
 | 
			
		||||
        inference_type=lite_constants.QUANTIZED_UINT8,
 | 
			
		||||
        inference_type=dtypes.uint8,
 | 
			
		||||
        quantized_input_stats=[(0., 1.), (0., 1.)])
 | 
			
		||||
    self.assertTrue(tflite_model)
 | 
			
		||||
 | 
			
		||||
@ -158,7 +157,7 @@ class ConvertTest(test_util.TensorFlowTestCase):
 | 
			
		||||
          input_arrays_map,
 | 
			
		||||
          output_arrays,
 | 
			
		||||
          enable_mlir_converter=False,
 | 
			
		||||
          inference_type=lite_constants.QUANTIZED_UINT8)
 | 
			
		||||
          inference_type=dtypes.uint8)
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        "std_dev and mean must be defined when inference_type or "
 | 
			
		||||
        "inference_input_type is QUANTIZED_UINT8 or INT8.",
 | 
			
		||||
 | 
			
		||||
@ -163,9 +163,8 @@ class TargetSpec(object):
 | 
			
		||||
    supported_ops: Experimental flag, subject to change. Set of OpsSet options
 | 
			
		||||
      supported by the device. (default set([OpsSet.TFLITE_BUILTINS]))
 | 
			
		||||
    supported_types: List of types for constant values on the target device.
 | 
			
		||||
      Supported values are types exported by lite.constants. Frequently, an
 | 
			
		||||
      optimization choice is driven by the most compact (i.e. smallest) type in
 | 
			
		||||
      this list (default [constants.FLOAT])
 | 
			
		||||
      Frequently, an optimization choice is driven by the most compact
 | 
			
		||||
      (i.e. smallest) type in this list (default [tf.float32])
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  def __init__(self, supported_ops=None, supported_types=None):
 | 
			
		||||
@ -200,7 +199,7 @@ class QuantizationMode(object):
 | 
			
		||||
    return (self._any_optimization_enabled() and
 | 
			
		||||
            not self._is_int16x8_target_required() and
 | 
			
		||||
            self._representative_dataset is not None and
 | 
			
		||||
            self._smallest_supported_type() == constants.INT8)
 | 
			
		||||
            self._smallest_supported_type() == _dtypes.int8)
 | 
			
		||||
 | 
			
		||||
  def is_post_training_integer_quantize_8(self):
 | 
			
		||||
    """Post training integer 8 quantization."""
 | 
			
		||||
@ -241,12 +240,12 @@ class QuantizationMode(object):
 | 
			
		||||
    return (self._any_optimization_enabled() and
 | 
			
		||||
            self._representative_dataset is None and
 | 
			
		||||
            not self.contains_training_quant_op() and
 | 
			
		||||
            self._smallest_supported_type() == constants.INT8)
 | 
			
		||||
            self._smallest_supported_type() == _dtypes.int8)
 | 
			
		||||
 | 
			
		||||
  def post_training_fp16(self):
 | 
			
		||||
    """Post training fp16 quantize."""
 | 
			
		||||
    return (self._any_optimization_enabled() and
 | 
			
		||||
            self._smallest_supported_type() == constants.FLOAT16)
 | 
			
		||||
            self._smallest_supported_type() == _dtypes.float16)
 | 
			
		||||
 | 
			
		||||
  def fp32_execution(self):
 | 
			
		||||
    """If none of the above are true."""
 | 
			
		||||
@ -259,36 +258,36 @@ class QuantizationMode(object):
 | 
			
		||||
                self.post_training_fp16())
 | 
			
		||||
 | 
			
		||||
  def activations_type(self):
 | 
			
		||||
    return constants.INT16 if self._is_int16x8_target_required() \
 | 
			
		||||
      else constants.INT8
 | 
			
		||||
    return _dtypes.int16 if self._is_int16x8_target_required() \
 | 
			
		||||
      else _dtypes.int8
 | 
			
		||||
 | 
			
		||||
  def converter_flags(self, inference_ty=None, inference_input_ty=None):
 | 
			
		||||
    """Flags to the converter."""
 | 
			
		||||
    if self.is_post_training_integer_quantize():
 | 
			
		||||
      # The inference_input_type is for the quantizer, then we need to keep the
 | 
			
		||||
      # converter inference_input_type to float.
 | 
			
		||||
      inference_input_ty = constants.FLOAT
 | 
			
		||||
      inference_input_ty = _dtypes.float32
 | 
			
		||||
 | 
			
		||||
    if self.training_time_int8_allow_float():
 | 
			
		||||
      return {
 | 
			
		||||
          "inference_type": inference_ty if inference_ty else \
 | 
			
		||||
            self.activations_type(),
 | 
			
		||||
          "inference_input_type":
 | 
			
		||||
              inference_input_ty if inference_input_ty else constants.FLOAT,
 | 
			
		||||
              inference_input_ty if inference_input_ty else _dtypes.float32,
 | 
			
		||||
          "post_training_quantize": False,  # disable dynamic range quantization
 | 
			
		||||
          "quantize_to_float16": False  # disable float16 quantization
 | 
			
		||||
      }
 | 
			
		||||
    elif self.post_training_dynamic_range_int8():
 | 
			
		||||
      return {
 | 
			
		||||
          "inference_type": constants.FLOAT,
 | 
			
		||||
          "inference_input_type": constants.FLOAT,
 | 
			
		||||
          "inference_type": _dtypes.float32,
 | 
			
		||||
          "inference_input_type": _dtypes.float32,
 | 
			
		||||
          "post_training_quantize": True,  # enable dynamic range quantization
 | 
			
		||||
          "quantize_to_float16": False  # disable float16 quantization
 | 
			
		||||
      }
 | 
			
		||||
    elif self.post_training_fp16():
 | 
			
		||||
      return {
 | 
			
		||||
          "inference_type": constants.FLOAT,
 | 
			
		||||
          "inference_input_type": constants.FLOAT,
 | 
			
		||||
          "inference_type": _dtypes.float32,
 | 
			
		||||
          "inference_input_type": _dtypes.float32,
 | 
			
		||||
          "post_training_quantize": True,
 | 
			
		||||
          "quantize_to_float16": True  # enable float16 quantization
 | 
			
		||||
      }
 | 
			
		||||
@ -296,7 +295,7 @@ class QuantizationMode(object):
 | 
			
		||||
      # Note this might still trigger (uint8) quantization to be compatible with
 | 
			
		||||
      # TOCO.
 | 
			
		||||
      return {
 | 
			
		||||
          "inference_type": inference_ty if inference_ty else constants.FLOAT,
 | 
			
		||||
          "inference_type": inference_ty if inference_ty else _dtypes.float32,
 | 
			
		||||
          "inference_input_type": inference_input_ty,
 | 
			
		||||
          "post_training_quantize": False,  # enable dynamic range quantization
 | 
			
		||||
          "quantize_to_float16": False  # disable float16 quantization
 | 
			
		||||
@ -305,8 +304,8 @@ class QuantizationMode(object):
 | 
			
		||||
  def quantizer_flags(self, input_ty=None, output_ty=None):
 | 
			
		||||
    """Default flags to the TFMOT quantizer."""
 | 
			
		||||
 | 
			
		||||
    inference_input_type = input_ty if input_ty else constants.FLOAT
 | 
			
		||||
    inference_output_type = output_ty if output_ty else constants.FLOAT
 | 
			
		||||
    inference_input_type = input_ty if input_ty else _dtypes.float32
 | 
			
		||||
    inference_output_type = output_ty if output_ty else _dtypes.float32
 | 
			
		||||
 | 
			
		||||
    if self.post_training_int8_no_float() \
 | 
			
		||||
      or self.post_training_int16x8_no_float():
 | 
			
		||||
@ -327,9 +326,8 @@ class QuantizationMode(object):
 | 
			
		||||
    else:
 | 
			
		||||
      return False, None
 | 
			
		||||
 | 
			
		||||
  def flags_modify_model_io_type(self,
 | 
			
		||||
                                 input_type=constants.FLOAT,
 | 
			
		||||
                                 output_type=constants.FLOAT):
 | 
			
		||||
  def flags_modify_model_io_type(
 | 
			
		||||
      self, input_type=_dtypes.float32, output_type=_dtypes.float32):
 | 
			
		||||
    """Flags for modifying the input and output type of a tflite model."""
 | 
			
		||||
    is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0]
 | 
			
		||||
    is_training_time_only_quantize = self.training_time_int8_allow_float() and \
 | 
			
		||||
@ -353,7 +351,7 @@ class QuantizationMode(object):
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
    if self._target_spec.supported_types and (self._smallest_supported_type() !=
 | 
			
		||||
                                              constants.INT8):
 | 
			
		||||
                                              _dtypes.int8):
 | 
			
		||||
      raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported "
 | 
			
		||||
                       "type to be INT8.")
 | 
			
		||||
 | 
			
		||||
@ -372,7 +370,7 @@ class QuantizationMode(object):
 | 
			
		||||
  def _is_int8_target_required(self):
 | 
			
		||||
    return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(
 | 
			
		||||
        self._target_spec.supported_ops) or
 | 
			
		||||
            set(self._target_spec.supported_types) == set([constants.INT8]))
 | 
			
		||||
            set(self._target_spec.supported_types) == set([_dtypes.int8]))
 | 
			
		||||
 | 
			
		||||
  def _is_int16x8_target_required(self):
 | 
			
		||||
    return bool(
 | 
			
		||||
@ -397,7 +395,7 @@ class QuantizationMode(object):
 | 
			
		||||
      return min(self._target_spec.supported_types, key=lambda x: x.size)
 | 
			
		||||
    else:
 | 
			
		||||
      # The default smallest supported type is INT8.
 | 
			
		||||
      return constants.INT8
 | 
			
		||||
      return _dtypes.int8
 | 
			
		||||
 | 
			
		||||
  def contains_training_quant_op(self):
 | 
			
		||||
    """Checks if the graph contains any training-time quantization ops."""
 | 
			
		||||
@ -556,18 +554,18 @@ class TFLiteConverterBaseV2(TFLiteConverterBase):
 | 
			
		||||
  def __init__(self):
 | 
			
		||||
    """Constructor for TFLiteConverter."""
 | 
			
		||||
    super(TFLiteConverterBaseV2, self).__init__()
 | 
			
		||||
    self.inference_input_type = constants.FLOAT
 | 
			
		||||
    self.inference_output_type = constants.FLOAT
 | 
			
		||||
    self.inference_input_type = _dtypes.float32
 | 
			
		||||
    self.inference_output_type = _dtypes.float32
 | 
			
		||||
 | 
			
		||||
  def _validate_inference_input_output_types(self, quant_mode):
 | 
			
		||||
    """Validate inference_input_type and inference_output_type flags."""
 | 
			
		||||
    default_types = [constants.FLOAT]
 | 
			
		||||
    default_types = [_dtypes.float32]
 | 
			
		||||
    # We support integer input/output for integer quantized models only.
 | 
			
		||||
    if quant_mode.training_time_int8_allow_float():
 | 
			
		||||
      if quant_mode.is_post_training_integer_quantize_16x8():
 | 
			
		||||
        all_types = default_types + [constants.INT16]
 | 
			
		||||
        all_types = default_types + [_dtypes.int16]
 | 
			
		||||
      else:
 | 
			
		||||
        all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8]
 | 
			
		||||
        all_types = default_types + [_dtypes.int8, _dtypes.uint8]
 | 
			
		||||
      if self.inference_input_type not in all_types or \
 | 
			
		||||
          self.inference_output_type not in all_types:
 | 
			
		||||
        all_types_names = ["tf." + t.name for t in all_types]
 | 
			
		||||
@ -1148,7 +1146,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
 | 
			
		||||
        graph debug info for a set of nodes from the `graph_def`.
 | 
			
		||||
    """
 | 
			
		||||
    super(TFLiteConverterBaseV1, self).__init__()
 | 
			
		||||
    self.inference_type = constants.FLOAT
 | 
			
		||||
    self.inference_type = _dtypes.float32
 | 
			
		||||
    self.inference_input_type = None
 | 
			
		||||
    self.inference_output_type = None
 | 
			
		||||
    self.output_format = constants.TFLITE
 | 
			
		||||
@ -1195,7 +1193,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
 | 
			
		||||
  def _validate_quantized_input_stats(self, converter_kwargs, calibrate):
 | 
			
		||||
    """Ensure the `quantized_input_stats` flag is provided if required."""
 | 
			
		||||
 | 
			
		||||
    quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8})
 | 
			
		||||
    quantized_types = frozenset({_dtypes.int8, _dtypes.uint8})
 | 
			
		||||
 | 
			
		||||
    requires_quantized_input_stats = (
 | 
			
		||||
        (converter_kwargs["inference_type"] in quantized_types or
 | 
			
		||||
@ -1645,8 +1643,8 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
 | 
			
		||||
    quantized_input_stats: Dict of strings representing input tensor names
 | 
			
		||||
      mapped to tuple of floats representing the mean and standard deviation
 | 
			
		||||
      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
 | 
			
		||||
        `inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
 | 
			
		||||
        (quantized_input_value - mean_value) / std_dev_value. (default {})
 | 
			
		||||
      `inference_input_type` is `QUANTIZED_UINT8`. real_input_value =
 | 
			
		||||
      (quantized_input_value - mean_value) / std_dev_value. (default {})
 | 
			
		||||
    default_ranges_stats: Tuple of integers representing (min, max) range values
 | 
			
		||||
      for all arrays without a specified range. Intended for experimenting with
 | 
			
		||||
      quantization via "dummy quantization". (default None)
 | 
			
		||||
 | 
			
		||||
@ -155,8 +155,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    # Convert model and ensure model is not None.
 | 
			
		||||
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
 | 
			
		||||
                                                  [out_tensor])
 | 
			
		||||
    converter.inference_input_type = lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    converter.inference_type = lite_constants.FLOAT
 | 
			
		||||
    converter.inference_input_type = dtypes.uint8
 | 
			
		||||
    converter.inference_type = dtypes.float32
 | 
			
		||||
    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
 | 
			
		||||
    tflite_model = converter.convert()
 | 
			
		||||
    self.assertIsNotNone(tflite_model)
 | 
			
		||||
@ -788,8 +788,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    quantized_converter = lite.TFLiteConverter.from_session(
 | 
			
		||||
        sess, [inp], [output])
 | 
			
		||||
    quantized_converter.experimental_new_converter = enable_mlir_converter
 | 
			
		||||
    quantized_converter.inference_input_type = lite_constants.INT8
 | 
			
		||||
    quantized_converter.inference_output_type = lite_constants.INT8
 | 
			
		||||
    quantized_converter.inference_input_type = dtypes.int8
 | 
			
		||||
    quantized_converter.inference_output_type = dtypes.int8
 | 
			
		||||
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
 | 
			
		||||
    quantized_converter.representative_dataset = calibration_gen
 | 
			
		||||
    quantized_tflite_model = quantized_converter.convert()
 | 
			
		||||
@ -832,7 +832,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    quantized_converter.experimental_new_converter = enable_mlir_converter
 | 
			
		||||
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
 | 
			
		||||
    # Restricting to int8 type only
 | 
			
		||||
    quantized_converter.target_spec.supported_types = [lite.constants.INT8]
 | 
			
		||||
    quantized_converter.target_spec.supported_types = [dtypes.int8]
 | 
			
		||||
    # A representative dataset is required for full fixed point quantization.
 | 
			
		||||
    with self.assertRaises(ValueError) as error:
 | 
			
		||||
      quantized_converter.convert()
 | 
			
		||||
@ -857,7 +857,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    converter = lite.TFLiteConverter.from_session(sess,
 | 
			
		||||
                                                  [in_tensor_1, in_tensor_2],
 | 
			
		||||
                                                  [out_tensor])
 | 
			
		||||
    converter.inference_type = lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    converter.inference_type = dtypes.uint8
 | 
			
		||||
    converter.quantized_input_stats = {
 | 
			
		||||
        'inputA': (0., 1.),
 | 
			
		||||
        'inputB': (0., 1.)
 | 
			
		||||
@ -898,7 +898,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    # Convert model and ensure model is not None.
 | 
			
		||||
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
 | 
			
		||||
                                                  [out_tensor])
 | 
			
		||||
    converter.inference_type = lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    converter.inference_type = dtypes.uint8
 | 
			
		||||
    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
 | 
			
		||||
    converter.default_ranges_stats = (0, 6)  # min, max
 | 
			
		||||
    tflite_model = converter.convert()
 | 
			
		||||
@ -954,16 +954,15 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    interpreter.allocate_tensors()
 | 
			
		||||
    self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name)
 | 
			
		||||
    self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
 | 
			
		||||
                     lite.constants.FLOAT)
 | 
			
		||||
                     dtypes.float32)
 | 
			
		||||
    # Convert model to quantized version
 | 
			
		||||
    quantized_converter = lite.TFLiteConverter.from_session(
 | 
			
		||||
        sess, [inp], [output])
 | 
			
		||||
    quantized_converter.experimental_new_converter = enable_mlir_converter
 | 
			
		||||
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
 | 
			
		||||
    quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
 | 
			
		||||
    quantized_converter.target_spec.supported_types = [dtypes.float16]
 | 
			
		||||
    if include_int8:
 | 
			
		||||
      quantized_converter.target_spec.supported_types.append(
 | 
			
		||||
          lite.constants.INT8)
 | 
			
		||||
      quantized_converter.target_spec.supported_types.append(dtypes.int8)
 | 
			
		||||
    if use_rep_data:
 | 
			
		||||
      quantized_converter.representative_dataset = calibration_gen
 | 
			
		||||
 | 
			
		||||
@ -984,11 +983,11 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
      if is_float16_quantized:
 | 
			
		||||
        # Verify that bias constant is float16 type.
 | 
			
		||||
        self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
 | 
			
		||||
                         lite.constants.FLOAT16)
 | 
			
		||||
                         dtypes.float16)
 | 
			
		||||
      elif is_post_training_quantized:
 | 
			
		||||
        # Verify that bias constants is int32 type.
 | 
			
		||||
        self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'],
 | 
			
		||||
                         lite.constants.INT32)
 | 
			
		||||
                         dtypes.int32)
 | 
			
		||||
      else:
 | 
			
		||||
        raise ValueError('Invalid test options.')
 | 
			
		||||
 | 
			
		||||
@ -1005,7 +1004,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
        sess, [inp], [output])
 | 
			
		||||
    quantized_converter.experimental_new_converter = enable_mlir_converter
 | 
			
		||||
    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
 | 
			
		||||
    quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16]
 | 
			
		||||
    quantized_converter.target_spec.supported_types = [dtypes.float16]
 | 
			
		||||
    # Specify only int8 builtin ops
 | 
			
		||||
    quantized_converter.target_spec.supported_ops = [
 | 
			
		||||
        lite.OpsSet.TFLITE_BUILTINS_INT8
 | 
			
		||||
@ -1017,8 +1016,8 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
        str(error.exception))
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      ('InferenceType_INT8', lite_constants.INT8),
 | 
			
		||||
      ('InferenceType_UINT8', lite_constants.QUANTIZED_UINT8))
 | 
			
		||||
      ('InferenceType_INT8', dtypes.int8),
 | 
			
		||||
      ('InferenceType_UINT8', dtypes.uint8))
 | 
			
		||||
  def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      in_tensor = array_ops.placeholder(
 | 
			
		||||
@ -1039,7 +1038,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
        'flag is set to tf.uint8 or tf.int8.', str(error.exception))
 | 
			
		||||
 | 
			
		||||
    with self.assertRaises(ValueError) as error:
 | 
			
		||||
      quantized_converter.inference_type = lite_constants.FLOAT
 | 
			
		||||
      quantized_converter.inference_type = dtypes.float32
 | 
			
		||||
      quantized_converter.inference_input_type = quantized_type
 | 
			
		||||
      quantized_converter.convert()
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
@ -1070,7 +1069,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    converter = lite.TFLiteConverter.from_session(sess,
 | 
			
		||||
                                                  [in_tensor_1, in_tensor_2],
 | 
			
		||||
                                                  [out_tensor])
 | 
			
		||||
    converter.inference_type = lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    converter.inference_type = dtypes.uint8
 | 
			
		||||
    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
 | 
			
		||||
    with self.assertRaises(ValueError) as error:
 | 
			
		||||
      converter.convert()
 | 
			
		||||
@ -1091,9 +1090,9 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
 | 
			
		||||
 | 
			
		||||
    # extra flags to trigger training time quantization conversion
 | 
			
		||||
    converter.inference_type = lite_constants.INT8
 | 
			
		||||
    converter.inference_input_type = lite_constants.FLOAT
 | 
			
		||||
    converter.inference_output_type = lite_constants.FLOAT
 | 
			
		||||
    converter.inference_type = dtypes.int8
 | 
			
		||||
    converter.inference_input_type = dtypes.float32
 | 
			
		||||
    converter.inference_output_type = dtypes.float32
 | 
			
		||||
    input_arrays = converter.get_input_arrays()
 | 
			
		||||
    converter.quantized_input_stats = {
 | 
			
		||||
        input_arrays[0]: (0., 1.)
 | 
			
		||||
@ -1255,7 +1254,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
 | 
			
		||||
    # Convert model and ensure model is not None.
 | 
			
		||||
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
 | 
			
		||||
                                                  [out_tensor])
 | 
			
		||||
    converter.inference_type = lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    converter.inference_type = dtypes.uint8
 | 
			
		||||
    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
 | 
			
		||||
    tflite_model = converter.convert()
 | 
			
		||||
    self.assertIsNotNone(tflite_model)
 | 
			
		||||
@ -2334,7 +2333,7 @@ class DefaultConverterAttrsTest(LiteTest):
 | 
			
		||||
    self.assertEqual(converter.output_format, lite_constants.TFLITE)
 | 
			
		||||
 | 
			
		||||
    # Assert the default inference type is float.
 | 
			
		||||
    self.assertEqual(converter.inference_type, lite_constants.FLOAT)
 | 
			
		||||
    self.assertEqual(converter.inference_type, dtypes.float32)
 | 
			
		||||
 | 
			
		||||
    # Assert the default inference type overrides are None.
 | 
			
		||||
    self.assertIsNone(converter.inference_input_type)
 | 
			
		||||
 | 
			
		||||
@ -32,6 +32,7 @@ from tensorflow.lite.python import lite_v2_test_util
 | 
			
		||||
from tensorflow.lite.python.convert import mlir_quantize
 | 
			
		||||
from tensorflow.lite.python.interpreter import Interpreter
 | 
			
		||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
from tensorflow.python.keras.layers import recurrent
 | 
			
		||||
@ -74,9 +75,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 | 
			
		||||
    self.assertEqual(expected_value.numpy(), actual_value)
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      ('_INT8InputOutput', lite.constants.INT8),
 | 
			
		||||
      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8),
 | 
			
		||||
      ('_INT16InputOutput', lite.constants.INT16))
 | 
			
		||||
      ('_INT8InputOutput', dtypes.int8),
 | 
			
		||||
      ('_UINT8InputOutput', dtypes.uint8),
 | 
			
		||||
      ('_INT16InputOutput', dtypes.int16))
 | 
			
		||||
  @test_util.run_v2_only
 | 
			
		||||
  def testInvalidFloat(self, inference_input_output_type):
 | 
			
		||||
    root = self._getSimpleVariableModel()
 | 
			
		||||
@ -194,9 +195,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 | 
			
		||||
    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      ('_INT8InputOutput', lite.constants.INT8),
 | 
			
		||||
      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8),
 | 
			
		||||
      ('_INT16InputOutput', lite.constants.INT16))
 | 
			
		||||
      ('_INT8InputOutput', dtypes.int8),
 | 
			
		||||
      ('_UINT8InputOutput', dtypes.uint8),
 | 
			
		||||
      ('_INT16InputOutput', dtypes.int16))
 | 
			
		||||
  @test_util.run_v2_only
 | 
			
		||||
  def testInvalidPostTrainingDynamicRangeQuantization(
 | 
			
		||||
      self, inference_input_output_type):
 | 
			
		||||
@ -219,18 +220,18 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 | 
			
		||||
        'must be tf.float32.', str(error.exception))
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      ('_Default', False, False, lite.constants.FLOAT),
 | 
			
		||||
      ('_INT8InputOutput', False, False, lite.constants.INT8),
 | 
			
		||||
      ('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8),
 | 
			
		||||
      ('_INT16Quantize', False, True, lite.constants.FLOAT),
 | 
			
		||||
      ('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16),
 | 
			
		||||
      ('_IntOnly', True, False, lite.constants.FLOAT),
 | 
			
		||||
      ('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8),
 | 
			
		||||
      ('_Default', False, False, dtypes.float32),
 | 
			
		||||
      ('_INT8InputOutput', False, False, dtypes.int8),
 | 
			
		||||
      ('_UINT8InputOutput', False, False, dtypes.uint8),
 | 
			
		||||
      ('_INT16Quantize', False, True, dtypes.float32),
 | 
			
		||||
      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
 | 
			
		||||
      ('_IntOnly', True, False, dtypes.float32),
 | 
			
		||||
      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
 | 
			
		||||
      ('_IntOnly_UINT8InputOutput', True, False,
 | 
			
		||||
       lite.constants.QUANTIZED_UINT8),
 | 
			
		||||
      ('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT),
 | 
			
		||||
       dtypes.uint8),
 | 
			
		||||
      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
 | 
			
		||||
      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True,
 | 
			
		||||
       lite.constants.INT16))
 | 
			
		||||
       dtypes.int16))
 | 
			
		||||
  def testIntegerQuantization(self, is_int_only, is_int16_quantize,
 | 
			
		||||
                              inference_input_output_type):
 | 
			
		||||
    func, calibration_gen = self._getIntegerQuantizeModel()
 | 
			
		||||
@ -281,7 +282,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 | 
			
		||||
    self.assertLess(len(quantized_tflite_model), len(tflite_model))
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      ('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8))
 | 
			
		||||
      ('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
 | 
			
		||||
  def testInvalidIntegerQuantization(self, is_int16_quantize,
 | 
			
		||||
                                     inference_input_output_type):
 | 
			
		||||
    func, calibration_gen = self._getIntegerQuantizeModel()
 | 
			
		||||
@ -297,8 +298,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 | 
			
		||||
          lite.OpsSet.TFLITE_BUILTINS
 | 
			
		||||
      ]
 | 
			
		||||
    with self.assertRaises(ValueError) as error:
 | 
			
		||||
      quantized_converter.inference_input_type = lite.constants.INT8
 | 
			
		||||
      quantized_converter.inference_output_type = lite.constants.INT8
 | 
			
		||||
      quantized_converter.inference_input_type = dtypes.int8
 | 
			
		||||
      quantized_converter.inference_output_type = dtypes.int8
 | 
			
		||||
      quantized_converter.convert()
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        "The inference_input_type and inference_output_type "
 | 
			
		||||
@ -377,9 +378,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
 | 
			
		||||
    return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT),
 | 
			
		||||
      ('_INT8InputOutput', lite.constants.INT8),
 | 
			
		||||
      ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
 | 
			
		||||
      ('_DefaultFLOAT32InputOutput', dtypes.float32),
 | 
			
		||||
      ('_INT8InputOutput', dtypes.int8),
 | 
			
		||||
      ('_UINT8InputOutput', dtypes.uint8))
 | 
			
		||||
  @test_util.run_v2_only
 | 
			
		||||
  def testTrainingTimeQuantization(self, inference_input_output_type):
 | 
			
		||||
    model = self._getTrainingTimeQuantizedModel()
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,7 @@ py_library(
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":_pywrap_tensorflow_lite_calibration_wrapper",  # buildcleaner: keep
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:util",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
    ],
 | 
			
		||||
@ -67,8 +68,8 @@ py_test(
 | 
			
		||||
    tags = ["no_oss"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":calibrator",
 | 
			
		||||
        "//tensorflow/lite/python:lite_constants",
 | 
			
		||||
        "//tensorflow/python:client_testlib",
 | 
			
		||||
        "//tensorflow/python:dtypes",
 | 
			
		||||
        "//tensorflow/python:framework_test_lib",
 | 
			
		||||
        "//tensorflow/python:platform",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
 | 
			
		||||
@ -18,8 +18,9 @@ from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.util.lazy_loader import LazyLoader
 | 
			
		||||
from tensorflow.lite.python import lite_constants
 | 
			
		||||
 | 
			
		||||
# Lazy load since some of the performance benchmark skylark rules
 | 
			
		||||
# break dependencies. Must use double quotes to match code internal rewrite
 | 
			
		||||
@ -60,7 +61,7 @@ class Calibrator(object):
 | 
			
		||||
                             input_type,
 | 
			
		||||
                             output_type,
 | 
			
		||||
                             allow_float,
 | 
			
		||||
                             activations_type=lite_constants.INT8,
 | 
			
		||||
                             activations_type=dtypes.int8,
 | 
			
		||||
                             resize_input=True):
 | 
			
		||||
    """Calibrates the model with specified generator and then quantizes it.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,8 +23,8 @@ from absl.testing import parameterized
 | 
			
		||||
import numpy as np
 | 
			
		||||
from six.moves import range
 | 
			
		||||
 | 
			
		||||
from tensorflow.lite.python import lite_constants as constants
 | 
			
		||||
from tensorflow.lite.python.optimize import calibrator as _calibrator
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
from tensorflow.python.platform import resource_loader
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
@ -34,9 +34,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      # Activation type Int8
 | 
			
		||||
      ('UseActivationTypeInt8', constants.INT8),
 | 
			
		||||
      ('UseActivationTypeInt8', dtypes.int8),
 | 
			
		||||
      # Activation type Int16
 | 
			
		||||
      ('UseActivationTypeInt16', constants.INT16))
 | 
			
		||||
      ('UseActivationTypeInt16', dtypes.int16))
 | 
			
		||||
  def test_calibration_with_quantization(self, activations_type):
 | 
			
		||||
    model_path = resource_loader.get_path_to_datafile(
 | 
			
		||||
        'test_data/mobilenet_like_model.bin')
 | 
			
		||||
@ -49,16 +49,17 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
 | 
			
		||||
 | 
			
		||||
    quantized_model = quantizer.calibrate_and_quantize(input_gen,
 | 
			
		||||
                                                       constants.FLOAT,
 | 
			
		||||
                                                       constants.FLOAT, False,
 | 
			
		||||
                                                       dtypes.float32,
 | 
			
		||||
                                                       dtypes.float32,
 | 
			
		||||
                                                       False,
 | 
			
		||||
                                                       activations_type)
 | 
			
		||||
    self.assertIsNotNone(quantized_model)
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      # Activation type Int8
 | 
			
		||||
      ('UseActivationTypeInt8', constants.INT8),
 | 
			
		||||
      ('UseActivationTypeInt8', dtypes.int8),
 | 
			
		||||
      # Activation type Int16
 | 
			
		||||
      ('UseActivationTypeInt16', constants.INT16))
 | 
			
		||||
      ('UseActivationTypeInt16', dtypes.int16))
 | 
			
		||||
  def test_calibration_with_quantization_allow_float(self, activations_type):
 | 
			
		||||
    model_path = resource_loader.get_path_to_datafile(
 | 
			
		||||
        'test_data/mobilenet_like_model.bin')
 | 
			
		||||
@ -71,8 +72,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
 | 
			
		||||
 | 
			
		||||
    quantized_model = quantizer.calibrate_and_quantize(input_gen,
 | 
			
		||||
                                                       constants.FLOAT,
 | 
			
		||||
                                                       constants.FLOAT, True,
 | 
			
		||||
                                                       dtypes.float32,
 | 
			
		||||
                                                       dtypes.float32,
 | 
			
		||||
                                                       True,
 | 
			
		||||
                                                       activations_type)
 | 
			
		||||
    self.assertIsNotNone(quantized_model)
 | 
			
		||||
 | 
			
		||||
@ -88,7 +90,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
 | 
			
		||||
 | 
			
		||||
    quantized_model = quantizer.calibrate_and_quantize_single(
 | 
			
		||||
        input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd')
 | 
			
		||||
        input_gen, dtypes.float32, dtypes.float32, True, 'conv2d_8/BiasAdd')
 | 
			
		||||
    self.assertIsNotNone(quantized_model)
 | 
			
		||||
 | 
			
		||||
  def test_calibration_with_string_input(self):
 | 
			
		||||
@ -103,14 +105,14 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.array(u'Test' + str(i))]
 | 
			
		||||
 | 
			
		||||
    quantized_model = quantizer.calibrate_and_quantize_single(
 | 
			
		||||
        input_gen, constants.FLOAT, constants.FLOAT, True, 'Identity')
 | 
			
		||||
        input_gen, dtypes.float32, dtypes.float32, True, 'Identity')
 | 
			
		||||
    self.assertIsNotNone(quantized_model)
 | 
			
		||||
 | 
			
		||||
  @parameterized.named_parameters(
 | 
			
		||||
      # Activation type Int8
 | 
			
		||||
      ('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8),
 | 
			
		||||
      ('UseActivationTypeInt8 - EnableMlirQuantizer', dtypes.int8),
 | 
			
		||||
      # Activation type Int16
 | 
			
		||||
      ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16))
 | 
			
		||||
      ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', dtypes.int16))
 | 
			
		||||
  def test_calibration_with_quantization_multiple_inputs(
 | 
			
		||||
      self, activations_type):
 | 
			
		||||
    # Load multi add model from test data.
 | 
			
		||||
@ -126,8 +128,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
 | 
			
		||||
 | 
			
		||||
    quantized_model = quantizer.calibrate_and_quantize(input_gen,
 | 
			
		||||
                                                       constants.FLOAT,
 | 
			
		||||
                                                       constants.FLOAT, False,
 | 
			
		||||
                                                       dtypes.float32,
 | 
			
		||||
                                                       dtypes.float32,
 | 
			
		||||
                                                       False,
 | 
			
		||||
                                                       activations_type)
 | 
			
		||||
    self.assertIsNotNone(quantized_model)
 | 
			
		||||
 | 
			
		||||
@ -148,8 +151,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield i
 | 
			
		||||
 | 
			
		||||
    with self.assertRaises(RuntimeError):
 | 
			
		||||
      quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT,
 | 
			
		||||
                                       constants.FLOAT, False)
 | 
			
		||||
      quantizer.calibrate_and_quantize(empty_input_gen, dtypes.float32,
 | 
			
		||||
                                       dtypes.float32, False)
 | 
			
		||||
 | 
			
		||||
  def test_invalid_shape_calibrator_gen(self):
 | 
			
		||||
    model_path = resource_loader.get_path_to_datafile(
 | 
			
		||||
@ -163,8 +166,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)]
 | 
			
		||||
 | 
			
		||||
    with self.assertRaisesRegex(ValueError, 'Size mismatch'):
 | 
			
		||||
      quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
 | 
			
		||||
                                       constants.FLOAT, False, constants.INT8,
 | 
			
		||||
      quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
 | 
			
		||||
                                       dtypes.float32, False, dtypes.int8,
 | 
			
		||||
                                       False)
 | 
			
		||||
 | 
			
		||||
  def test_invalid_type_calibrator_gen(self):
 | 
			
		||||
@ -179,8 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
 | 
			
		||||
        yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)]
 | 
			
		||||
 | 
			
		||||
    with self.assertRaises(ValueError):
 | 
			
		||||
      quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
 | 
			
		||||
                                       constants.FLOAT, False, constants.INT8)
 | 
			
		||||
      quantizer.calibrate_and_quantize(input_gen, dtypes.float32,
 | 
			
		||||
                                       dtypes.float32, False, dtypes.int8)
 | 
			
		||||
 | 
			
		||||
  def test_calibration(self):
 | 
			
		||||
    model_path = resource_loader.get_path_to_datafile(
 | 
			
		||||
 | 
			
		||||
@ -28,12 +28,12 @@ import six
 | 
			
		||||
from six.moves import zip
 | 
			
		||||
 | 
			
		||||
from tensorflow.lite.python import lite
 | 
			
		||||
from tensorflow.lite.python import lite_constants
 | 
			
		||||
from tensorflow.lite.python.convert import register_custom_opdefs
 | 
			
		||||
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
 | 
			
		||||
from tensorflow.lite.toco.logging import gen_html
 | 
			
		||||
from tensorflow.python import keras
 | 
			
		||||
from tensorflow.python import tf2
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.platform import app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,13 +63,14 @@ def _parse_inference_type(value, flag):
 | 
			
		||||
    ValueError: Unsupported value.
 | 
			
		||||
  """
 | 
			
		||||
  if value == "FLOAT":
 | 
			
		||||
    return lite_constants.FLOAT
 | 
			
		||||
  if value == "QUANTIZED_UINT8":
 | 
			
		||||
    return lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    return dtypes.float32
 | 
			
		||||
  if value == "INT8":
 | 
			
		||||
    return lite_constants.INT8
 | 
			
		||||
  raise ValueError("Unsupported value for --{0}. Only FLOAT and "
 | 
			
		||||
                   "QUANTIZED_UINT8 are supported.".format(flag))
 | 
			
		||||
    return dtypes.int8
 | 
			
		||||
  if value == "UINT8" or value == "QUANTIZED_UINT8":
 | 
			
		||||
    return dtypes.uint8
 | 
			
		||||
  raise ValueError(
 | 
			
		||||
      "Unsupported value for `{}` flag. Expected FLOAT, INT8 or UINT8, instead "
 | 
			
		||||
      "got {}.".format(flag, value))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_tflite_converter(flags):
 | 
			
		||||
@ -151,10 +152,10 @@ def _convert_tf1_model(flags):
 | 
			
		||||
 | 
			
		||||
    # In quantized inference, mean_value has to be integer so that the real
 | 
			
		||||
    # value 0.0 is exactly representable.
 | 
			
		||||
    if converter.inference_type == lite_constants.QUANTIZED_UINT8:
 | 
			
		||||
      mean_values = _parse_array(flags.mean_values, type_fn=int)
 | 
			
		||||
    else:
 | 
			
		||||
    if converter.inference_type == dtypes.float32:
 | 
			
		||||
      mean_values = _parse_array(flags.mean_values, type_fn=float)
 | 
			
		||||
    else:
 | 
			
		||||
      mean_values = _parse_array(flags.mean_values, type_fn=int)
 | 
			
		||||
    quant_stats = list(zip(mean_values, std_dev_values))
 | 
			
		||||
    if ((not flags.input_arrays and len(input_arrays) > 1) or
 | 
			
		||||
        (len(input_arrays) != len(quant_stats))):
 | 
			
		||||
@ -193,13 +194,13 @@ def _convert_tf1_model(flags):
 | 
			
		||||
 | 
			
		||||
  if flags.post_training_quantize:
 | 
			
		||||
    converter.optimizations = [lite.Optimize.DEFAULT]
 | 
			
		||||
    if converter.inference_type == lite_constants.QUANTIZED_UINT8:
 | 
			
		||||
    if converter.inference_type != dtypes.float32:
 | 
			
		||||
      print("--post_training_quantize quantizes a graph of inference_type "
 | 
			
		||||
            "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
 | 
			
		||||
      converter.inference_type = lite_constants.FLOAT
 | 
			
		||||
            "FLOAT. Overriding inference_type to FLOAT.")
 | 
			
		||||
      converter.inference_type = dtypes.float32
 | 
			
		||||
 | 
			
		||||
  if flags.quantize_to_float16:
 | 
			
		||||
    converter.target_spec.supported_types = [lite.constants.FLOAT16]
 | 
			
		||||
    converter.target_spec.supported_types = [dtypes.float16]
 | 
			
		||||
    if not flags.post_training_quantize:
 | 
			
		||||
      print("--quantize_to_float16 will only take effect with the "
 | 
			
		||||
            "--post_training_quantize flag enabled.")
 | 
			
		||||
@ -358,14 +359,15 @@ def _get_tf1_flags(parser):
 | 
			
		||||
  parser.add_argument(
 | 
			
		||||
      "--inference_type",
 | 
			
		||||
      type=str.upper,
 | 
			
		||||
      choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
 | 
			
		||||
      help="Target data type of real-number arrays in the output file.")
 | 
			
		||||
      default="FLOAT",
 | 
			
		||||
      help=("Target data type of real-number arrays in the output file. "
 | 
			
		||||
            "Must be either FLOAT, INT8 or UINT8."))
 | 
			
		||||
  parser.add_argument(
 | 
			
		||||
      "--inference_input_type",
 | 
			
		||||
      type=str.upper,
 | 
			
		||||
      choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
 | 
			
		||||
      help=("Target data type of real-number input arrays. Allows for a "
 | 
			
		||||
            "different type for input arrays in the case of quantization."))
 | 
			
		||||
            "different type for input arrays in the case of quantization. "
 | 
			
		||||
            "Must be either FLOAT, INT8 or UINT8."))
 | 
			
		||||
 | 
			
		||||
  # Input and output arrays flags.
 | 
			
		||||
  parser.add_argument(
 | 
			
		||||
 | 
			
		||||
@ -168,6 +168,37 @@ class TfLiteConvertV1Test(TestModels):
 | 
			
		||||
    self._run(flags_str, should_succeed=True)
 | 
			
		||||
    os.remove(graph_def_file)
 | 
			
		||||
 | 
			
		||||
  def testQATFrozenGraphDefUInt8(self):
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
      in_tensor_1 = array_ops.placeholder(
 | 
			
		||||
          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
 | 
			
		||||
      in_tensor_2 = array_ops.placeholder(
 | 
			
		||||
          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
 | 
			
		||||
      _ = array_ops.fake_quant_with_min_max_args(
 | 
			
		||||
          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
 | 
			
		||||
      sess = session.Session()
 | 
			
		||||
 | 
			
		||||
    # Write graph to file.
 | 
			
		||||
    graph_def_file = self._getFilepath('model.pb')
 | 
			
		||||
    write_graph(sess.graph_def, '', graph_def_file, False)
 | 
			
		||||
    sess.close()
 | 
			
		||||
 | 
			
		||||
    # Define converter flags
 | 
			
		||||
    flags_str = ('--std_dev_values=128,128 --mean_values=128,128 '
 | 
			
		||||
                 '--graph_def_file={0} --input_arrays={1} '
 | 
			
		||||
                 '--output_arrays={2}'.format(
 | 
			
		||||
                     graph_def_file, 'inputA,inputB', 'output'))
 | 
			
		||||
 | 
			
		||||
    # Set inference_type UINT8 and (default) inference_input_type UINT8
 | 
			
		||||
    flags_str_1 = flags_str + ' --inference_type=UINT8'
 | 
			
		||||
    self._run(flags_str_1, should_succeed=True)
 | 
			
		||||
 | 
			
		||||
    # Set inference_type UINT8 and inference_input_type FLOAT
 | 
			
		||||
    flags_str_2 = flags_str_1 + ' --inference_input_type=FLOAT'
 | 
			
		||||
    self._run(flags_str_2, should_succeed=True)
 | 
			
		||||
 | 
			
		||||
    os.remove(graph_def_file)
 | 
			
		||||
 | 
			
		||||
  def testSavedModel(self):
 | 
			
		||||
    saved_model_dir = self._getFilepath('model')
 | 
			
		||||
    with ops.Graph().as_default():
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,6 @@ import numpy as np
 | 
			
		||||
from six.moves import range
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from tensorflow.lite.python import lite_constants
 | 
			
		||||
from tensorflow.lite.python import util
 | 
			
		||||
from tensorflow.lite.toco import types_pb2 as _types_pb2
 | 
			
		||||
from tensorflow.python.client import session
 | 
			
		||||
@ -292,9 +291,9 @@ def _test_param_modify_integer_model_io_type():
 | 
			
		||||
      # "DuringTraining": False,
 | 
			
		||||
  }
 | 
			
		||||
  map_types = {
 | 
			
		||||
      "": lite_constants.FLOAT,
 | 
			
		||||
      "INT8": lite_constants.INT8,
 | 
			
		||||
      "UINT8": lite_constants.QUANTIZED_UINT8
 | 
			
		||||
      "": dtypes.float32,
 | 
			
		||||
      "INT8": dtypes.int8,
 | 
			
		||||
      "UINT8": dtypes.uint8,
 | 
			
		||||
  }
 | 
			
		||||
  for k1, v1 in map_model_type.items():
 | 
			
		||||
    for k2, v2 in map_types.items():
 | 
			
		||||
 | 
			
		||||
@ -28,11 +28,11 @@ from google.protobuf.message import DecodeError
 | 
			
		||||
from tensorflow.core.framework import graph_pb2 as _graph_pb2
 | 
			
		||||
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
 | 
			
		||||
from tensorflow.lite.python import lite as _lite
 | 
			
		||||
from tensorflow.lite.python import lite_constants as constants
 | 
			
		||||
from tensorflow.lite.python import util as _util
 | 
			
		||||
from tensorflow.python import keras as _keras
 | 
			
		||||
from tensorflow.python.client import session as _session
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
 | 
			
		||||
from tensorflow.python.keras.preprocessing import image
 | 
			
		||||
@ -97,7 +97,7 @@ def _convert(converter, **kwargs):
 | 
			
		||||
  if "post_training_quantize" in kwargs:
 | 
			
		||||
    converter.optimizations = [_lite.Optimize.DEFAULT]
 | 
			
		||||
  if kwargs.get("quantize_to_float16", False):
 | 
			
		||||
    converter.target_spec.supported_types = [constants.FLOAT16]
 | 
			
		||||
    converter.target_spec.supported_types = [dtypes.float16]
 | 
			
		||||
  if kwargs.get("post_training_quantize_16x8", False):
 | 
			
		||||
    input_size = kwargs.get("model_input_size")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ py_library(
 | 
			
		||||
    name = "modify_model_interface_constants",
 | 
			
		||||
    srcs = ["modify_model_interface_constants.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    deps = ["//tensorflow/lite/python:lite_constants"],
 | 
			
		||||
    deps = ["//tensorflow/python:dtypes"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
pybind_extension(
 | 
			
		||||
 | 
			
		||||
@ -19,12 +19,12 @@ from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from tensorflow.lite.python import lite_constants
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
 | 
			
		||||
STR_TO_TFLITE_TYPES = {
 | 
			
		||||
    'INT8': lite_constants.INT8,
 | 
			
		||||
    'INT16': lite_constants.INT16,
 | 
			
		||||
    'UINT8': lite_constants.QUANTIZED_UINT8
 | 
			
		||||
    'INT8': dtypes.int8,
 | 
			
		||||
    'UINT8': dtypes.uint8,
 | 
			
		||||
    'INT16': dtypes.int16,
 | 
			
		||||
}
 | 
			
		||||
TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user