From eccaaff6d81ecb2b663f9b00fbf57a4c7dccd6b4 Mon Sep 17 00:00:00 2001 From: Jaesung Chung <jaesung@google.com> Date: Wed, 21 Oct 2020 18:21:38 -0700 Subject: [PATCH] Relax quantization checks for the SELECT TF OPS set Also added the test cases to increase test coverage on flex op enabled models. PiperOrigin-RevId: 338380353 Change-Id: I129d1931c9e2ec1f7faa1115bc4dfa1d55c37627 --- tensorflow/lite/python/lite.py | 19 +++-- tensorflow/lite/python/lite_v2_test.py | 107 +++++++++++++++++++------ 2 files changed, 91 insertions(+), 35 deletions(-) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 77bfff3199c..362145435a9 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -192,6 +192,7 @@ class QuantizationMode(object): """Post training int8 quantize, disallow float fallback.""" return (self._is_int8_target_required() and not self._is_int16x8_target_required() and + not self._is_allow_float() and self._representative_dataset is not None) def post_training_int8_allow_float(self): @@ -351,20 +352,18 @@ class QuantizationMode(object): "TFLITE_BUILTINS_INT8 or INT8 supported types.") def _is_int8_target_required(self): - return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set( - self._target_spec.supported_ops) or - set(self._target_spec.supported_types) == set([_dtypes.int8])) + return (OpsSet.TFLITE_BUILTINS_INT8 in set( + self._target_spec.supported_ops)) or (set( + self._target_spec.supported_types) == set([_dtypes.int8])) def _is_int16x8_target_required(self): - return bool( - set(self._target_spec.supported_ops).intersection([ - OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 - ])) + return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8 + in set(self._target_spec.supported_ops)) def _is_allow_float(self): - return bool( - set(self._target_spec.supported_ops).intersection( - [OpsSet.TFLITE_BUILTINS])) + return (OpsSet.TFLITE_BUILTINS in set( + self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set( + self._target_spec.supported_ops)) def _any_optimization_enabled(self): return bool( diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 4851912226a..db38287d9c2 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -488,35 +488,92 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): converter.convert() self._assertValidDebugInfo(converter._debug_info) + def _getIntegerQuantizationModelWithFlexOp(self): + np.random.seed(0) + + root = tracking.AutoTrackable() + + @tf.function(input_signature=[ + tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32) + ]) + def func(inp): + tanh = tf.math.tanh(inp) + conv3d = tf.nn.conv3d( + tanh, + tf.ones([3, 3, 3, 3, 3]), + strides=[1, 1, 1, 1, 1], + padding='SAME') + output = tf.math.tanh(conv3d) + return output + + def calibration_gen(): + for _ in range(5): + yield [ + np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32) + ] + + root.f = func + return (root.f.get_concrete_function(), calibration_gen) + + @parameterized.named_parameters( + ('_Default', False, False, dtypes.float32), + ('_INT8InputOutput', False, False, dtypes.int8), + ('_UINT8InputOutput', False, False, dtypes.uint8), + ('_INT16Quantize', False, True, dtypes.float32), + ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), + ('_IntOnly', True, False, dtypes.float32), + ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), + ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8), + ('_IntOnly_INT16Quantize', True, True, dtypes.float32), + ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16)) @test_util.run_v2_only - def testFlexOpWithInt8OpSet(self): - model = tf.keras.Sequential() - input_shape = (1, 4, 4, 4, 1) - model.add( - tf.keras.layers.Conv3D( - 4, - kernel_size=(1, 1, 1), - activation='relu', - input_shape=input_shape[1:])) - model.add(tf.keras.layers.Flatten()) - model.add(tf.keras.layers.Dense(2, activation='relu')) + def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize, + inference_input_output_type): + func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp() - @tf.function( - input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) - def _call_fn(inputs): - return model(inputs, training=False) + quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions( + [func]) + quantized_converter.optimizations = [lite.Optimize.DEFAULT] + quantized_converter.representative_dataset = calibration_gen + if is_int_only: + if is_int16_quantize: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.\ + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, + lite.OpsSet.SELECT_TF_OPS + ] + else: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS + ] + else: + if is_int16_quantize: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.\ + EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, + lite.OpsSet.TFLITE_BUILTINS, + lite.OpsSet.SELECT_TF_OPS + ] + else: + quantized_converter.target_spec.supported_ops = [ + lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS + ] - concrete_func = _call_fn.get_concrete_function( - tf.TensorSpec(input_shape, dtype=tf.float32)) + quantized_converter.inference_input_type = inference_input_output_type + quantized_converter.inference_output_type = inference_input_output_type + quantized_tflite_model = quantized_converter.convert() + self.assertIsNotNone(quantized_tflite_model) - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS_INT8, - tf.lite.OpsSet.SELECT_TF_OPS, - ] - tflite_model = converter.convert() - self.assertTrue(tflite_model) + interpreter = Interpreter(model_content=quantized_tflite_model) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual(inference_input_output_type.as_numpy_dtype, + input_details[0]['dtype']) + output_details = interpreter.get_output_details() + self.assertLen(output_details, 1) + self.assertEqual(inference_input_output_type.as_numpy_dtype, + output_details[0]['dtype']) class FromSavedModelTest(lite_v2_test_util.ModelTest):