Refactor and Fix lint errors in util.py and lite*.py files
PiperOrigin-RevId: 324727472 Change-Id: I3766b0724564f91216bffcc8b55f70744fd94334
This commit is contained in:
parent
e4592dad25
commit
1277f67514
@ -125,7 +125,7 @@ class Optimize(enum.Enum):
|
||||
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
return str(self.value)
|
||||
|
||||
|
||||
@_tf_export("lite.RepresentativeDataset")
|
||||
@ -230,7 +230,7 @@ class QuantizationMode(object):
|
||||
|
||||
def post_training_int16x8_allow_float(self):
|
||||
"""Post training int16x8 quantize, allow float fallback."""
|
||||
return (self._is_int16x8_target_required() and self._is_allow_float())
|
||||
return self._is_int16x8_target_required() and self._is_allow_float()
|
||||
|
||||
def post_training_dynamic_range_int8(self):
|
||||
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
|
||||
@ -907,7 +907,7 @@ class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
|
||||
"""
|
||||
# TODO(b/130297984): Add support for converting multiple function.
|
||||
|
||||
if len(self._funcs) == 0:
|
||||
if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test
|
||||
raise ValueError("No ConcreteFunction is specified.")
|
||||
|
||||
if len(self._funcs) > 1:
|
||||
@ -1127,7 +1127,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
|
||||
parameter is ignored. (default tf.float32)
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays. If an integer type is provided and
|
||||
`optimizations` are not used, `quantized_inputs_stats` must be provided.
|
||||
`optimizations` are not used, `quantized_input_stats` must be provided.
|
||||
If `inference_type` is tf.uint8, signaling conversion to a fully quantized
|
||||
model from a quantization-aware trained input model, then
|
||||
`inference_input_type` defaults to tf.uint8. In all other cases,
|
||||
@ -1681,7 +1681,7 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
inference_input_type: Target data type of real-number input arrays. Allows
|
||||
for a different type for input arrays.
|
||||
If an integer type is provided and `optimizations` are not used,
|
||||
`quantized_inputs_stats` must be provided.
|
||||
`quantized_input_stats` must be provided.
|
||||
If `inference_type` is tf.uint8, signaling conversion to a fully quantized
|
||||
model from a quantization-aware trained input model, then
|
||||
`inference_input_type` defaults to tf.uint8.
|
||||
@ -2012,6 +2012,7 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
|
||||
"""
|
||||
return super(TFLiteConverter, self).convert()
|
||||
|
||||
|
||||
@_tf_export(v1=["lite.TocoConverter"])
|
||||
class TocoConverter(object):
|
||||
"""Convert a TensorFlow model into `output_format` using TOCO.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -58,14 +58,14 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
('EnableMlirConverter', True), # enable mlir
|
||||
('DisableMlirConverter', False)) # disable mlir
|
||||
@test_util.run_v2_only
|
||||
def testFloat(self, enable_mlir):
|
||||
def testFloat(self, enable_mlir_converter):
|
||||
root = self._getSimpleVariableModel()
|
||||
input_data = tf.constant(1., shape=[1])
|
||||
concrete_func = root.f.get_concrete_function(input_data)
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_new_converter = enable_mlir
|
||||
converter.experimental_new_converter = enable_mlir_converter
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
@ -142,7 +142,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertIn('can only convert a single ConcreteFunction',
|
||||
str(error.exception))
|
||||
|
||||
def _getCalibrationQuantizeModel(self):
|
||||
def _getIntegerQuantizeModel(self):
|
||||
np.random.seed(0)
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
@ -167,23 +167,23 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
('EnableMlirQuantizer', True), # enable mlir quantizer
|
||||
('DisableMlirQuantizer', False)) # disable mlir quantizer
|
||||
def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer):
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert float model.
|
||||
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
float_tflite = float_converter.convert()
|
||||
self.assertTrue(float_tflite)
|
||||
float_tflite_model = float_converter.convert()
|
||||
self.assertIsNotNone(float_tflite_model)
|
||||
|
||||
# Convert quantized model.
|
||||
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.representative_dataset = calibration_gen
|
||||
quantized_converter._experimental_new_quantizer = mlir_quantizer
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
# The default input and output types should be float.
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
@ -193,7 +193,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_INT8InputOutput', lite.constants.INT8),
|
||||
@ -202,7 +202,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
@test_util.run_v2_only
|
||||
def testInvalidPostTrainingDynamicRangeQuantization(
|
||||
self, inference_input_output_type):
|
||||
func, _ = self._getCalibrationQuantizeModel()
|
||||
func, _ = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert float model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
@ -228,7 +228,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
|
||||
def testPostTrainingIntegerAllowFloatQuantization(
|
||||
self, inference_input_output_type):
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert float model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
@ -242,7 +242,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite_model)
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
@ -259,7 +259,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertLess(len(quantized_tflite_model), len(tflite_model))
|
||||
|
||||
def testPostTrainingIntegerAllowFloatQuantizationINT16InputOutput(self):
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert float model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
@ -279,7 +279,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite_model)
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
@ -299,7 +299,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
# In this test we check that when we do 16x8 post-training
|
||||
# quantization and set inference_input(output)_type to
|
||||
# constants.INT8, we have an error.
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert quantized model.
|
||||
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
@ -330,7 +330,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
inference_input_output_type,
|
||||
use_target_ops_flag,
|
||||
quantization_16x8):
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert float model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
@ -357,7 +357,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite_model)
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
@ -374,12 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertLess(len(quantized_tflite_model), len(tflite_model))
|
||||
|
||||
def testCalibrateAndQuantizeBuiltinInt16(self):
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
# Convert float model.
|
||||
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
float_tflite = float_converter.convert()
|
||||
self.assertTrue(float_tflite)
|
||||
float_tflite_model = float_converter.convert()
|
||||
self.assertIsNotNone(float_tflite_model)
|
||||
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
# TODO(b/156309549): We should add INT16 to the builtin types.
|
||||
@ -389,13 +389,13 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
converter.representative_dataset = calibration_gen
|
||||
converter._experimental_calibrate_only = True
|
||||
calibrated_tflite = converter.convert()
|
||||
quantized_tflite = mlir_quantize(calibrated_tflite,
|
||||
inference_type=_types_pb2.QUANTIZED_INT16)
|
||||
quantized_tflite_model = mlir_quantize(
|
||||
calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16)
|
||||
|
||||
self.assertTrue(quantized_tflite)
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
# The default input and output types should be float.
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
@ -405,7 +405,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
|
||||
def _getTrainingTimeQuantizedModel(self):
|
||||
|
||||
@ -454,17 +454,17 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
model = self._getTrainingTimeQuantizedModel()
|
||||
|
||||
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
float_tflite = float_converter.convert()
|
||||
self.assertTrue(float_tflite)
|
||||
float_tflite_model = float_converter.convert()
|
||||
self.assertIsNotNone(float_tflite_model)
|
||||
|
||||
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
|
||||
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
quantized_converter.inference_input_type = inference_input_output_type
|
||||
quantized_converter.inference_output_type = inference_input_output_type
|
||||
quantized_tflite = quantized_converter.convert()
|
||||
self.assertTrue(quantized_tflite)
|
||||
quantized_tflite_model = quantized_converter.convert()
|
||||
self.assertIsNotNone(quantized_tflite_model)
|
||||
|
||||
interpreter = Interpreter(model_content=quantized_tflite)
|
||||
interpreter = Interpreter(model_content=quantized_tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
@ -476,12 +476,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
output_details[0]['dtype'])
|
||||
|
||||
# Ensure that the quantized tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite), len(float_tflite))
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testNewQuantizer(self):
|
||||
"""Test the model quantized by the new converter."""
|
||||
func, calibration_gen = self._getCalibrationQuantizeModel()
|
||||
func, calibration_gen = self._getIntegerQuantizeModel()
|
||||
|
||||
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
|
||||
quantized_converter.target_spec.supported_ops = [
|
||||
@ -502,13 +502,13 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
|
||||
old_value = self._evaluateTFLiteModel(old_tflite, [input_data])
|
||||
new_value = self._evaluateTFLiteModel(new_tflite, [input_data])
|
||||
np.testing.assert_almost_equal(old_value, new_value, 1)
|
||||
self.assertAllClose(old_value, new_value, atol=1e-01)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('EnableMlirConverter', True), # enable mlir
|
||||
('DisableMlirConverter', False)) # disable mlir
|
||||
@test_util.run_v2_only
|
||||
def testEmbeddings(self, enable_mlir):
|
||||
def testEmbeddings(self, enable_mlir_converter):
|
||||
"""Test model with embeddings."""
|
||||
input_data = tf.constant(
|
||||
np.array(np.random.random_sample((20)), dtype=np.int32))
|
||||
@ -534,13 +534,13 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
# Convert model.
|
||||
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
|
||||
converter.experimental_new_converter = enable_mlir
|
||||
converter.experimental_new_converter = enable_mlir_converter
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = root.func(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0], 5)
|
||||
self.assertAllClose(expected_value.numpy(), actual_value[0], atol=1e-05)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGraphDebugInfo(self):
|
||||
@ -594,7 +594,7 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
self.assertLen(input_details, 2)
|
||||
self.assertStartsWith(input_details[0]['name'], 'inputA')
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
|
||||
self.assertEqual((0., 0.), input_details[0]['quantization'])
|
||||
|
||||
self.assertStartsWith(
|
||||
@ -602,14 +602,14 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
'inputB',
|
||||
)
|
||||
self.assertEqual(np.float32, input_details[1]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
|
||||
self.assertTrue([1, 16, 16, 3], input_details[1]['shape'])
|
||||
self.assertEqual((0., 0.), input_details[1]['quantization'])
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertLen(output_details, 1)
|
||||
self.assertStartsWith(output_details[0]['name'], 'add')
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertTrue([1, 16, 16, 3], output_details[0]['shape'])
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
|
||||
@test_util.run_v2_only
|
||||
@ -715,7 +715,6 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
@test_util.run_v2_only
|
||||
def testNoConcreteFunctionModel(self):
|
||||
root = self._getMultiFunctionModel()
|
||||
input_data = tf.constant(1., shape=[1])
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save(root, save_dir)
|
||||
@ -836,7 +835,7 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest):
|
||||
expected_value = model.predict(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, input_data)
|
||||
for tf_result, tflite_result in zip(expected_value, actual_value):
|
||||
np.testing.assert_almost_equal(tf_result, tflite_result, 5)
|
||||
self.assertAllClose(tf_result, tflite_result, atol=1e-05)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGraphDebugInfo(self):
|
||||
@ -919,7 +918,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
expected_value = concrete_func(**input_data)
|
||||
actual_value = self._evaluateTFLiteModel(
|
||||
tflite_model, [input_data['x'], input_data['b']])[0]
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value)
|
||||
self.assertAllClose(expected_value, actual_value)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testStaticRnn(self):
|
||||
@ -945,7 +944,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
expected_value = concrete_func(input_data)[0]
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
for expected, actual in zip(expected_value, actual_value):
|
||||
np.testing.assert_almost_equal(expected.numpy(), actual)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testWhileLoop(self):
|
||||
@ -973,7 +972,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data)[0]
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value)
|
||||
self.assertAllClose(expected_value, actual_value)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testDynamicRnn(self):
|
||||
@ -997,11 +996,9 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
expected_value = concrete_func(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
for expected, actual in zip(expected_value, actual_value):
|
||||
if isinstance(expected, ops.EagerTensor):
|
||||
expected = expected.numpy()
|
||||
else:
|
||||
expected = expected.c.numpy()
|
||||
np.testing.assert_almost_equal(expected, actual)
|
||||
if not isinstance(expected, ops.EagerTensor):
|
||||
expected = expected.c
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@parameterized.named_parameters(('LSTM', recurrent_v2.LSTM),
|
||||
('SimpleRNN', recurrent.SimpleRNN),
|
||||
@ -1025,7 +1022,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = model.predict(input_data)
|
||||
np.testing.assert_almost_equal(expected_value, actual_value, decimal=5)
|
||||
self.assertAllClose(expected_value, actual_value, atol=1e-05)
|
||||
|
||||
@parameterized.named_parameters(('LSTM', recurrent_v2.LSTM),
|
||||
('SimpleRNN', recurrent.SimpleRNN),
|
||||
@ -1046,7 +1043,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = model.predict(input_data)
|
||||
np.testing.assert_almost_equal(expected_value, actual_value, decimal=5)
|
||||
self.assertAllClose(expected_value, actual_value, atol=1e-05)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testKerasBidirectionalRNN(self):
|
||||
@ -1069,7 +1066,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = model.predict(input_data)
|
||||
np.testing.assert_almost_equal(expected_value, actual_value, decimal=5)
|
||||
self.assertAllClose(expected_value, actual_value, atol=1e-05)
|
||||
|
||||
|
||||
class GrapplerTest(lite_v2_test_util.ModelTest):
|
||||
@ -1096,14 +1093,14 @@ class GrapplerTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = root.f(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
|
||||
self.assertAllClose(expected_value, actual_value)
|
||||
|
||||
# Enable hybrid quantization, same result
|
||||
converter.optimizations = [lite.Optimize.DEFAULT]
|
||||
hybrid_tflite_model = converter.convert()
|
||||
actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
|
||||
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
|
||||
tflite_model = converter.convert()
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
|
||||
self.assertAllClose(expected_value, actual_value)
|
||||
|
||||
|
||||
class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
@ -1128,15 +1125,16 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
# Check values from converted model.
|
||||
expected_value = concrete_func(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(
|
||||
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])
|
||||
np.testing.assert_almost_equal(
|
||||
expected_value.numpy(), actual_value[0], decimal=6)
|
||||
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])[0]
|
||||
self.assertAllClose(expected_value, actual_value, atol=1e-06)
|
||||
|
||||
def _getIntegerQuantizeModelWithUnknownShapes(self):
|
||||
np.random.seed(0)
|
||||
|
||||
def _getQuantizedModel(self):
|
||||
# Returns a model with tf.MatMul and unknown dimensions.
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)])
|
||||
def model(in_tensor):
|
||||
def model(input_tensor):
|
||||
"""Define a model with tf.MatMul and unknown shapes."""
|
||||
# We need the tensor to have more than 1024 elements for quantize_weights
|
||||
# to kick in. Thus, the [33, 33] shape.
|
||||
const_tensor = tf.constant(
|
||||
@ -1145,12 +1143,14 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
dtype=tf.float32,
|
||||
name='inputB')
|
||||
|
||||
shape = tf.shape(in_tensor)
|
||||
shape = tf.shape(input_tensor)
|
||||
fill = tf.transpose(tf.fill(shape, 1.))
|
||||
mult = tf.matmul(fill, in_tensor)
|
||||
mult = tf.matmul(fill, input_tensor)
|
||||
return tf.matmul(mult, const_tensor)
|
||||
|
||||
concrete_func = model.get_concrete_function()
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = model
|
||||
concrete_func = root.f.get_concrete_function()
|
||||
|
||||
def calibration_gen():
|
||||
for batch in range(5, 20, 5):
|
||||
@ -1161,7 +1161,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMatMulQuantize(self):
|
||||
concrete_func, _ = self._getQuantizedModel()
|
||||
concrete_func, _ = self._getIntegerQuantizeModelWithUnknownShapes()
|
||||
float_converter = lite.TFLiteConverterV2.from_concrete_functions(
|
||||
[concrete_func])
|
||||
float_tflite_model = float_converter.convert()
|
||||
@ -1177,14 +1177,15 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
input_details = quantized_interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue((input_details[0]['shape_signature'] == [-1, 33]).all())
|
||||
self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMatMulCalibrateAndQuantize(self):
|
||||
concrete_func, calibration_gen = self._getQuantizedModel()
|
||||
concrete_func, calibration_gen = \
|
||||
self._getIntegerQuantizeModelWithUnknownShapes()
|
||||
float_converter = lite.TFLiteConverterV2.from_concrete_functions(
|
||||
[concrete_func])
|
||||
float_tflite_model = float_converter.convert()
|
||||
@ -1201,7 +1202,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
input_details = quantized_interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual(np.float32, input_details[0]['dtype'])
|
||||
self.assertTrue((input_details[0]['shape_signature'] == [-1, 33]).all())
|
||||
self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
|
||||
|
||||
# Ensure that the quantized weights tflite model is smaller.
|
||||
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
|
||||
@ -1228,9 +1229,8 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
|
||||
expected_value = concrete_func(input_data_1, input_data_2)
|
||||
actual_value = self._evaluateTFLiteModel(
|
||||
tflite_model, [input_data_1, input_data_2],
|
||||
input_shapes=[([-1, 256, 256], [1, 256, 256])])
|
||||
np.testing.assert_almost_equal(
|
||||
expected_value.numpy(), actual_value[0], decimal=4)
|
||||
input_shapes=[([-1, 256, 256], [1, 256, 256])])[0]
|
||||
self.assertAllClose(expected_value, actual_value, atol=4)
|
||||
|
||||
def testSizeInvalid(self):
|
||||
|
||||
|
@ -77,6 +77,7 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
def _getMultiFunctionModel(self):
|
||||
|
||||
class BasicModel(tracking.AutoTrackable):
|
||||
"""Basic model with multiple functions."""
|
||||
|
||||
def __init__(self):
|
||||
self.y = None
|
||||
|
@ -48,16 +48,16 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g
|
||||
_MAP_TF_TO_TFLITE_TYPES = {
|
||||
dtypes.float32: _types_pb2.FLOAT,
|
||||
dtypes.float16: _types_pb2.FLOAT16,
|
||||
dtypes.float64: _types_pb2.FLOAT64,
|
||||
dtypes.int32: _types_pb2.INT32,
|
||||
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
|
||||
dtypes.int64: _types_pb2.INT64,
|
||||
dtypes.string: _types_pb2.STRING,
|
||||
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
|
||||
dtypes.int8: _types_pb2.INT8,
|
||||
dtypes.bool: _types_pb2.BOOL,
|
||||
dtypes.int16: _types_pb2.QUANTIZED_INT16,
|
||||
dtypes.complex64: _types_pb2.COMPLEX64,
|
||||
dtypes.int8: _types_pb2.INT8,
|
||||
dtypes.float64: _types_pb2.FLOAT64,
|
||||
dtypes.complex128: _types_pb2.COMPLEX128,
|
||||
dtypes.bool: _types_pb2.BOOL,
|
||||
}
|
||||
|
||||
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
||||
@ -72,6 +72,7 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
||||
8: dtypes.complex64,
|
||||
9: dtypes.int8,
|
||||
10: dtypes.float64,
|
||||
11: dtypes.complex128,
|
||||
}
|
||||
|
||||
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
||||
@ -113,7 +114,7 @@ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
|
||||
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
|
||||
if tf_type is None:
|
||||
raise ValueError(
|
||||
"Unsupported enum {}. The valid map of enum to tf.dtypes is : {}"
|
||||
"Unsupported enum {}. The valid map of enum to tf types is : {}"
|
||||
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
|
||||
return tf_type
|
||||
|
||||
|
@ -42,27 +42,34 @@ from tensorflow.python.platform import test
|
||||
class UtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testConvertDtype(self):
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(lite_constants.FLOAT),
|
||||
_types_pb2.FLOAT)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.float32), _types_pb2.FLOAT)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.float16), _types_pb2.FLOAT16)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.int32), _types_pb2.INT32)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.uint8),
|
||||
_types_pb2.QUANTIZED_UINT8)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.int64), _types_pb2.INT64)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.string), _types_pb2.STRING)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.uint8),
|
||||
_types_pb2.QUANTIZED_UINT8)
|
||||
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.int16),
|
||||
_types_pb2.QUANTIZED_INT16)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.complex64),
|
||||
_types_pb2.COMPLEX64)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.half), _types_pb2.FLOAT16)
|
||||
util.convert_dtype_to_tflite_type(dtypes.int8), _types_pb2.INT8)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
|
||||
util.convert_dtype_to_tflite_type(dtypes.float64), _types_pb2.FLOAT64)
|
||||
self.assertEqual(
|
||||
util.convert_dtype_to_tflite_type(dtypes.complex128),
|
||||
_types_pb2.COMPLEX128)
|
||||
|
||||
def testConvertEnumToDtype(self):
|
||||
self.assertEqual(
|
||||
@ -81,17 +88,19 @@ class UtilTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8)
|
||||
self.assertEqual(
|
||||
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
util._convert_tflite_enum_type_to_tf_type(11)
|
||||
self.assertEqual(
|
||||
"Unsupported enum 11. The valid map of enum to tf.dtypes is : "
|
||||
util._convert_tflite_enum_type_to_tf_type(11), dtypes.complex128)
|
||||
with self.assertRaises(ValueError) as error:
|
||||
util._convert_tflite_enum_type_to_tf_type(20)
|
||||
self.assertEqual(
|
||||
"Unsupported enum 20. The valid map of enum to tf types is : "
|
||||
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
|
||||
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
|
||||
"10: tf.float64}", str(error.exception))
|
||||
"10: tf.float64, 11: tf.complex128}", str(error.exception))
|
||||
|
||||
def testTensorName(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
|
||||
in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[4])
|
||||
out_tensors = array_ops.split(
|
||||
value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
|
||||
|
||||
@ -103,7 +112,7 @@ class UtilTest(test_util.TensorFlowTestCase):
|
||||
@test_util.enable_control_flow_v2
|
||||
def testRemoveLowerUsingSwitchMerge(self):
|
||||
with ops.Graph().as_default():
|
||||
i = array_ops.placeholder(shape=(), dtype=dtypes.int32)
|
||||
i = array_ops.placeholder(dtype=dtypes.int32, shape=())
|
||||
c = lambda i: math_ops.less(i, 10)
|
||||
b = lambda i: math_ops.add(i, 1)
|
||||
control_flow_ops.while_loop(c, b, [i])
|
||||
@ -116,7 +125,7 @@ class UtilTest(test_util.TensorFlowTestCase):
|
||||
if node.op == "While" or node.op == "StatelessWhile":
|
||||
if not node.attr["_lower_using_switch_merge"].b:
|
||||
lower_using_switch_merge_is_removed = True
|
||||
self.assertEqual(lower_using_switch_merge_is_removed, True)
|
||||
self.assertTrue(lower_using_switch_merge_is_removed)
|
||||
|
||||
def testConvertBytes(self):
|
||||
source, header = util.convert_bytes_to_c_source(
|
||||
@ -154,7 +163,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
||||
def testGetTensorsValid(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
dtype=dtypes.float32, shape=[1, 16, 16, 3])
|
||||
_ = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
@ -164,7 +173,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
||||
def testGetTensorsInvalid(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
dtype=dtypes.float32, shape=[1, 16, 16, 3])
|
||||
_ = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
@ -175,52 +184,51 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testSetTensorShapeValid(self):
|
||||
with ops.Graph().as_default():
|
||||
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
util.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
|
||||
self.assertEqual([5, 3, 5], tensor.shape.as_list())
|
||||
self.assertAllEqual([5, 3, 5], tensor.shape)
|
||||
|
||||
def testSetTensorShapeNoneValid(self):
|
||||
with ops.Graph().as_default():
|
||||
tensor = array_ops.placeholder(dtype=dtypes.float32)
|
||||
self.assertEqual(None, tensor.shape)
|
||||
|
||||
util.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
|
||||
self.assertEqual([1, 3, 5], tensor.shape.as_list())
|
||||
self.assertAllEqual([1, 3, 5], tensor.shape)
|
||||
|
||||
def testSetTensorShapeArrayInvalid(self):
|
||||
# Tests set_tensor_shape where the tensor name passed in doesn't exist.
|
||||
with ops.Graph().as_default():
|
||||
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
util.set_tensor_shapes([tensor], {"invalid-input": [5, 3, 5]})
|
||||
self.assertEqual(
|
||||
"Invalid tensor 'invalid-input' found in tensor shapes map.",
|
||||
str(error.exception))
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
def testSetTensorShapeDimensionInvalid(self):
|
||||
# Tests set_tensor_shape where the shape passed in is incompatible.
|
||||
with ops.Graph().as_default():
|
||||
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
with self.assertRaises(ValueError) as error:
|
||||
util.set_tensor_shapes([tensor], {"Placeholder": [1, 5, 5]})
|
||||
self.assertIn("The shape of tensor 'Placeholder' cannot be changed",
|
||||
str(error.exception))
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
def testSetTensorShapeEmpty(self):
|
||||
with ops.Graph().as_default():
|
||||
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
util.set_tensor_shapes([tensor], {})
|
||||
self.assertEqual([None, 3, 5], tensor.shape.as_list())
|
||||
self.assertAllEqual([None, 3, 5], tensor.shape)
|
||||
|
||||
|
||||
def _generate_integer_tflite_model():
|
||||
@ -355,7 +363,7 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
|
||||
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
|
||||
|
||||
# Validate that both the outputs are the same
|
||||
self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0))
|
||||
self.assertAllClose(output_data, output_io_data, atol=1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user