Merge pull request #1939 from mozilla/tfliteconverter
Use tf.lite.TFLiteConverter to create tflite model
This commit is contained in:
commit
382707388a
@ -12,7 +12,6 @@ import evaluate
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
import traceback
|
||||
|
||||
@ -30,9 +29,9 @@ from util.text import Alphabet
|
||||
|
||||
#TODO: remove once fully switched to 1.13
|
||||
try:
|
||||
from tensorflow.contrib.lite.python import tflite_convert # 1.12
|
||||
import tensorflow.lite as lite # 1.13
|
||||
except ImportError:
|
||||
from tensorflow.lite.python import tflite_convert # 1.13
|
||||
import tensorflow.contrib.lite as lite # 1.12
|
||||
|
||||
|
||||
# Graph Creation
|
||||
@ -804,7 +803,7 @@ def export():
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||
freeze_graph.freeze_graph_with_def_protos(
|
||||
return freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=session.graph_def,
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
@ -819,39 +818,16 @@ def export():
|
||||
if not FLAGS.export_tflite:
|
||||
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||
else:
|
||||
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
|
||||
os.close(temp_fd)
|
||||
do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
class TFLiteFlags():
|
||||
def __init__(self):
|
||||
self.graph_def_file = temp_freeze
|
||||
self.inference_type = 'FLOAT'
|
||||
self.input_arrays = input_names
|
||||
self.input_shapes = input_shapes
|
||||
self.output_arrays = output_names
|
||||
self.output_file = output_tflite_path
|
||||
self.output_format = 'TFLITE'
|
||||
self.post_training_quantize = True
|
||||
|
||||
default_empty = [
|
||||
'inference_input_type',
|
||||
'mean_values',
|
||||
'default_ranges_min', 'default_ranges_max',
|
||||
'drop_control_dependency',
|
||||
'reorder_across_fake_quant',
|
||||
'change_concat_input_ranges',
|
||||
'allow_custom_ops',
|
||||
'converter_mode',
|
||||
'dump_graphviz_dir',
|
||||
'dump_graphviz_video'
|
||||
]
|
||||
for e in default_empty:
|
||||
self.__dict__[e] = None
|
||||
converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.post_training_quantize = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
flags = TFLiteFlags()
|
||||
tflite_convert._convert_model(flags)
|
||||
os.unlink(temp_freeze)
|
||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
Loading…
x
Reference in New Issue
Block a user