Add BUILTIN_INT8 support to gate integer only conversion.

PiperOrigin-RevId: 246077351
This commit is contained in:
Suharsh Sivakumar 2019-04-30 21:30:15 -07:00 committed by TensorFlower Gardener
parent e0df2327cf
commit 78993c47e3
9 changed files with 114 additions and 48 deletions

View File

@ -74,6 +74,11 @@ class OpsSet(enum.Enum):
# WARNING: Experimental interface, subject to change.
SELECT_TF_OPS = "SELECT_TF_OPS"
# Convert model using only TensorFlow Lite quantized int8 operations.
# Specifying this will throw an error for operations that do not yet have
# quantized implementations.
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
def __str__(self):
return self.value

View File

@ -80,28 +80,24 @@ class Optimize(enum.Enum):
# Converter will do its best to improve size and latency based on the
# information provided.
# Enhanced optimizations can be gained by providing a representative_dataset.
# Currently this is recommended, and is equivalent to the modes below.
# This is recommended, and is currently equivalent to the modes below.
# Currently, weights will be quantized and if representative_dataset is
# provided, activations for quantizable operations will also be quantized.
DEFAULT = "DEFAULT"
# Optimize for size.
#
# Optimizations that reduce the size of the model.
# The model size will be reduced.
# Current behavior:
# - If RepresentativeDataset is not provided, weights will be quantized and
# activations will remain float.
# - If RepresentativeDataset is provided, weights and activations will be
# quantized.
# Currently, weights will be quantized and if representative_dataset is
# provided, activations for quantizable operations will also be quantized.
OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
# Optimize for latency.
#
# Optimizations that reduce the latency of the model.
# Current behavior:
# - If RepresentativeDataset is not provided, weights will be quantized and
# activations will remain float.
# - If RepresentativeDataset is provided, weights and activations will be
# quantized.
# Currently, weights will be quantized and if representative_dataset is
# provided, activations for quantizable operations will also be quantized.
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
def __str__(self):
@ -154,9 +150,10 @@ class TFLiteConverterBase(object):
def __init__(self):
self.representative_dataset = None
self.optimizations = []
self._target_ops = set([OpsSet.TFLITE_BUILTINS])
def _grappler_config(self, target_ops):
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == target_ops
def _grappler_config(self):
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set(self._target_ops)
if is_only_flex_enabled:
# The layout optimizer turns NHCW to NCHW. This provides performance
# optimizations when Flex mode is enabled. However, this is not compatible
@ -172,13 +169,19 @@ class TFLiteConverterBase(object):
if self.representative_dataset.input_gen is None:
raise ValueError(
"Provide an input generator for representative_dataset")
elif self._int8_target_required():
raise ValueError("representative_dataset is required when specifying "
"TFLITE_BUILTINs_INT8 target.")
def _int8_target_required(self):
return set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops)
def _is_post_training_optimize(self):
return bool(
return (self._int8_target_required() or bool(
set(self.optimizations).intersection([
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
Optimize.DEFAULT
]))
])))
def _is_weight_only_quantize(self):
return (self._is_post_training_optimize() and
@ -189,10 +192,11 @@ class TFLiteConverterBase(object):
def _calibrate_quantize_model(self, result, inference_input_type,
inference_output_type):
allow_float = not self._int8_target_required()
calibrate_quantize = _calibrator.Calibrator(result)
return calibrate_quantize.calibrate_and_quantize(
self.representative_dataset.input_gen, inference_input_type,
inference_output_type)
inference_output_type, allow_float)
@_tf_export("lite.TFLiteConverter", v1=[])
@ -330,6 +334,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
Invalid quantization parameters.
"""
# TODO(b/130297984): Add support for converting multiple function.
self._target_ops = self.target_spec.supported_ops
if len(self._funcs) != 1:
raise ValueError("This converter can only convert a single "
"ConcreteFunction. Converting multiple functions is "
@ -345,7 +350,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
# Run a Grappler pass.
graph_def = frozen_func.graph.as_graph_def()
config = self._grappler_config(self.target_spec.supported_ops)
config = self._grappler_config()
if config:
graph_def = _run_graph_optimizations(
graph_def,
@ -787,6 +792,7 @@ class TFLiteConverter(TFLiteConverterBase):
Input shape is not specified.
None value for dimension in input_tensor.
"""
self._target_ops = self.target_ops
# Checks dimensions in input tensor.
if self._has_valid_tensors():
for tensor in self._input_tensors:
@ -873,7 +879,7 @@ class TFLiteConverter(TFLiteConverterBase):
optimized_graph = self._graph_def
if self.inference_type != constants.QUANTIZED_UINT8:
try:
config = self._grappler_config(self.target_ops)
config = self._grappler_config()
if config:
optimized_graph = _run_graph_optimizations(self._graph_def,
self._input_tensors,

View File

@ -537,10 +537,10 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
def testPostTrainingCalibrateAndQuantize(self):
def _getCalibrationQuantizeModel(self):
np.random.seed(0)
inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3),
name='input')
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input')
conv = nn_ops.conv2d(
inp,
filter=array_ops.ones([3, 3, 3, 16]),
@ -552,6 +552,10 @@ class FromSessionTest(test_util.TensorFlowTestCase):
for _ in range(5):
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
return (inp, output, calibration_gen)
def testPostTrainingCalibrateAndQuantize(self):
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()
# Convert float model.
@ -559,7 +563,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert quantized weights model.
# Convert quantized model.
quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
@ -580,21 +584,39 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def testCalibrateAndQuantizeBuiltinInt8(self):
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()
# Convert float model.
float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert model by specifying target spec (instead of optimizations), since
# when targeting an integer only backend, quantization is mandatory.
quantized_converter = lite.TFLiteConverter.from_session(
sess, [inp], [output])
quantized_converter.target_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
quantized_converter.representative_dataset = calibration_gen
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual(np.float32, input_details[0]['dtype'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
def testPostTrainingCalibrateAndQuantizeInt8Inputs(self):
np.random.seed(0)
inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3),
name='input')
conv = nn_ops.conv2d(
inp,
filter=array_ops.ones([3, 3, 3, 16]),
strides=[1, 1, 1, 1],
padding='SAME')
output = nn_ops.relu(conv, name='output')
def calibration_gen():
for _ in range(5):
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
sess = session.Session()
# Convert float model.

View File

@ -186,7 +186,8 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
}
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
int output_py_type) {
int output_py_type,
bool allow_float) {
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
@ -199,7 +200,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
flatbuffers::FlatBufferBuilder builder;
auto status = tflite::optimize::QuantizeModel(
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
TfLiteTypeToSchemaType(output_type), error_reporter_.get());
TfLiteTypeToSchemaType(output_type), allow_float, error_reporter_.get());
if (status != kTfLiteOk) {
error_reporter_->exception();
return nullptr;

View File

@ -59,7 +59,8 @@ class CalibrationWrapper {
PyObject* FeedTensor(PyObject* input_value);
PyObject* QuantizeModel(int input_py_type, int output_py_type);
PyObject* QuantizeModel(int input_py_type, int output_py_type,
bool allow_float);
private:
// CalibrationWrapper is not copyable or assignable. We avoid the use of

View File

@ -54,7 +54,8 @@ class Calibrator(object):
if not self._calibrator:
raise ValueError("Failed to parse the model.")
def calibrate_and_quantize(self, dataset_gen, input_type, output_type):
def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
allow_float):
"""Calibrates the model with specified generator and then quantizes it.
Returns:
@ -64,10 +65,14 @@ class Calibrator(object):
dataset_gen: A generator that generates calibration samples.
input_type: A tf.dtype representing the desired real-value input type.
output_type: A tf.dtype representing the desired real-value output type.
allow_float: A boolean. False if the resulting model cannot perform float
computation, useful when targeting an integer-only backend.
If False, an error will be thrown if an operation cannot be
quantized, otherwise the model will fallback to float ops.
"""
self._calibrator.Prepare()
for calibration_sample in dataset_gen():
self._calibrator.FeedTensor(calibration_sample)
return self._calibrator.QuantizeModel(
np.dtype(input_type.as_numpy_dtype()).num,
np.dtype(output_type.as_numpy_dtype()).num)
np.dtype(output_type.as_numpy_dtype()).num, allow_float)

View File

@ -39,8 +39,25 @@ class CalibratorTest(test_util.TensorFlowTestCase):
for _ in range(10):
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
quantized_model = quantizer.calibrate_and_quantize(
input_gen, constants.FLOAT, constants.FLOAT)
quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT,
constants.FLOAT, False)
self.assertIsNotNone(quantized_model)
def test_calibration_with_quantization_allow_float(self):
model_path = resource_loader.get_path_to_datafile(
'test_data/mobilenet_like_model.bin')
float_model = open(model_path, 'rb').read()
quantizer = _calibrator.Calibrator(float_model)
# Input generator for the model.
def input_gen():
for _ in range(10):
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT,
constants.FLOAT, True)
self.assertIsNotNone(quantized_model)
def test_calibration_with_quantization_multiple_inputs(self):
@ -56,8 +73,9 @@ class CalibratorTest(test_util.TensorFlowTestCase):
for _ in range(10):
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
quantized_model = quantizer.calibrate_and_quantize(
input_gen, constants.FLOAT, constants.FLOAT)
quantized_model = quantizer.calibrate_and_quantize(input_gen,
constants.FLOAT,
constants.FLOAT, False)
self.assertIsNotNone(quantized_model)
def test_invalid_model_buffer(self):
@ -78,7 +96,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
with self.assertRaises(RuntimeError):
quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT,
constants.FLOAT)
constants.FLOAT, False)
def test_invalid_shape_calibrator_gen(self):
model_path = resource_loader.get_path_to_datafile(
@ -93,7 +111,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'):
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
constants.FLOAT)
constants.FLOAT, False)
def test_invalid_type_calibrator_gen(self):
model_path = resource_loader.get_path_to_datafile(
@ -108,7 +126,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
constants.FLOAT)
constants.FLOAT, False)
if __name__ == '__main__':

View File

@ -9,4 +9,8 @@ tf_class {
name: "TFLITE_BUILTINS"
mtype: "<enum \'OpsSet\'>"
}
member {
name: "TFLITE_BUILTINS_INT8"
mtype: "<enum \'OpsSet\'>"
}
}

View File

@ -9,4 +9,8 @@ tf_class {
name: "TFLITE_BUILTINS"
mtype: "<enum \'OpsSet\'>"
}
member {
name: "TFLITE_BUILTINS_INT8"
mtype: "<enum \'OpsSet\'>"
}
}