Add BUILTIN_INT8 support to gate integer only conversion.
PiperOrigin-RevId: 246077351
This commit is contained in:
parent
e0df2327cf
commit
78993c47e3
@ -74,6 +74,11 @@ class OpsSet(enum.Enum):
|
|||||||
# WARNING: Experimental interface, subject to change.
|
# WARNING: Experimental interface, subject to change.
|
||||||
SELECT_TF_OPS = "SELECT_TF_OPS"
|
SELECT_TF_OPS = "SELECT_TF_OPS"
|
||||||
|
|
||||||
|
# Convert model using only TensorFlow Lite quantized int8 operations.
|
||||||
|
# Specifying this will throw an error for operations that do not yet have
|
||||||
|
# quantized implementations.
|
||||||
|
TFLITE_BUILTINS_INT8 = "TFLITE_BUILTINS_INT8"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
@ -80,28 +80,24 @@ class Optimize(enum.Enum):
|
|||||||
# Converter will do its best to improve size and latency based on the
|
# Converter will do its best to improve size and latency based on the
|
||||||
# information provided.
|
# information provided.
|
||||||
# Enhanced optimizations can be gained by providing a representative_dataset.
|
# Enhanced optimizations can be gained by providing a representative_dataset.
|
||||||
# Currently this is recommended, and is equivalent to the modes below.
|
# This is recommended, and is currently equivalent to the modes below.
|
||||||
|
# Currently, weights will be quantized and if representative_dataset is
|
||||||
|
# provided, activations for quantizable operations will also be quantized.
|
||||||
DEFAULT = "DEFAULT"
|
DEFAULT = "DEFAULT"
|
||||||
|
|
||||||
# Optimize for size.
|
# Optimize for size.
|
||||||
#
|
#
|
||||||
# Optimizations that reduce the size of the model.
|
# Optimizations that reduce the size of the model.
|
||||||
# The model size will be reduced.
|
# The model size will be reduced.
|
||||||
# Current behavior:
|
# Currently, weights will be quantized and if representative_dataset is
|
||||||
# - If RepresentativeDataset is not provided, weights will be quantized and
|
# provided, activations for quantizable operations will also be quantized.
|
||||||
# activations will remain float.
|
|
||||||
# - If RepresentativeDataset is provided, weights and activations will be
|
|
||||||
# quantized.
|
|
||||||
OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
|
OPTIMIZE_FOR_SIZE = "OPTIMIZE_FOR_SIZE"
|
||||||
|
|
||||||
# Optimize for latency.
|
# Optimize for latency.
|
||||||
#
|
#
|
||||||
# Optimizations that reduce the latency of the model.
|
# Optimizations that reduce the latency of the model.
|
||||||
# Current behavior:
|
# Currently, weights will be quantized and if representative_dataset is
|
||||||
# - If RepresentativeDataset is not provided, weights will be quantized and
|
# provided, activations for quantizable operations will also be quantized.
|
||||||
# activations will remain float.
|
|
||||||
# - If RepresentativeDataset is provided, weights and activations will be
|
|
||||||
# quantized.
|
|
||||||
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
|
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@ -154,9 +150,10 @@ class TFLiteConverterBase(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.representative_dataset = None
|
self.representative_dataset = None
|
||||||
self.optimizations = []
|
self.optimizations = []
|
||||||
|
self._target_ops = set([OpsSet.TFLITE_BUILTINS])
|
||||||
|
|
||||||
def _grappler_config(self, target_ops):
|
def _grappler_config(self):
|
||||||
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == target_ops
|
is_only_flex_enabled = set([OpsSet.SELECT_TF_OPS]) == set(self._target_ops)
|
||||||
if is_only_flex_enabled:
|
if is_only_flex_enabled:
|
||||||
# The layout optimizer turns NHCW to NCHW. This provides performance
|
# The layout optimizer turns NHCW to NCHW. This provides performance
|
||||||
# optimizations when Flex mode is enabled. However, this is not compatible
|
# optimizations when Flex mode is enabled. However, this is not compatible
|
||||||
@ -172,13 +169,19 @@ class TFLiteConverterBase(object):
|
|||||||
if self.representative_dataset.input_gen is None:
|
if self.representative_dataset.input_gen is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Provide an input generator for representative_dataset")
|
"Provide an input generator for representative_dataset")
|
||||||
|
elif self._int8_target_required():
|
||||||
|
raise ValueError("representative_dataset is required when specifying "
|
||||||
|
"TFLITE_BUILTINs_INT8 target.")
|
||||||
|
|
||||||
|
def _int8_target_required(self):
|
||||||
|
return set([OpsSet.TFLITE_BUILTINS_INT8]) == set(self._target_ops)
|
||||||
|
|
||||||
def _is_post_training_optimize(self):
|
def _is_post_training_optimize(self):
|
||||||
return bool(
|
return (self._int8_target_required() or bool(
|
||||||
set(self.optimizations).intersection([
|
set(self.optimizations).intersection([
|
||||||
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
|
Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE,
|
||||||
Optimize.DEFAULT
|
Optimize.DEFAULT
|
||||||
]))
|
])))
|
||||||
|
|
||||||
def _is_weight_only_quantize(self):
|
def _is_weight_only_quantize(self):
|
||||||
return (self._is_post_training_optimize() and
|
return (self._is_post_training_optimize() and
|
||||||
@ -189,10 +192,11 @@ class TFLiteConverterBase(object):
|
|||||||
|
|
||||||
def _calibrate_quantize_model(self, result, inference_input_type,
|
def _calibrate_quantize_model(self, result, inference_input_type,
|
||||||
inference_output_type):
|
inference_output_type):
|
||||||
|
allow_float = not self._int8_target_required()
|
||||||
calibrate_quantize = _calibrator.Calibrator(result)
|
calibrate_quantize = _calibrator.Calibrator(result)
|
||||||
return calibrate_quantize.calibrate_and_quantize(
|
return calibrate_quantize.calibrate_and_quantize(
|
||||||
self.representative_dataset.input_gen, inference_input_type,
|
self.representative_dataset.input_gen, inference_input_type,
|
||||||
inference_output_type)
|
inference_output_type, allow_float)
|
||||||
|
|
||||||
|
|
||||||
@_tf_export("lite.TFLiteConverter", v1=[])
|
@_tf_export("lite.TFLiteConverter", v1=[])
|
||||||
@ -330,6 +334,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
Invalid quantization parameters.
|
Invalid quantization parameters.
|
||||||
"""
|
"""
|
||||||
# TODO(b/130297984): Add support for converting multiple function.
|
# TODO(b/130297984): Add support for converting multiple function.
|
||||||
|
self._target_ops = self.target_spec.supported_ops
|
||||||
if len(self._funcs) != 1:
|
if len(self._funcs) != 1:
|
||||||
raise ValueError("This converter can only convert a single "
|
raise ValueError("This converter can only convert a single "
|
||||||
"ConcreteFunction. Converting multiple functions is "
|
"ConcreteFunction. Converting multiple functions is "
|
||||||
@ -345,7 +350,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
|
|||||||
|
|
||||||
# Run a Grappler pass.
|
# Run a Grappler pass.
|
||||||
graph_def = frozen_func.graph.as_graph_def()
|
graph_def = frozen_func.graph.as_graph_def()
|
||||||
config = self._grappler_config(self.target_spec.supported_ops)
|
config = self._grappler_config()
|
||||||
if config:
|
if config:
|
||||||
graph_def = _run_graph_optimizations(
|
graph_def = _run_graph_optimizations(
|
||||||
graph_def,
|
graph_def,
|
||||||
@ -787,6 +792,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
Input shape is not specified.
|
Input shape is not specified.
|
||||||
None value for dimension in input_tensor.
|
None value for dimension in input_tensor.
|
||||||
"""
|
"""
|
||||||
|
self._target_ops = self.target_ops
|
||||||
# Checks dimensions in input tensor.
|
# Checks dimensions in input tensor.
|
||||||
if self._has_valid_tensors():
|
if self._has_valid_tensors():
|
||||||
for tensor in self._input_tensors:
|
for tensor in self._input_tensors:
|
||||||
@ -873,7 +879,7 @@ class TFLiteConverter(TFLiteConverterBase):
|
|||||||
optimized_graph = self._graph_def
|
optimized_graph = self._graph_def
|
||||||
if self.inference_type != constants.QUANTIZED_UINT8:
|
if self.inference_type != constants.QUANTIZED_UINT8:
|
||||||
try:
|
try:
|
||||||
config = self._grappler_config(self.target_ops)
|
config = self._grappler_config()
|
||||||
if config:
|
if config:
|
||||||
optimized_graph = _run_graph_optimizations(self._graph_def,
|
optimized_graph = _run_graph_optimizations(self._graph_def,
|
||||||
self._input_tensors,
|
self._input_tensors,
|
||||||
|
@ -537,10 +537,10 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
# Ensure that the quantized weights tflite model is smaller.
|
# Ensure that the quantized weights tflite model is smaller.
|
||||||
self.assertTrue(len(quantized_tflite) < len(float_tflite))
|
self.assertTrue(len(quantized_tflite) < len(float_tflite))
|
||||||
|
|
||||||
def testPostTrainingCalibrateAndQuantize(self):
|
def _getCalibrationQuantizeModel(self):
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3),
|
inp = array_ops.placeholder(
|
||||||
name='input')
|
dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input')
|
||||||
conv = nn_ops.conv2d(
|
conv = nn_ops.conv2d(
|
||||||
inp,
|
inp,
|
||||||
filter=array_ops.ones([3, 3, 3, 16]),
|
filter=array_ops.ones([3, 3, 3, 16]),
|
||||||
@ -552,6 +552,10 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
|
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
|
||||||
|
|
||||||
|
return (inp, output, calibration_gen)
|
||||||
|
|
||||||
|
def testPostTrainingCalibrateAndQuantize(self):
|
||||||
|
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
|
|
||||||
# Convert float model.
|
# Convert float model.
|
||||||
@ -559,7 +563,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
float_tflite = float_converter.convert()
|
float_tflite = float_converter.convert()
|
||||||
self.assertTrue(float_tflite)
|
self.assertTrue(float_tflite)
|
||||||
|
|
||||||
# Convert quantized weights model.
|
# Convert quantized model.
|
||||||
quantized_converter = lite.TFLiteConverter.from_session(
|
quantized_converter = lite.TFLiteConverter.from_session(
|
||||||
sess, [inp], [output])
|
sess, [inp], [output])
|
||||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||||
@ -580,21 +584,39 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
# Ensure that the quantized weights tflite model is smaller.
|
# Ensure that the quantized weights tflite model is smaller.
|
||||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||||
|
|
||||||
|
def testCalibrateAndQuantizeBuiltinInt8(self):
|
||||||
|
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
# Convert float model.
|
||||||
|
float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
|
||||||
|
float_tflite = float_converter.convert()
|
||||||
|
self.assertTrue(float_tflite)
|
||||||
|
|
||||||
|
# Convert model by specifying target spec (instead of optimizations), since
|
||||||
|
# when targeting an integer only backend, quantization is mandatory.
|
||||||
|
quantized_converter = lite.TFLiteConverter.from_session(
|
||||||
|
sess, [inp], [output])
|
||||||
|
quantized_converter.target_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||||
|
quantized_converter.representative_dataset = calibration_gen
|
||||||
|
quantized_tflite = quantized_converter.convert()
|
||||||
|
self.assertTrue(quantized_tflite)
|
||||||
|
|
||||||
|
# The default input and output types should be float.
|
||||||
|
interpreter = Interpreter(model_content=quantized_tflite)
|
||||||
|
interpreter.allocate_tensors()
|
||||||
|
input_details = interpreter.get_input_details()
|
||||||
|
self.assertEqual(1, len(input_details))
|
||||||
|
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||||
|
output_details = interpreter.get_output_details()
|
||||||
|
self.assertEqual(1, len(output_details))
|
||||||
|
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||||
|
|
||||||
|
# Ensure that the quantized weights tflite model is smaller.
|
||||||
|
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||||
|
|
||||||
def testPostTrainingCalibrateAndQuantizeInt8Inputs(self):
|
def testPostTrainingCalibrateAndQuantizeInt8Inputs(self):
|
||||||
np.random.seed(0)
|
inp, output, calibration_gen = self._getCalibrationQuantizeModel()
|
||||||
inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3),
|
|
||||||
name='input')
|
|
||||||
conv = nn_ops.conv2d(
|
|
||||||
inp,
|
|
||||||
filter=array_ops.ones([3, 3, 3, 16]),
|
|
||||||
strides=[1, 1, 1, 1],
|
|
||||||
padding='SAME')
|
|
||||||
output = nn_ops.relu(conv, name='output')
|
|
||||||
|
|
||||||
def calibration_gen():
|
|
||||||
for _ in range(5):
|
|
||||||
yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
|
|
||||||
|
|
||||||
sess = session.Session()
|
sess = session.Session()
|
||||||
|
|
||||||
# Convert float model.
|
# Convert float model.
|
||||||
|
@ -186,7 +186,8 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
||||||
int output_py_type) {
|
int output_py_type,
|
||||||
|
bool allow_float) {
|
||||||
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
|
TfLiteType input_type = python_utils::TfLiteTypeFromPyType(input_py_type);
|
||||||
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
|
TfLiteType output_type = python_utils::TfLiteTypeFromPyType(output_py_type);
|
||||||
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
|
if (input_type == kTfLiteNoType || output_type == kTfLiteNoType) {
|
||||||
@ -199,7 +200,7 @@ PyObject* CalibrationWrapper::QuantizeModel(int input_py_type,
|
|||||||
flatbuffers::FlatBufferBuilder builder;
|
flatbuffers::FlatBufferBuilder builder;
|
||||||
auto status = tflite::optimize::QuantizeModel(
|
auto status = tflite::optimize::QuantizeModel(
|
||||||
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
&builder, tflite_model.get(), TfLiteTypeToSchemaType(input_type),
|
||||||
TfLiteTypeToSchemaType(output_type), error_reporter_.get());
|
TfLiteTypeToSchemaType(output_type), allow_float, error_reporter_.get());
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
error_reporter_->exception();
|
error_reporter_->exception();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -59,7 +59,8 @@ class CalibrationWrapper {
|
|||||||
|
|
||||||
PyObject* FeedTensor(PyObject* input_value);
|
PyObject* FeedTensor(PyObject* input_value);
|
||||||
|
|
||||||
PyObject* QuantizeModel(int input_py_type, int output_py_type);
|
PyObject* QuantizeModel(int input_py_type, int output_py_type,
|
||||||
|
bool allow_float);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// CalibrationWrapper is not copyable or assignable. We avoid the use of
|
// CalibrationWrapper is not copyable or assignable. We avoid the use of
|
||||||
|
@ -54,7 +54,8 @@ class Calibrator(object):
|
|||||||
if not self._calibrator:
|
if not self._calibrator:
|
||||||
raise ValueError("Failed to parse the model.")
|
raise ValueError("Failed to parse the model.")
|
||||||
|
|
||||||
def calibrate_and_quantize(self, dataset_gen, input_type, output_type):
|
def calibrate_and_quantize(self, dataset_gen, input_type, output_type,
|
||||||
|
allow_float):
|
||||||
"""Calibrates the model with specified generator and then quantizes it.
|
"""Calibrates the model with specified generator and then quantizes it.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -64,10 +65,14 @@ class Calibrator(object):
|
|||||||
dataset_gen: A generator that generates calibration samples.
|
dataset_gen: A generator that generates calibration samples.
|
||||||
input_type: A tf.dtype representing the desired real-value input type.
|
input_type: A tf.dtype representing the desired real-value input type.
|
||||||
output_type: A tf.dtype representing the desired real-value output type.
|
output_type: A tf.dtype representing the desired real-value output type.
|
||||||
|
allow_float: A boolean. False if the resulting model cannot perform float
|
||||||
|
computation, useful when targeting an integer-only backend.
|
||||||
|
If False, an error will be thrown if an operation cannot be
|
||||||
|
quantized, otherwise the model will fallback to float ops.
|
||||||
"""
|
"""
|
||||||
self._calibrator.Prepare()
|
self._calibrator.Prepare()
|
||||||
for calibration_sample in dataset_gen():
|
for calibration_sample in dataset_gen():
|
||||||
self._calibrator.FeedTensor(calibration_sample)
|
self._calibrator.FeedTensor(calibration_sample)
|
||||||
return self._calibrator.QuantizeModel(
|
return self._calibrator.QuantizeModel(
|
||||||
np.dtype(input_type.as_numpy_dtype()).num,
|
np.dtype(input_type.as_numpy_dtype()).num,
|
||||||
np.dtype(output_type.as_numpy_dtype()).num)
|
np.dtype(output_type.as_numpy_dtype()).num, allow_float)
|
||||||
|
@ -39,8 +39,25 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
|||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize(
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
input_gen, constants.FLOAT, constants.FLOAT)
|
constants.FLOAT,
|
||||||
|
constants.FLOAT, False)
|
||||||
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
|
def test_calibration_with_quantization_allow_float(self):
|
||||||
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
|
'test_data/mobilenet_like_model.bin')
|
||||||
|
float_model = open(model_path, 'rb').read()
|
||||||
|
quantizer = _calibrator.Calibrator(float_model)
|
||||||
|
|
||||||
|
# Input generator for the model.
|
||||||
|
def input_gen():
|
||||||
|
for _ in range(10):
|
||||||
|
yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)]
|
||||||
|
|
||||||
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
|
constants.FLOAT,
|
||||||
|
constants.FLOAT, True)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
def test_calibration_with_quantization_multiple_inputs(self):
|
def test_calibration_with_quantization_multiple_inputs(self):
|
||||||
@ -56,8 +73,9 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
|||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
|
yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)]
|
||||||
|
|
||||||
quantized_model = quantizer.calibrate_and_quantize(
|
quantized_model = quantizer.calibrate_and_quantize(input_gen,
|
||||||
input_gen, constants.FLOAT, constants.FLOAT)
|
constants.FLOAT,
|
||||||
|
constants.FLOAT, False)
|
||||||
self.assertIsNotNone(quantized_model)
|
self.assertIsNotNone(quantized_model)
|
||||||
|
|
||||||
def test_invalid_model_buffer(self):
|
def test_invalid_model_buffer(self):
|
||||||
@ -78,7 +96,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT,
|
||||||
constants.FLOAT)
|
constants.FLOAT, False)
|
||||||
|
|
||||||
def test_invalid_shape_calibrator_gen(self):
|
def test_invalid_shape_calibrator_gen(self):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
@ -93,7 +111,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'):
|
with self.assertRaisesWithRegexpMatch(ValueError, 'Dimension mismatch'):
|
||||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||||
constants.FLOAT)
|
constants.FLOAT, False)
|
||||||
|
|
||||||
def test_invalid_type_calibrator_gen(self):
|
def test_invalid_type_calibrator_gen(self):
|
||||||
model_path = resource_loader.get_path_to_datafile(
|
model_path = resource_loader.get_path_to_datafile(
|
||||||
@ -108,7 +126,7 @@ class CalibratorTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
quantizer.calibrate_and_quantize(input_gen, constants.FLOAT,
|
||||||
constants.FLOAT)
|
constants.FLOAT, False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -9,4 +9,8 @@ tf_class {
|
|||||||
name: "TFLITE_BUILTINS"
|
name: "TFLITE_BUILTINS"
|
||||||
mtype: "<enum \'OpsSet\'>"
|
mtype: "<enum \'OpsSet\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "TFLITE_BUILTINS_INT8"
|
||||||
|
mtype: "<enum \'OpsSet\'>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,4 +9,8 @@ tf_class {
|
|||||||
name: "TFLITE_BUILTINS"
|
name: "TFLITE_BUILTINS"
|
||||||
mtype: "<enum \'OpsSet\'>"
|
mtype: "<enum \'OpsSet\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "TFLITE_BUILTINS_INT8"
|
||||||
|
mtype: "<enum \'OpsSet\'>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user