Support int8 in tflite_convert
PiperOrigin-RevId: 312105323 Change-Id: I161b9b324e37f42f2026592f7c5bec8ac568c3d6
This commit is contained in:
parent
0bf90cb2a8
commit
83b85568fb
@ -65,6 +65,8 @@ def _parse_inference_type(value, flag):
|
||||
return lite_constants.FLOAT
|
||||
if value == "QUANTIZED_UINT8":
|
||||
return lite_constants.QUANTIZED_UINT8
|
||||
if value == "INT8":
|
||||
return lite_constants.INT8
|
||||
raise ValueError("Unsupported value for --{0}. Only FLOAT and "
|
||||
"QUANTIZED_UINT8 are supported.".format(flag))
|
||||
|
||||
@ -352,12 +354,12 @@ def _get_tf1_flags(parser):
|
||||
parser.add_argument(
|
||||
"--inference_type",
|
||||
type=str.upper,
|
||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
||||
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
||||
help="Target data type of real-number arrays in the output file.")
|
||||
parser.add_argument(
|
||||
"--inference_input_type",
|
||||
type=str.upper,
|
||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
||||
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
||||
help=("Target data type of real-number input arrays. Allows for a "
|
||||
"different type for input arrays in the case of quantization."))
|
||||
|
||||
|
@ -98,8 +98,8 @@ class TfLiteConvertV1Test(TestModels):
|
||||
sess.close()
|
||||
|
||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2}'.format(graph_def_file,
|
||||
'Placeholder', 'add'))
|
||||
'--output_arrays={2}'.format(graph_def_file, 'Placeholder',
|
||||
'add'))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(graph_def_file)
|
||||
|
||||
@ -137,8 +137,31 @@ class TfLiteConvertV1Test(TestModels):
|
||||
sess.close()
|
||||
|
||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||
'--output_arrays={2}'.format(graph_def_file,
|
||||
'random', 'add'))
|
||||
'--output_arrays={2}'.format(graph_def_file, 'random', 'add'))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(graph_def_file)
|
||||
|
||||
def testQATFrozenGraphDefInt8(self):
|
||||
with ops.Graph().as_default():
|
||||
in_tensor_1 = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
|
||||
in_tensor_2 = array_ops.placeholder(
|
||||
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
|
||||
_ = array_ops.fake_quant_with_min_max_args(
|
||||
in_tensor_1 + in_tensor_2, min=0., max=1., name='output',
|
||||
num_bits=16) # INT8 inference type works for 16 bits fake quant.
|
||||
sess = session.Session()
|
||||
|
||||
# Write graph to file.
|
||||
graph_def_file = self._getFilepath('model.pb')
|
||||
write_graph(sess.graph_def, '', graph_def_file, False)
|
||||
sess.close()
|
||||
|
||||
flags_str = ('--inference_type=INT8 --std_dev_values=128,128 '
|
||||
'--mean_values=128,128 '
|
||||
'--graph_def_file={0} --input_arrays={1},{2} '
|
||||
'--output_arrays={3}'.format(graph_def_file, 'inputA',
|
||||
'inputB', 'output'))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(graph_def_file)
|
||||
|
||||
@ -166,8 +189,8 @@ class TfLiteConvertV1Test(TestModels):
|
||||
def testKerasFileMLIR(self):
|
||||
keras_file = self._getKerasModelFile()
|
||||
|
||||
flags_str = ('--keras_model_file={} --experimental_new_converter'
|
||||
.format(keras_file))
|
||||
flags_str = (
|
||||
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(keras_file)
|
||||
|
||||
@ -299,8 +322,8 @@ class TfLiteConvertV2Test(TestModels):
|
||||
def testKerasFileMLIR(self):
|
||||
keras_file = self._getKerasModelFile()
|
||||
|
||||
flags_str = ('--keras_model_file={} --experimental_new_converter'
|
||||
.format(keras_file))
|
||||
flags_str = (
|
||||
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
|
||||
self._run(flags_str, should_succeed=True)
|
||||
os.remove(keras_file)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user