Use tf.lite.TFLiteConverter to create tflite model

This commit is contained in:
Reuben Morais 2019-03-09 12:02:18 -03:00
parent eb1c0f9853
commit d3aa5020a6

View File

@ -12,7 +12,6 @@ import evaluate
import numpy as np import numpy as np
import progressbar import progressbar
import shutil import shutil
import tempfile
import tensorflow as tf import tensorflow as tf
import traceback import traceback
@ -30,9 +29,9 @@ from util.text import Alphabet
#TODO: remove once fully switched to 1.13 #TODO: remove once fully switched to 1.13
try: try:
from tensorflow.contrib.lite.python import tflite_convert # 1.12 import tensorflow.lite as lite # 1.13
except ImportError: except ImportError:
from tensorflow.lite.python import tflite_convert # 1.13 import tensorflow.contrib.lite as lite # 1.12
# Graph Creation # Graph Creation
@ -804,7 +803,7 @@ def export():
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
freeze_graph.freeze_graph_with_def_protos( return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def, input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(), input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path, input_checkpoint=checkpoint_path,
@ -819,39 +818,16 @@ def export():
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
else: else:
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir) frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
os.close(temp_fd)
do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
class TFLiteFlags():
def __init__(self):
self.graph_def_file = temp_freeze
self.inference_type = 'FLOAT'
self.input_arrays = input_names
self.input_shapes = input_shapes
self.output_arrays = output_names
self.output_file = output_tflite_path
self.output_format = 'TFLITE'
self.post_training_quantize = True
default_empty = [ converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
'inference_input_type', converter.post_training_quantize = True
'mean_values', tflite_model = converter.convert()
'default_ranges_min', 'default_ranges_max',
'drop_control_dependency', with open(output_tflite_path, 'wb') as fout:
'reorder_across_fake_quant', fout.write(tflite_model)
'change_concat_input_ranges',
'allow_custom_ops',
'converter_mode',
'dump_graphviz_dir',
'dump_graphviz_video'
]
for e in default_empty:
self.__dict__[e] = None
flags = TFLiteFlags()
tflite_convert._convert_model(flags)
os.unlink(temp_freeze)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path))) log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))