Add an integration tests for inference_input_type=uint8 and inference_type=float
PiperOrigin-RevId: 293862594 Change-Id: Ic4435ac836e34fec4e5c65bc0fe485020b3ba7ff
This commit is contained in:
parent
c2d8449859
commit
75e2a0a14f
@ -257,6 +257,43 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
|
||||
|
||||
def testQuantizedInput(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32)
|
||||
out_tensor = in_tensor + in_tensor
|
||||
sess = session.Session()
|
||||
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
|
||||
[out_tensor])
|
||||
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
|
||||
converter.inference_type = lite_constants.FLOAT
|
||||
converter.quantized_input_stats = {
|
||||
'Placeholder': (0., 1.)
|
||||
} # mean, std_dev
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
self.assertLen(input_details, 1)
|
||||
self.assertEqual('Placeholder', input_details[0]['name'])
|
||||
self.assertEqual(np.uint8, input_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
|
||||
self.assertEqual((1., 0.),
|
||||
input_details[0]['quantization']) # scale, zero_point
|
||||
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertLen(output_details, 1)
|
||||
self.assertEqual('add', output_details[0]['name'])
|
||||
self.assertEqual(np.float32, output_details[0]['dtype'])
|
||||
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization']) # float
|
||||
|
||||
def testQuantizationInvalid(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor_1 = array_ops.placeholder(
|
||||
|
Loading…
Reference in New Issue
Block a user