From eccaaff6d81ecb2b663f9b00fbf57a4c7dccd6b4 Mon Sep 17 00:00:00 2001
From: Jaesung Chung <jaesung@google.com>
Date: Wed, 21 Oct 2020 18:21:38 -0700
Subject: [PATCH] Relax quantization checks for the SELECT TF OPS set

Also added the test cases to increase test coverage on flex op enabled models.

PiperOrigin-RevId: 338380353
Change-Id: I129d1931c9e2ec1f7faa1115bc4dfa1d55c37627
---
 tensorflow/lite/python/lite.py         |  19 +++--
 tensorflow/lite/python/lite_v2_test.py | 107 +++++++++++++++++++------
 2 files changed, 91 insertions(+), 35 deletions(-)

diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py
index 77bfff3199c..362145435a9 100644
--- a/tensorflow/lite/python/lite.py
+++ b/tensorflow/lite/python/lite.py
@@ -192,6 +192,7 @@ class QuantizationMode(object):
     """Post training int8 quantize, disallow float fallback."""
     return (self._is_int8_target_required() and
             not self._is_int16x8_target_required() and
+            not self._is_allow_float() and
             self._representative_dataset is not None)
 
   def post_training_int8_allow_float(self):
@@ -351,20 +352,18 @@ class QuantizationMode(object):
                        "TFLITE_BUILTINS_INT8 or INT8 supported types.")
 
   def _is_int8_target_required(self):
-    return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(
-        self._target_spec.supported_ops) or
-            set(self._target_spec.supported_types) == set([_dtypes.int8]))
+    return (OpsSet.TFLITE_BUILTINS_INT8 in set(
+        self._target_spec.supported_ops)) or (set(
+            self._target_spec.supported_types) == set([_dtypes.int8]))
 
   def _is_int16x8_target_required(self):
-    return bool(
-        set(self._target_spec.supported_ops).intersection([
-            OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
-        ]))
+    return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
+            in set(self._target_spec.supported_ops))
 
   def _is_allow_float(self):
-    return bool(
-        set(self._target_spec.supported_ops).intersection(
-            [OpsSet.TFLITE_BUILTINS]))
+    return (OpsSet.TFLITE_BUILTINS in set(
+        self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
+            self._target_spec.supported_ops))
 
   def _any_optimization_enabled(self):
     return bool(
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 4851912226a..db38287d9c2 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -488,35 +488,92 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
     converter.convert()
     self._assertValidDebugInfo(converter._debug_info)
 
+  def _getIntegerQuantizationModelWithFlexOp(self):
+    np.random.seed(0)
+
+    root = tracking.AutoTrackable()
+
+    @tf.function(input_signature=[
+        tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
+    ])
+    def func(inp):
+      tanh = tf.math.tanh(inp)
+      conv3d = tf.nn.conv3d(
+          tanh,
+          tf.ones([3, 3, 3, 3, 3]),
+          strides=[1, 1, 1, 1, 1],
+          padding='SAME')
+      output = tf.math.tanh(conv3d)
+      return output
+
+    def calibration_gen():
+      for _ in range(5):
+        yield [
+            np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32)
+        ]
+
+    root.f = func
+    return (root.f.get_concrete_function(), calibration_gen)
+
+  @parameterized.named_parameters(
+      ('_Default', False, False, dtypes.float32),
+      ('_INT8InputOutput', False, False, dtypes.int8),
+      ('_UINT8InputOutput', False, False, dtypes.uint8),
+      ('_INT16Quantize', False, True, dtypes.float32),
+      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
+      ('_IntOnly', True, False, dtypes.float32),
+      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
+      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
+      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
+      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
   @test_util.run_v2_only
-  def testFlexOpWithInt8OpSet(self):
-    model = tf.keras.Sequential()
-    input_shape = (1, 4, 4, 4, 1)
-    model.add(
-        tf.keras.layers.Conv3D(
-            4,
-            kernel_size=(1, 1, 1),
-            activation='relu',
-            input_shape=input_shape[1:]))
-    model.add(tf.keras.layers.Flatten())
-    model.add(tf.keras.layers.Dense(2, activation='relu'))
+  def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize,
+                                        inference_input_output_type):
+    func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp()
 
-    @tf.function(
-        input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
-    def _call_fn(inputs):
-      return model(inputs, training=False)
+    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
+        [func])
+    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
+    quantized_converter.representative_dataset = calibration_gen
+    if is_int_only:
+      if is_int16_quantize:
+        quantized_converter.target_spec.supported_ops = [
+            lite.OpsSet.\
+            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
+            lite.OpsSet.SELECT_TF_OPS
+        ]
+      else:
+        quantized_converter.target_spec.supported_ops = [
+            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS
+        ]
+    else:
+      if is_int16_quantize:
+        quantized_converter.target_spec.supported_ops = [
+            lite.OpsSet.\
+            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
+            lite.OpsSet.TFLITE_BUILTINS,
+            lite.OpsSet.SELECT_TF_OPS
+        ]
+      else:
+        quantized_converter.target_spec.supported_ops = [
+            lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS
+        ]
 
-    concrete_func = _call_fn.get_concrete_function(
-        tf.TensorSpec(input_shape, dtype=tf.float32))
+    quantized_converter.inference_input_type = inference_input_output_type
+    quantized_converter.inference_output_type = inference_input_output_type
+    quantized_tflite_model = quantized_converter.convert()
+    self.assertIsNotNone(quantized_tflite_model)
 
-    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
-    converter.optimizations = [tf.lite.Optimize.DEFAULT]
-    converter.target_spec.supported_ops = [
-        tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
-        tf.lite.OpsSet.SELECT_TF_OPS,
-    ]
-    tflite_model = converter.convert()
-    self.assertTrue(tflite_model)
+    interpreter = Interpreter(model_content=quantized_tflite_model)
+    interpreter.allocate_tensors()
+    input_details = interpreter.get_input_details()
+    self.assertLen(input_details, 1)
+    self.assertEqual(inference_input_output_type.as_numpy_dtype,
+                     input_details[0]['dtype'])
+    output_details = interpreter.get_output_details()
+    self.assertLen(output_details, 1)
+    self.assertEqual(inference_input_output_type.as_numpy_dtype,
+                     output_details[0]['dtype'])
 
 
 class FromSavedModelTest(lite_v2_test_util.ModelTest):