Support int8 in tflite_convert
PiperOrigin-RevId: 312105323 Change-Id: I161b9b324e37f42f2026592f7c5bec8ac568c3d6
This commit is contained in:
parent
0bf90cb2a8
commit
83b85568fb
@ -65,6 +65,8 @@ def _parse_inference_type(value, flag):
|
|||||||
return lite_constants.FLOAT
|
return lite_constants.FLOAT
|
||||||
if value == "QUANTIZED_UINT8":
|
if value == "QUANTIZED_UINT8":
|
||||||
return lite_constants.QUANTIZED_UINT8
|
return lite_constants.QUANTIZED_UINT8
|
||||||
|
if value == "INT8":
|
||||||
|
return lite_constants.INT8
|
||||||
raise ValueError("Unsupported value for --{0}. Only FLOAT and "
|
raise ValueError("Unsupported value for --{0}. Only FLOAT and "
|
||||||
"QUANTIZED_UINT8 are supported.".format(flag))
|
"QUANTIZED_UINT8 are supported.".format(flag))
|
||||||
|
|
||||||
@ -352,12 +354,12 @@ def _get_tf1_flags(parser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_type",
|
"--inference_type",
|
||||||
type=str.upper,
|
type=str.upper,
|
||||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
||||||
help="Target data type of real-number arrays in the output file.")
|
help="Target data type of real-number arrays in the output file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_input_type",
|
"--inference_input_type",
|
||||||
type=str.upper,
|
type=str.upper,
|
||||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
choices=["FLOAT", "QUANTIZED_UINT8", "INT8"],
|
||||||
help=("Target data type of real-number input arrays. Allows for a "
|
help=("Target data type of real-number input arrays. Allows for a "
|
||||||
"different type for input arrays in the case of quantization."))
|
"different type for input arrays in the case of quantization."))
|
||||||
|
|
||||||
|
@ -98,8 +98,8 @@ class TfLiteConvertV1Test(TestModels):
|
|||||||
sess.close()
|
sess.close()
|
||||||
|
|
||||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||||
'--output_arrays={2}'.format(graph_def_file,
|
'--output_arrays={2}'.format(graph_def_file, 'Placeholder',
|
||||||
'Placeholder', 'add'))
|
'add'))
|
||||||
self._run(flags_str, should_succeed=True)
|
self._run(flags_str, should_succeed=True)
|
||||||
os.remove(graph_def_file)
|
os.remove(graph_def_file)
|
||||||
|
|
||||||
@ -137,8 +137,31 @@ class TfLiteConvertV1Test(TestModels):
|
|||||||
sess.close()
|
sess.close()
|
||||||
|
|
||||||
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
flags_str = ('--graph_def_file={0} --input_arrays={1} '
|
||||||
'--output_arrays={2}'.format(graph_def_file,
|
'--output_arrays={2}'.format(graph_def_file, 'random', 'add'))
|
||||||
'random', 'add'))
|
self._run(flags_str, should_succeed=True)
|
||||||
|
os.remove(graph_def_file)
|
||||||
|
|
||||||
|
def testQATFrozenGraphDefInt8(self):
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
in_tensor_1 = array_ops.placeholder(
|
||||||
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
|
||||||
|
in_tensor_2 = array_ops.placeholder(
|
||||||
|
shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
|
||||||
|
_ = array_ops.fake_quant_with_min_max_args(
|
||||||
|
in_tensor_1 + in_tensor_2, min=0., max=1., name='output',
|
||||||
|
num_bits=16) # INT8 inference type works for 16 bits fake quant.
|
||||||
|
sess = session.Session()
|
||||||
|
|
||||||
|
# Write graph to file.
|
||||||
|
graph_def_file = self._getFilepath('model.pb')
|
||||||
|
write_graph(sess.graph_def, '', graph_def_file, False)
|
||||||
|
sess.close()
|
||||||
|
|
||||||
|
flags_str = ('--inference_type=INT8 --std_dev_values=128,128 '
|
||||||
|
'--mean_values=128,128 '
|
||||||
|
'--graph_def_file={0} --input_arrays={1},{2} '
|
||||||
|
'--output_arrays={3}'.format(graph_def_file, 'inputA',
|
||||||
|
'inputB', 'output'))
|
||||||
self._run(flags_str, should_succeed=True)
|
self._run(flags_str, should_succeed=True)
|
||||||
os.remove(graph_def_file)
|
os.remove(graph_def_file)
|
||||||
|
|
||||||
@ -166,8 +189,8 @@ class TfLiteConvertV1Test(TestModels):
|
|||||||
def testKerasFileMLIR(self):
|
def testKerasFileMLIR(self):
|
||||||
keras_file = self._getKerasModelFile()
|
keras_file = self._getKerasModelFile()
|
||||||
|
|
||||||
flags_str = ('--keras_model_file={} --experimental_new_converter'
|
flags_str = (
|
||||||
.format(keras_file))
|
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
|
||||||
self._run(flags_str, should_succeed=True)
|
self._run(flags_str, should_succeed=True)
|
||||||
os.remove(keras_file)
|
os.remove(keras_file)
|
||||||
|
|
||||||
@ -299,8 +322,8 @@ class TfLiteConvertV2Test(TestModels):
|
|||||||
def testKerasFileMLIR(self):
|
def testKerasFileMLIR(self):
|
||||||
keras_file = self._getKerasModelFile()
|
keras_file = self._getKerasModelFile()
|
||||||
|
|
||||||
flags_str = ('--keras_model_file={} --experimental_new_converter'
|
flags_str = (
|
||||||
.format(keras_file))
|
'--keras_model_file={} --experimental_new_converter'.format(keras_file))
|
||||||
self._run(flags_str, should_succeed=True)
|
self._run(flags_str, should_succeed=True)
|
||||||
os.remove(keras_file)
|
os.remove(keras_file)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user