diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index d0dd7313df3..c7504a3a638 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -65,6 +65,8 @@ def _parse_inference_type(value, flag): return lite_constants.FLOAT if value == "QUANTIZED_UINT8": return lite_constants.QUANTIZED_UINT8 + if value == "INT8": + return lite_constants.INT8 raise ValueError("Unsupported value for --{0}. Only FLOAT and " "QUANTIZED_UINT8 are supported.".format(flag)) @@ -352,12 +354,12 @@ def _get_tf1_flags(parser): parser.add_argument( "--inference_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8"], + choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], help="Target data type of real-number arrays in the output file.") parser.add_argument( "--inference_input_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8"], + choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], help=("Target data type of real-number input arrays. Allows for a " "different type for input arrays in the case of quantization.")) diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py index 1e80907edbd..d6a35ba9248 100644 --- a/tensorflow/lite/python/tflite_convert_test.py +++ b/tensorflow/lite/python/tflite_convert_test.py @@ -98,8 +98,8 @@ class TfLiteConvertV1Test(TestModels): sess.close() flags_str = ('--graph_def_file={0} --input_arrays={1} ' - '--output_arrays={2}'.format(graph_def_file, - 'Placeholder', 'add')) + '--output_arrays={2}'.format(graph_def_file, 'Placeholder', + 'add')) self._run(flags_str, should_succeed=True) os.remove(graph_def_file) @@ -137,8 +137,31 @@ class TfLiteConvertV1Test(TestModels): sess.close() flags_str = ('--graph_def_file={0} --input_arrays={1} ' - '--output_arrays={2}'.format(graph_def_file, - 'random', 'add')) + '--output_arrays={2}'.format(graph_def_file, 'random', 'add')) + self._run(flags_str, should_succeed=True) + os.remove(graph_def_file) + + def testQATFrozenGraphDefInt8(self): + with ops.Graph().as_default(): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + _ = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output', + num_bits=16) # INT8 inference type works for 16 bits fake quant. + sess = session.Session() + + # Write graph to file. + graph_def_file = self._getFilepath('model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + sess.close() + + flags_str = ('--inference_type=INT8 --std_dev_values=128,128 ' + '--mean_values=128,128 ' + '--graph_def_file={0} --input_arrays={1},{2} ' + '--output_arrays={3}'.format(graph_def_file, 'inputA', + 'inputB', 'output')) self._run(flags_str, should_succeed=True) os.remove(graph_def_file) @@ -166,8 +189,8 @@ class TfLiteConvertV1Test(TestModels): def testKerasFileMLIR(self): keras_file = self._getKerasModelFile() - flags_str = ('--keras_model_file={} --experimental_new_converter' - .format(keras_file)) + flags_str = ( + '--keras_model_file={} --experimental_new_converter'.format(keras_file)) self._run(flags_str, should_succeed=True) os.remove(keras_file) @@ -299,8 +322,8 @@ class TfLiteConvertV2Test(TestModels): def testKerasFileMLIR(self): keras_file = self._getKerasModelFile() - flags_str = ('--keras_model_file={} --experimental_new_converter' - .format(keras_file)) + flags_str = ( + '--keras_model_file={} --experimental_new_converter'.format(keras_file)) self._run(flags_str, should_succeed=True) os.remove(keras_file)