Refactor and Fix lint errors in util.py and lite*.py files

PiperOrigin-RevId: 324727472
Change-Id: I3766b0724564f91216bffcc8b55f70744fd94334
This commit is contained in:
Meghna Natraj 2020-08-03 18:37:33 -07:00 committed by TensorFlower Gardener
parent e4592dad25
commit 1277f67514
6 changed files with 607 additions and 665 deletions

View File

@ -125,7 +125,7 @@ class Optimize(enum.Enum):
OPTIMIZE_FOR_LATENCY = "OPTIMIZE_FOR_LATENCY"
def __str__(self):
return self.value
return str(self.value)
@_tf_export("lite.RepresentativeDataset")
@ -230,7 +230,7 @@ class QuantizationMode(object):
def post_training_int16x8_allow_float(self):
"""Post training int16x8 quantize, allow float fallback."""
return (self._is_int16x8_target_required() and self._is_allow_float())
return self._is_int16x8_target_required() and self._is_allow_float()
def post_training_dynamic_range_int8(self):
"""Post training int8 const, on-the-fly int8 quantize of dynamic tensors."""
@ -907,7 +907,7 @@ class TFLiteFrozenGraphConverterV2(TFLiteConverterBaseV2):
"""
# TODO(b/130297984): Add support for converting multiple function.
if len(self._funcs) == 0:
if len(self._funcs) == 0: # pylint: disable=g-explicit-length-test
raise ValueError("No ConcreteFunction is specified.")
if len(self._funcs) > 1:
@ -1127,7 +1127,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase):
parameter is ignored. (default tf.float32)
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays. If an integer type is provided and
`optimizations` are not used, `quantized_inputs_stats` must be provided.
`optimizations` are not used, `quantized_input_stats` must be provided.
If `inference_type` is tf.uint8, signaling conversion to a fully quantized
model from a quantization-aware trained input model, then
`inference_input_type` defaults to tf.uint8. In all other cases,
@ -1681,7 +1681,7 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
inference_input_type: Target data type of real-number input arrays. Allows
for a different type for input arrays.
If an integer type is provided and `optimizations` are not used,
`quantized_inputs_stats` must be provided.
`quantized_input_stats` must be provided.
If `inference_type` is tf.uint8, signaling conversion to a fully quantized
model from a quantization-aware trained input model, then
`inference_input_type` defaults to tf.uint8.
@ -2012,6 +2012,7 @@ class TFLiteConverter(TFLiteFrozenGraphConverter):
"""
return super(TFLiteConverter, self).convert()
@_tf_export(v1=["lite.TocoConverter"])
class TocoConverter(object):
"""Convert a TensorFlow model into `output_format` using TOCO.

File diff suppressed because it is too large Load Diff

View File

@ -58,14 +58,14 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
@test_util.run_v2_only
def testFloat(self, enable_mlir):
def testFloat(self, enable_mlir_converter):
root = self._getSimpleVariableModel()
input_data = tf.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = enable_mlir
converter.experimental_new_converter = enable_mlir_converter
tflite_model = converter.convert()
# Check values from converted model.
@ -142,7 +142,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertIn('can only convert a single ConcreteFunction',
str(error.exception))
def _getCalibrationQuantizeModel(self):
def _getIntegerQuantizeModel(self):
np.random.seed(0)
root = tracking.AutoTrackable()
@ -167,23 +167,23 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
('EnableMlirQuantizer', True), # enable mlir quantizer
('DisableMlirQuantizer', False)) # disable mlir quantizer
def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer):
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
float_tflite_model = float_converter.convert()
self.assertIsNotNone(float_tflite_model)
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.representative_dataset = calibration_gen
quantized_converter._experimental_new_quantizer = mlir_quantizer
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
@ -193,7 +193,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@parameterized.named_parameters(
('_INT8InputOutput', lite.constants.INT8),
@ -202,7 +202,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testInvalidPostTrainingDynamicRangeQuantization(
self, inference_input_output_type):
func, _ = self._getCalibrationQuantizeModel()
func, _ = self._getIntegerQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
@ -228,7 +228,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8))
def testPostTrainingIntegerAllowFloatQuantization(
self, inference_input_output_type):
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
@ -242,7 +242,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertTrue(quantized_tflite_model)
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
@ -259,7 +259,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertLess(len(quantized_tflite_model), len(tflite_model))
def testPostTrainingIntegerAllowFloatQuantizationINT16InputOutput(self):
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
@ -279,7 +279,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertTrue(quantized_tflite_model)
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
@ -299,7 +299,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
# In this test we check that when we do 16x8 post-training
# quantization and set inference_input(output)_type to
# constants.INT8, we have an error.
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert quantized model.
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
@ -330,7 +330,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
inference_input_output_type,
use_target_ops_flag,
quantization_16x8):
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
@ -357,7 +357,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite_model = quantized_converter.convert()
self.assertTrue(quantized_tflite_model)
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
@ -374,12 +374,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertLess(len(quantized_tflite_model), len(tflite_model))
def testCalibrateAndQuantizeBuiltinInt16(self):
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
# Convert float model.
float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
float_tflite_model = float_converter.convert()
self.assertIsNotNone(float_tflite_model)
converter = lite.TFLiteConverterV2.from_concrete_functions([func])
# TODO(b/156309549): We should add INT16 to the builtin types.
@ -389,13 +389,13 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
converter.representative_dataset = calibration_gen
converter._experimental_calibrate_only = True
calibrated_tflite = converter.convert()
quantized_tflite = mlir_quantize(calibrated_tflite,
inference_type=_types_pb2.QUANTIZED_INT16)
quantized_tflite_model = mlir_quantize(
calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16)
self.assertTrue(quantized_tflite)
self.assertIsNotNone(quantized_tflite_model)
# The default input and output types should be float.
interpreter = Interpreter(model_content=quantized_tflite)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
@ -405,7 +405,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
self.assertEqual(np.float32, output_details[0]['dtype'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
def _getTrainingTimeQuantizedModel(self):
@ -454,17 +454,17 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
model = self._getTrainingTimeQuantizedModel()
float_converter = lite.TFLiteConverterV2.from_keras_model(model)
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
float_tflite_model = float_converter.convert()
self.assertIsNotNone(float_tflite_model)
quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
quantized_converter.optimizations = [lite.Optimize.DEFAULT]
quantized_converter.inference_input_type = inference_input_output_type
quantized_converter.inference_output_type = inference_input_output_type
quantized_tflite = quantized_converter.convert()
self.assertTrue(quantized_tflite)
quantized_tflite_model = quantized_converter.convert()
self.assertIsNotNone(quantized_tflite_model)
interpreter = Interpreter(model_content=quantized_tflite)
interpreter = Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
@ -476,12 +476,12 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
output_details[0]['dtype'])
# Ensure that the quantized tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite))
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@test_util.run_v2_only
def testNewQuantizer(self):
"""Test the model quantized by the new converter."""
func, calibration_gen = self._getCalibrationQuantizeModel()
func, calibration_gen = self._getIntegerQuantizeModel()
quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
quantized_converter.target_spec.supported_ops = [
@ -502,13 +502,13 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
old_value = self._evaluateTFLiteModel(old_tflite, [input_data])
new_value = self._evaluateTFLiteModel(new_tflite, [input_data])
np.testing.assert_almost_equal(old_value, new_value, 1)
self.assertAllClose(old_value, new_value, atol=1e-01)
@parameterized.named_parameters(
('EnableMlirConverter', True), # enable mlir
('DisableMlirConverter', False)) # disable mlir
@test_util.run_v2_only
def testEmbeddings(self, enable_mlir):
def testEmbeddings(self, enable_mlir_converter):
"""Test model with embeddings."""
input_data = tf.constant(
np.array(np.random.random_sample((20)), dtype=np.int32))
@ -534,13 +534,13 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.experimental_new_converter = enable_mlir
converter.experimental_new_converter = enable_mlir_converter
tflite_model = converter.convert()
# Check values from converted model.
expected_value = root.func(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0], 5)
self.assertAllClose(expected_value.numpy(), actual_value[0], atol=1e-05)
@test_util.run_v2_only
def testGraphDebugInfo(self):
@ -594,7 +594,7 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
self.assertLen(input_details, 2)
self.assertStartsWith(input_details[0]['name'], 'inputA')
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
self.assertEqual((0., 0.), input_details[0]['quantization'])
self.assertStartsWith(
@ -602,14 +602,14 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
'inputB',
)
self.assertEqual(np.float32, input_details[1]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
self.assertTrue([1, 16, 16, 3], input_details[1]['shape'])
self.assertEqual((0., 0.), input_details[1]['quantization'])
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertStartsWith(output_details[0]['name'], 'add')
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue([1, 16, 16, 3], output_details[0]['shape'])
self.assertEqual((0., 0.), output_details[0]['quantization'])
@test_util.run_v2_only
@ -715,7 +715,6 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testNoConcreteFunctionModel(self):
root = self._getMultiFunctionModel()
input_data = tf.constant(1., shape=[1])
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir)
@ -836,7 +835,7 @@ class FromKerasModelTest(lite_v2_test_util.ModelTest):
expected_value = model.predict(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, input_data)
for tf_result, tflite_result in zip(expected_value, actual_value):
np.testing.assert_almost_equal(tf_result, tflite_result, 5)
self.assertAllClose(tf_result, tflite_result, atol=1e-05)
@test_util.run_v2_only
def testGraphDebugInfo(self):
@ -919,7 +918,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
expected_value = concrete_func(**input_data)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data['x'], input_data['b']])[0]
np.testing.assert_almost_equal(expected_value.numpy(), actual_value)
self.assertAllClose(expected_value, actual_value)
@test_util.run_v2_only
def testStaticRnn(self):
@ -945,7 +944,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
expected_value = concrete_func(input_data)[0]
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
for expected, actual in zip(expected_value, actual_value):
np.testing.assert_almost_equal(expected.numpy(), actual)
self.assertAllClose(expected, actual)
@test_util.run_v2_only
def testWhileLoop(self):
@ -973,7 +972,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
# Check values from converted model.
expected_value = concrete_func(input_data)[0]
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
np.testing.assert_almost_equal(expected_value.numpy(), actual_value)
self.assertAllClose(expected_value, actual_value)
@test_util.run_v2_only
def testDynamicRnn(self):
@ -997,11 +996,9 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
expected_value = concrete_func(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
for expected, actual in zip(expected_value, actual_value):
if isinstance(expected, ops.EagerTensor):
expected = expected.numpy()
else:
expected = expected.c.numpy()
np.testing.assert_almost_equal(expected, actual)
if not isinstance(expected, ops.EagerTensor):
expected = expected.c
self.assertAllClose(expected, actual)
@parameterized.named_parameters(('LSTM', recurrent_v2.LSTM),
('SimpleRNN', recurrent.SimpleRNN),
@ -1025,7 +1022,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
# Check values from converted model.
expected_value = model.predict(input_data)
np.testing.assert_almost_equal(expected_value, actual_value, decimal=5)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
@parameterized.named_parameters(('LSTM', recurrent_v2.LSTM),
('SimpleRNN', recurrent.SimpleRNN),
@ -1046,7 +1043,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
# Check values from converted model.
expected_value = model.predict(input_data)
np.testing.assert_almost_equal(expected_value, actual_value, decimal=5)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
@test_util.run_v2_only
def testKerasBidirectionalRNN(self):
@ -1069,7 +1066,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
# Check values from converted model.
expected_value = model.predict(input_data)
np.testing.assert_almost_equal(expected_value, actual_value, decimal=5)
self.assertAllClose(expected_value, actual_value, atol=1e-05)
class GrapplerTest(lite_v2_test_util.ModelTest):
@ -1096,14 +1093,14 @@ class GrapplerTest(lite_v2_test_util.ModelTest):
# Check values from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value, actual_value)
# Enable hybrid quantization, same result
converter.optimizations = [lite.Optimize.DEFAULT]
hybrid_tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(hybrid_tflite_model, [input_data])
np.testing.assert_almost_equal(expected_value.numpy(), actual_value[0])
tflite_model = converter.convert()
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
self.assertAllClose(expected_value, actual_value)
class UnknownShapes(lite_v2_test_util.ModelTest):
@ -1128,15 +1125,16 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
# Check values from converted model.
expected_value = concrete_func(input_data)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])
np.testing.assert_almost_equal(
expected_value.numpy(), actual_value[0], decimal=6)
tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])[0]
self.assertAllClose(expected_value, actual_value, atol=1e-06)
def _getIntegerQuantizeModelWithUnknownShapes(self):
np.random.seed(0)
def _getQuantizedModel(self):
# Returns a model with tf.MatMul and unknown dimensions.
@tf.function(
input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)])
def model(in_tensor):
def model(input_tensor):
"""Define a model with tf.MatMul and unknown shapes."""
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
const_tensor = tf.constant(
@ -1145,12 +1143,14 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
dtype=tf.float32,
name='inputB')
shape = tf.shape(in_tensor)
shape = tf.shape(input_tensor)
fill = tf.transpose(tf.fill(shape, 1.))
mult = tf.matmul(fill, in_tensor)
mult = tf.matmul(fill, input_tensor)
return tf.matmul(mult, const_tensor)
concrete_func = model.get_concrete_function()
root = tracking.AutoTrackable()
root.f = model
concrete_func = root.f.get_concrete_function()
def calibration_gen():
for batch in range(5, 20, 5):
@ -1161,7 +1161,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
@test_util.run_v2_only
def testMatMulQuantize(self):
concrete_func, _ = self._getQuantizedModel()
concrete_func, _ = self._getIntegerQuantizeModelWithUnknownShapes()
float_converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
float_tflite_model = float_converter.convert()
@ -1177,14 +1177,15 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
input_details = quantized_interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue((input_details[0]['shape_signature'] == [-1, 33]).all())
self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@test_util.run_v2_only
def testMatMulCalibrateAndQuantize(self):
concrete_func, calibration_gen = self._getQuantizedModel()
concrete_func, calibration_gen = \
self._getIntegerQuantizeModelWithUnknownShapes()
float_converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func])
float_tflite_model = float_converter.convert()
@ -1201,7 +1202,7 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
input_details = quantized_interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual(np.float32, input_details[0]['dtype'])
self.assertTrue((input_details[0]['shape_signature'] == [-1, 33]).all())
self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
# Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
@ -1228,9 +1229,8 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
expected_value = concrete_func(input_data_1, input_data_2)
actual_value = self._evaluateTFLiteModel(
tflite_model, [input_data_1, input_data_2],
input_shapes=[([-1, 256, 256], [1, 256, 256])])
np.testing.assert_almost_equal(
expected_value.numpy(), actual_value[0], decimal=4)
input_shapes=[([-1, 256, 256], [1, 256, 256])])[0]
self.assertAllClose(expected_value, actual_value, atol=4)
def testSizeInvalid(self):

View File

@ -77,6 +77,7 @@ class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _getMultiFunctionModel(self):
class BasicModel(tracking.AutoTrackable):
"""Basic model with multiple functions."""
def __init__(self):
self.y = None

View File

@ -48,16 +48,16 @@ from tensorflow.python.training.saver import export_meta_graph as _export_meta_g
_MAP_TF_TO_TFLITE_TYPES = {
dtypes.float32: _types_pb2.FLOAT,
dtypes.float16: _types_pb2.FLOAT16,
dtypes.float64: _types_pb2.FLOAT64,
dtypes.int32: _types_pb2.INT32,
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
dtypes.int64: _types_pb2.INT64,
dtypes.string: _types_pb2.STRING,
dtypes.uint8: _types_pb2.QUANTIZED_UINT8,
dtypes.int8: _types_pb2.INT8,
dtypes.bool: _types_pb2.BOOL,
dtypes.int16: _types_pb2.QUANTIZED_INT16,
dtypes.complex64: _types_pb2.COMPLEX64,
dtypes.int8: _types_pb2.INT8,
dtypes.float64: _types_pb2.FLOAT64,
dtypes.complex128: _types_pb2.COMPLEX128,
dtypes.bool: _types_pb2.BOOL,
}
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
@ -72,6 +72,7 @@ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
8: dtypes.complex64,
9: dtypes.int8,
10: dtypes.float64,
11: dtypes.complex128,
}
_TFLITE_FILE_IDENTIFIER = b"TFL3"
@ -113,7 +114,7 @@ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
if tf_type is None:
raise ValueError(
"Unsupported enum {}. The valid map of enum to tf.dtypes is : {}"
"Unsupported enum {}. The valid map of enum to tf types is : {}"
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
return tf_type

View File

@ -42,27 +42,34 @@ from tensorflow.python.platform import test
class UtilTest(test_util.TensorFlowTestCase):
def testConvertDtype(self):
self.assertEqual(
util.convert_dtype_to_tflite_type(lite_constants.FLOAT),
_types_pb2.FLOAT)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.float32), _types_pb2.FLOAT)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.float16), _types_pb2.FLOAT16)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.int32), _types_pb2.INT32)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.uint8),
_types_pb2.QUANTIZED_UINT8)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.int64), _types_pb2.INT64)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.string), _types_pb2.STRING)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.uint8),
_types_pb2.QUANTIZED_UINT8)
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.int16),
_types_pb2.QUANTIZED_INT16)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.complex64),
_types_pb2.COMPLEX64)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.half), _types_pb2.FLOAT16)
util.convert_dtype_to_tflite_type(dtypes.int8), _types_pb2.INT8)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.bool), _types_pb2.BOOL)
util.convert_dtype_to_tflite_type(dtypes.float64), _types_pb2.FLOAT64)
self.assertEqual(
util.convert_dtype_to_tflite_type(dtypes.complex128),
_types_pb2.COMPLEX128)
def testConvertEnumToDtype(self):
self.assertEqual(
@ -81,17 +88,19 @@ class UtilTest(test_util.TensorFlowTestCase):
self.assertEqual(util._convert_tflite_enum_type_to_tf_type(9), dtypes.int8)
self.assertEqual(
util._convert_tflite_enum_type_to_tf_type(10), dtypes.float64)
with self.assertRaises(ValueError) as error:
util._convert_tflite_enum_type_to_tf_type(11)
self.assertEqual(
"Unsupported enum 11. The valid map of enum to tf.dtypes is : "
util._convert_tflite_enum_type_to_tf_type(11), dtypes.complex128)
with self.assertRaises(ValueError) as error:
util._convert_tflite_enum_type_to_tf_type(20)
self.assertEqual(
"Unsupported enum 20. The valid map of enum to tf types is : "
"{0: tf.float32, 1: tf.float16, 2: tf.int32, 3: tf.uint8, 4: tf.int64, "
"5: tf.string, 6: tf.bool, 7: tf.int16, 8: tf.complex64, 9: tf.int8, "
"10: tf.float64}", str(error.exception))
"10: tf.float64, 11: tf.complex128}", str(error.exception))
def testTensorName(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[4])
out_tensors = array_ops.split(
value=in_tensor, num_or_size_splits=[1, 1, 1, 1], axis=0)
@ -103,7 +112,7 @@ class UtilTest(test_util.TensorFlowTestCase):
@test_util.enable_control_flow_v2
def testRemoveLowerUsingSwitchMerge(self):
with ops.Graph().as_default():
i = array_ops.placeholder(shape=(), dtype=dtypes.int32)
i = array_ops.placeholder(dtype=dtypes.int32, shape=())
c = lambda i: math_ops.less(i, 10)
b = lambda i: math_ops.add(i, 1)
control_flow_ops.while_loop(c, b, [i])
@ -116,7 +125,7 @@ class UtilTest(test_util.TensorFlowTestCase):
if node.op == "While" or node.op == "StatelessWhile":
if not node.attr["_lower_using_switch_merge"].b:
lower_using_switch_merge_is_removed = True
self.assertEqual(lower_using_switch_merge_is_removed, True)
self.assertTrue(lower_using_switch_merge_is_removed)
def testConvertBytes(self):
source, header = util.convert_bytes_to_c_source(
@ -154,7 +163,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
def testGetTensorsValid(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
dtype=dtypes.float32, shape=[1, 16, 16, 3])
_ = in_tensor + in_tensor
sess = session.Session()
@ -164,7 +173,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
def testGetTensorsInvalid(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
dtype=dtypes.float32, shape=[1, 16, 16, 3])
_ = in_tensor + in_tensor
sess = session.Session()
@ -175,52 +184,51 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
def testSetTensorShapeValid(self):
with ops.Graph().as_default():
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
self.assertEqual([None, 3, 5], tensor.shape.as_list())
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
self.assertAllEqual([None, 3, 5], tensor.shape)
util.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
self.assertEqual([5, 3, 5], tensor.shape.as_list())
self.assertAllEqual([5, 3, 5], tensor.shape)
def testSetTensorShapeNoneValid(self):
with ops.Graph().as_default():
tensor = array_ops.placeholder(dtype=dtypes.float32)
self.assertEqual(None, tensor.shape)
util.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
self.assertEqual([1, 3, 5], tensor.shape.as_list())
self.assertAllEqual([1, 3, 5], tensor.shape)
def testSetTensorShapeArrayInvalid(self):
# Tests set_tensor_shape where the tensor name passed in doesn't exist.
with ops.Graph().as_default():
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
self.assertEqual([None, 3, 5], tensor.shape.as_list())
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
self.assertAllEqual([None, 3, 5], tensor.shape)
with self.assertRaises(ValueError) as error:
util.set_tensor_shapes([tensor], {"invalid-input": [5, 3, 5]})
self.assertEqual(
"Invalid tensor 'invalid-input' found in tensor shapes map.",
str(error.exception))
self.assertEqual([None, 3, 5], tensor.shape.as_list())
self.assertAllEqual([None, 3, 5], tensor.shape)
def testSetTensorShapeDimensionInvalid(self):
# Tests set_tensor_shape where the shape passed in is incompatible.
with ops.Graph().as_default():
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
self.assertEqual([None, 3, 5], tensor.shape.as_list())
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
self.assertAllEqual([None, 3, 5], tensor.shape)
with self.assertRaises(ValueError) as error:
util.set_tensor_shapes([tensor], {"Placeholder": [1, 5, 5]})
self.assertIn("The shape of tensor 'Placeholder' cannot be changed",
str(error.exception))
self.assertEqual([None, 3, 5], tensor.shape.as_list())
self.assertAllEqual([None, 3, 5], tensor.shape)
def testSetTensorShapeEmpty(self):
with ops.Graph().as_default():
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
self.assertEqual([None, 3, 5], tensor.shape.as_list())
tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
self.assertAllEqual([None, 3, 5], tensor.shape)
util.set_tensor_shapes([tensor], {})
self.assertEqual([None, 3, 5], tensor.shape.as_list())
self.assertAllEqual([None, 3, 5], tensor.shape)
def _generate_integer_tflite_model():
@ -355,7 +363,7 @@ class UtilModifyIntegerQuantizedModelIOTypeTest(
output_io_data = _run_tflite_inference(model_io, in_tftype, out_tftype)
# Validate that both the outputs are the same
self.assertTrue(np.allclose(output_data, output_io_data, atol=1.0))
self.assertAllClose(output_data, output_io_data, atol=1.0)
if __name__ == "__main__":