From 75e2a0a14f4a127892477c29044de1f3a6c4b242 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 7 Feb 2020 11:50:37 -0800 Subject: [PATCH] Add an integration tests for inference_input_type=uint8 and inference_type=float PiperOrigin-RevId: 293862594 Change-Id: Ic4435ac836e34fec4e5c65bc0fe485020b3ba7ff --- tensorflow/lite/python/lite_test.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 8c1f10af530..7977b30e7ae 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -257,6 +257,43 @@ class FromSessionTest(TestModels, parameterized.TestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testQuantizedInput(self): + with ops.Graph().as_default(): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = lite_constants.FLOAT + converter.quantized_input_stats = { + 'Placeholder': (0., 1.) + } # mean, std_dev + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertLen(input_details, 1) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), + input_details[0]['quantization']) # scale, zero_point + + output_details = interpreter.get_output_details() + self.assertLen(output_details, 1) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) # float + def testQuantizationInvalid(self): with ops.Graph().as_default(): in_tensor_1 = array_ops.placeholder(