diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 04878cb2dd8..d06a5a662fa 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -74,6 +74,11 @@ class OpsSet(enum.Enum): # WARNING: Experimental interface, subject to change. SELECT_TF_OPS = "SELECT_TF_OPS" + # Convert model using only TensorFlow Lite quantized int8 operations. + # Specifying this will throw an error for operations that do not yet have + # quantized implementations. + TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8" + def __str__(self): return self.value diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 80efe361cb4..829169839cd 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -80,28 +80,24 @@ class Optimize(enum.Enum): # Converter will do its best to improve size and latency based on the # information provided. # Enhanced optimizations can be gained by providing a representative_dataset. - # Currently this is recommended, and is equivalent to the modes below. + # This is recommended, and is currently equivalent to the modes below. + # Currently, weights will be quantized and if representative_dataset is + # provided, activations for quantizable operations will also be quantized. DEFAULT = "DEFAULT" # Optimize for size. # # Optimizations that reduce the size of the model. # The model size will be reduced. - # Current behavior: - # - If RepresentativeDataset is not provided, weights will be quantized and - # activations will remain float. - # - If RepresentativeDataset is provided, weights and activations will be - # quantized. + # Currently, weights will be quantized and if representative_dataset is + # provided, activations for quantizable operations will also be quantized. OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE" # Optimize for latency. # # Optimizations that reduce the latency of the model. - # Current behavior: - # - If RepresentativeDataset is not provided, weights will be quantized and - # activations will remain float. - # - If RepresentativeDataset is provided, weights and activations will be - # quantized. + # Currently, weights will be quantized and if representative_dataset is + # provided, activations for quantizable operations will also be quantized. OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY" def __str__(self): @@ -154,9 +150,10 @@ class TFLiteConverterBase(object): def __init__(self): self.representative_dataset = None self.optimizations = [] + self._target_ops = set([OpsSet.TFLITE_BUILTINS]) - def _grappler_config(self, target_ops): - is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == target_ops + def _grappler_config(self): + is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set(self._target_ops) if is_only_flex_enabled: # The layout optimizer turns NHCW to NCHW. This provides performance # optimizations when Flex mode is enabled. However, this is not compatible @@ -172,13 +169,19 @@ class TFLiteConverterBase(object): if self.representative_dataset.input_gen is None: raise ValueError( "Provide an input generator for representative_dataset") + elif self._int8_target_required(): + raise ValueError("representative_dataset is required when specifying " + "TFLITE_BUILTINs_INT8 target.") + + def _int8_target_required(self): + return set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops) def _is_post_training_optimize(self): - return bool( + return (self._int8_target_required() or bool( set(self.optimizations).intersection([ Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE, Optimize.DEFAULT - ])) + ]))) def _is_weight_only_quantize(self): return (self._is_post_training_optimize() and @@ -189,10 +192,11 @@ class TFLiteConverterBase(object): def _calibrate_quantize_model(self, result, inference_input_type, inference_output_type): + allow_float = not self._int8_target_required() calibrate_quantize = _calibrator.Calibrator(result) return calibrate_quantize.calibrate_and_quantize( self.representative_dataset.input_gen, inference_input_type, - inference_output_type) + inference_output_type, allow_float) @_tf_export("lite.TFLiteConverter", v1=[]) @@ -330,6 +334,7 @@ class TFLiteConverterV2(TFLiteConverterBase): Invalid quantization parameters. """ # TODO(b/130297984): Add support for converting multiple function. + self._target_ops = self.target_spec.supported_ops if len(self._funcs) != 1: raise ValueError("This converter can only convert a single " "ConcreteFunction. Converting multiple functions is " @@ -345,7 +350,7 @@ class TFLiteConverterV2(TFLiteConverterBase): # Run a Grappler pass. graph_def = frozen_func.graph.as_graph_def() - config = self._grappler_config(self.target_spec.supported_ops) + config = self._grappler_config() if config: graph_def = _run_graph_optimizations( graph_def, @@ -787,6 +792,7 @@ class TFLiteConverter(TFLiteConverterBase): Input shape is not specified. None value for dimension in input_tensor. """ + self._target_ops = self.target_ops # Checks dimensions in input tensor. if self._has_valid_tensors(): for tensor in self._input_tensors: @@ -873,7 +879,7 @@ class TFLiteConverter(TFLiteConverterBase): optimized_graph = self._graph_def if self.inference_type != constants.QUANTIZED_UINT8: try: - config = self._grappler_config(self.target_ops) + config = self._grappler_config() if config: optimized_graph = _run_graph_optimizations(self._graph_def, self._input_tensors, diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 573111ad22a..eba6ceb17c9 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -537,10 +537,10 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Ensure that the quantized weights tflite model is smaller. self.assertTrue(len(quantized_tflite) < len(float_tflite)) - def testPostTrainingCalibrateAndQuantize(self): + def _getCalibrationQuantizeModel(self): np.random.seed(0) - inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3), - name='input') + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input') conv = nn_ops.conv2d( inp, filter=array_ops.ones([3, 3, 3, 16]), @@ -552,6 +552,10 @@ class FromSessionTest(test_util.TensorFlowTestCase): for _ in range(5): yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] + return (inp, output, calibration_gen) + + def testPostTrainingCalibrateAndQuantize(self): + inp, output, calibration_gen = self._getCalibrationQuantizeModel() sess = session.Session() # Convert float model. @@ -559,7 +563,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): float_tflite = float_converter.convert() self.assertTrue(float_tflite) - # Convert quantized weights model. + # Convert quantized model. quantized_converter = lite.TFLiteConverter.from_session( sess, [inp], [output]) quantized_converter.optimizations = [lite.Optimize.DEFAULT] @@ -580,21 +584,39 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Ensure that the quantized weights tflite model is smaller. self.assertLess(len(quantized_tflite), len(float_tflite)) + def testCalibrateAndQuantizeBuiltinInt8(self): + inp, output, calibration_gen = self._getCalibrationQuantizeModel() + sess = session.Session() + + # Convert float model. + float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + # Convert model by specifying target spec (instead of optimizations), since + # when targeting an integer only backend, quantization is mandatory. + quantized_converter = lite.TFLiteConverter.from_session( + sess, [inp], [output]) + quantized_converter.target_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8] + quantized_converter.representative_dataset = calibration_gen + quantized_tflite = quantized_converter.convert() + self.assertTrue(quantized_tflite) + + # The default input and output types should be float. + interpreter = Interpreter(model_content=quantized_tflite) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual(np.float32, input_details[0]['dtype']) + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual(np.float32, output_details[0]['dtype']) + + # Ensure that the quantized weights tflite model is smaller. + self.assertLess(len(quantized_tflite), len(float_tflite)) + def testPostTrainingCalibrateAndQuantizeInt8Inputs(self): - np.random.seed(0) - inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3), - name='input') - conv = nn_ops.conv2d( - inp, - filter=array_ops.ones([3, 3, 3, 16]), - strides=[1, 1, 1, 1], - padding='SAME') - output = nn_ops.relu(conv, name='output') - - def calibration_gen(): - for _ in range(5): - yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] - + inp, output, calibration_gen = self._getCalibrationQuantizeModel() sess = session.Session() # Convert float model. diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index 3ebc94e82d5..285935dc9df 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -186,7 +186,8 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) { } PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, - int output_py_type) { + int output_py_type, + bool allow_float) { TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type); TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type); if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) { @@ -199,7 +200,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type, flatbuffers::FlatBufferBuilder builder; auto status = tflite::optimize::QuantizeModel( &builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type), - TfLiteTypeToSchemaType(output_type), error_reporter_.get()); + TfLiteTypeToSchemaType(output_type), allow_float, error_reporter_.get()); if (status != kTfLiteOk) { error_reporter_->exception(); return nullptr; diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.h b/tensorflow/lite/python/optimize/calibration_wrapper.h index 801a100f462..3fe1629da58 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.h +++ b/tensorflow/lite/python/optimize/calibration_wrapper.h @@ -59,7 +59,8 @@ class CalibrationWrapper { PyObject* FeedTensor(PyObject* input_value); - PyObject* QuantizeModel(int input_py_type, int output_py_type); + PyObject* QuantizeModel(int input_py_type, int output_py_type, + bool allow_float); private: // CalibrationWrapper is not copyable or assignable. We avoid the use of diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index 665d4a34f61..a9eb6792882 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -54,7 +54,8 @@ class Calibrator(object): if not self._calibrator: raise ValueError("Failed to parse the model.") - def calibrate_and_quantize(self, dataset_gen, input_type, output_type): + def calibrate_and_quantize(self, dataset_gen, input_type, output_type, + allow_float): """Calibrates the model with specified generator and then quantizes it. Returns: @@ -64,10 +65,14 @@ class Calibrator(object): dataset_gen: A generator that generates calibration samples. input_type: A tf.dtype representing the desired real-value input type. output_type: A tf.dtype representing the desired real-value output type. + allow_float: A boolean. False if the resulting model cannot perform float + computation, useful when targeting an integer-only backend. + If False, an error will be thrown if an operation cannot be + quantized, otherwise the model will fallback to float ops. """ self._calibrator.Prepare() for calibration_sample in dataset_gen(): self._calibrator.FeedTensor(calibration_sample) return self._calibrator.QuantizeModel( np.dtype(input_type.as_numpy_dtype()).num, - np.dtype(output_type.as_numpy_dtype()).num) + np.dtype(output_type.as_numpy_dtype()).num, allow_float) diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index 1bb0175ed9c..ca4a86c8461 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -39,8 +39,25 @@ class CalibratorTest(test_util.TensorFlowTestCase): for _ in range(10): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] - quantized_model = quantizer.calibrate_and_quantize( - input_gen, constants.FLOAT, constants.FLOAT) + quantized_model = quantizer.calibrate_and_quantize(input_gen, + constants.FLOAT, + constants.FLOAT, False) + self.assertIsNotNone(quantized_model) + + def test_calibration_with_quantization_allow_float(self): + model_path = resource_loader.get_path_to_datafile( + 'test_data/mobilenet_like_model.bin') + float_model = open(model_path, 'rb').read() + quantizer = _calibrator.Calibrator(float_model) + + # Input generator for the model. + def input_gen(): + for _ in range(10): + yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] + + quantized_model = quantizer.calibrate_and_quantize(input_gen, + constants.FLOAT, + constants.FLOAT, True) self.assertIsNotNone(quantized_model) def test_calibration_with_quantization_multiple_inputs(self): @@ -56,8 +73,9 @@ class CalibratorTest(test_util.TensorFlowTestCase): for _ in range(10): yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)] - quantized_model = quantizer.calibrate_and_quantize( - input_gen, constants.FLOAT, constants.FLOAT) + quantized_model = quantizer.calibrate_and_quantize(input_gen, + constants.FLOAT, + constants.FLOAT, False) self.assertIsNotNone(quantized_model) def test_invalid_model_buffer(self): @@ -78,7 +96,7 @@ class CalibratorTest(test_util.TensorFlowTestCase): with self.assertRaises(RuntimeError): quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT, - constants.FLOAT) + constants.FLOAT, False) def test_invalid_shape_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( @@ -93,7 +111,7 @@ class CalibratorTest(test_util.TensorFlowTestCase): with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT) + constants.FLOAT, False) def test_invalid_type_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( @@ -108,7 +126,7 @@ class CalibratorTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT) + constants.FLOAT, False) if __name__ == '__main__': diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-ops-set.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-ops-set.pbtxt index 68c651a3c99..c3199b24d98 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-ops-set.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-ops-set.pbtxt @@ -9,4 +9,8 @@ tf_class { name: "TFLITE_BUILTINS" mtype: "" } + member { + name: "TFLITE_BUILTINS_INT8" + mtype: "" + } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-ops-set.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-ops-set.pbtxt index 68c651a3c99..c3199b24d98 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-ops-set.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-ops-set.pbtxt @@ -9,4 +9,8 @@ tf_class { name: "TFLITE_BUILTINS" mtype: "" } + member { + name: "TFLITE_BUILTINS_INT8" + mtype: "" + } }