Add an integration tests for inference_input_type=uint8 and inference_type=float

PiperOrigin-RevId: 293862594
Change-Id: Ic4435ac836e34fec4e5c65bc0fe485020b3ba7ff
This commit is contained in:
Feng Liu 2020-02-07 11:50:37 -08:00 committed by TensorFlower Gardener
parent c2d8449859
commit 75e2a0a14f

View File

@ -257,6 +257,43 @@ class FromSessionTest(TestModels, parameterized.TestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
def testQuantizedInput(self):
with ops.Graph().as_default():
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
[out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
converter.inference_type = lite_constants.FLOAT
converter.quantized_input_stats = {
'Placeholder': (0., 1.)
} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
self.assertLen(input_details, 1)
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((1., 0.),
input_details[0]['quantization']) # scale, zero_point
output_details = interpreter.get_output_details()
self.assertLen(output_details, 1)
self.assertEqual('add', output_details[0]['name'])
self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization']) # float
def testQuantizationInvalid(self):
with ops.Graph().as_default():
in_tensor_1 = array_ops.placeholder(