Support int8 in tflite_convert

PiperOrigin-RevId: 312105323
Change-Id: I161b9b324e37f42f2026592f7c5bec8ac568c3d6
This commit is contained in:
Feng Liu 2020-05-18 10:23:36 -07:00 committed by TensorFlower Gardener
parent 0bf90cb2a8
commit 83b85568fb
2 changed files with 35 additions and 10 deletions

View File

@ -65,6 +65,8 @@ def _parse_inference_type(value, flag):
return lite_constants.FLOAT
if value == "QUANTIZED_UINT8":
return lite_constants.QUANTIZED_UINT8
if value == "INT8":
return lite_constants.INT8
raise ValueError("Unsupported value for --{0}. Only FLOAT and "
"QUANTIZED_UINT8 are supported.".format(flag))
@ -352,12 +354,12 @@ def _get_tf1_flags(parser):
parser.add_argument(
"--inference_type",
type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
help="Target data type of real-number arrays in the output file.")
parser.add_argument(
"--inference_input_type",
type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
help=("Target data type of real-number input arrays. Allows for a "
"different type for input arrays in the case of quantization."))

View File

@ -98,8 +98,8 @@ class TfLiteConvertV1Test(TestModels):
sess.close()
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2}'.format(graph_def_file,
'Placeholder', 'add'))
'--output_arrays={2}'.format(graph_def_file, 'Placeholder',
'add'))
self._run(flags_str, should_succeed=True)
os.remove(graph_def_file)
@ -137,8 +137,31 @@ class TfLiteConvertV1Test(TestModels):
sess.close()
flags_str = ('--graph_def_file={0} --input_arrays={1} '
'--output_arrays={2}'.format(graph_def_file,
'random', 'add'))
'--output_arrays={2}'.format(graph_def_file, 'random', 'add'))
self._run(flags_str, should_succeed=True)
os.remove(graph_def_file)
def testQATFrozenGraphDefInt8(self):
with ops.Graph().as_default():
in_tensor_1 = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
in_tensor_2 = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
_ = array_ops.fake_quant_with_min_max_args(
in_tensor_1 + in_tensor_2, min=0., max=1., name='output',
num_bits=16) # INT8 inference type works for 16 bits fake quant.
sess = session.Session()
# Write graph to file.
graph_def_file = self._getFilepath('model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
sess.close()
flags_str = ('--inference_type=INT8 --std_dev_values=128,128 '
'--mean_values=128,128 '
'--graph_def_file={0} --input_arrays={1},{2} '
'--output_arrays={3}'.format(graph_def_file, 'inputA',
'inputB', 'output'))
self._run(flags_str, should_succeed=True)
os.remove(graph_def_file)
@ -166,8 +189,8 @@ class TfLiteConvertV1Test(TestModels):
def testKerasFileMLIR(self):
keras_file = self._getKerasModelFile()
flags_str = ('--keras_model_file={} --experimental_new_converter'
.format(keras_file))
flags_str = (
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
self._run(flags_str, should_succeed=True)
os.remove(keras_file)
@ -299,8 +322,8 @@ class TfLiteConvertV2Test(TestModels):
def testKerasFileMLIR(self):
keras_file = self._getKerasModelFile()
flags_str = ('--keras_model_file={} --experimental_new_converter'
.format(keras_file))
flags_str = (
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
self._run(flags_str, should_succeed=True)
os.remove(keras_file)