Relax quantization checks for the SELECT TF OPS set
Also added the test cases to increase test coverage on flex op enabled models. PiperOrigin-RevId: 338380353 Change-Id: I129d1931c9e2ec1f7faa1115bc4dfa1d55c37627
This commit is contained in:
parent
d579f1a7de
commit
eccaaff6d8
tensorflow/lite/python
@ -192,6 +192,7 @@ class QuantizationMode(object):
|
||||
"""Post training int8 quantize, disallow float fallback."""
|
||||
return (self._is_int8_target_required() and
|
||||
not self._is_int16x8_target_required() and
|
||||
not self._is_allow_float() and
|
||||
self._representative_dataset is not None)
|
||||
|
||||
def post_training_int8_allow_float(self):
|
||||
@ -351,20 +352,18 @@ class QuantizationMode(object):
|
||||
"TFLITE_BUILTINS_INT8 or INT8 supported types.")
|
||||
|
||||
def _is_int8_target_required(self):
|
||||
return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set(
|
||||
self._target_spec.supported_ops) or
|
||||
set(self._target_spec.supported_types) == set([_dtypes.int8]))
|
||||
return (OpsSet.TFLITE_BUILTINS_INT8 in set(
|
||||
self._target_spec.supported_ops)) or (set(
|
||||
self._target_spec.supported_types) == set([_dtypes.int8]))
|
||||
|
||||
def _is_int16x8_target_required(self):
|
||||
return bool(
|
||||
set(self._target_spec.supported_ops).intersection([
|
||||
OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||
]))
|
||||
return (OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
|
||||
in set(self._target_spec.supported_ops))
|
||||
|
||||
def _is_allow_float(self):
|
||||
return bool(
|
||||
set(self._target_spec.supported_ops).intersection(
|
||||
[OpsSet.TFLITE_BUILTINS]))
|
||||
return (OpsSet.TFLITE_BUILTINS in set(
|
||||
self._target_spec.supported_ops)) or (OpsSet.SELECT_TF_OPS in set(
|
||||
self._target_spec.supported_ops))
|
||||
|
||||
def _any_optimization_enabled(self):
|
||||
return bool(
|
||||
|
@ -488,35 +488,92 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
converter.convert()
|
||||
self._assertValidDebugInfo(converter._debug_info)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testFlexOpWithInt8OpSet(self):
|
||||
model = tf.keras.Sequential()
|
||||
input_shape = (1, 4, 4, 4, 1)
|
||||
model.add(
|
||||
tf.keras.layers.Conv3D(
|
||||
4,
|
||||
kernel_size=(1, 1, 1),
|
||||
activation='relu',
|
||||
input_shape=input_shape[1:]))
|
||||
model.add(tf.keras.layers.Flatten())
|
||||
model.add(tf.keras.layers.Dense(2, activation='relu'))
|
||||
def _getIntegerQuantizationModelWithFlexOp(self):
|
||||
np.random.seed(0)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)])
|
||||
def _call_fn(inputs):
|
||||
return model(inputs, training=False)
|
||||
root = tracking.AutoTrackable()
|
||||
|
||||
concrete_func = _call_fn.get_concrete_function(
|
||||
tf.TensorSpec(input_shape, dtype=tf.float32))
|
||||
@tf.function(input_signature=[
|
||||
tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
|
||||
])
|
||||
def func(inp):
|
||||
tanh = tf.math.tanh(inp)
|
||||
conv3d = tf.nn.conv3d(
|
||||
tanh,
|
||||
tf.ones([3, 3, 3, 3, 3]),
|
||||
strides=[1, 1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
output = tf.math.tanh(conv3d)
|
||||
return output
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.target_spec.supported_ops = [
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
|
||||
tf.lite.OpsSet.SELECT_TF_OPS,
|
||||
def calibration_gen():
|
||||
for _ in range(5):
|
||||
yield [
|
||||
np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32)
|
||||
]
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
root.f = func
|
||||
return (root.f.get_concrete_function(), calibration_gen)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_Default', False, False, dtypes.float32),
|
||||
('_INT8InputOutput', False, False, dtypes.int8),
|
||||
('_UINT8InputOutput', False, False, dtypes.uint8),
|
||||
('_INT16Quantize', False, True, dtypes.float32),
|
||||
('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
|
||||
('_IntOnly', True, False, dtypes.float32),
|
||||
('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
|
||||
('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
|
||||
('_IntOnly_INT16Quantize', True, True, dtypes.float32),
|
||||
('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
|
||||
@test_util.run_v2_only
|
||||
def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize,
|
||||
inference_input_output_type):
|
||||
func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp()
|
||||
|
||||
quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
||||
[func])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
if is_int_only:
|
||||
if is_int16_quantize:
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.\
|
||||
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
|
||||
lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
else:
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
else:
|
||||
if is_int16_quantize:
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.\
|
||||
EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
|
||||
lite.OpsSet.TFLITE_BUILTINS,
|
||||
lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
else:
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual(inference_input_output_type.as_numpy_dtype,
|
||||
input_details[0]['dtype'])
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertLen(output_details, 1)
|
||||
self.assertEqual(inference_input_output_type.as_numpy_dtype,
|
||||
output_details[0]['dtype'])
|
||||
|
||||
|
||||
class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
|
Loading…
Reference in New Issue
Block a user