Merge pull request #1939 from mozilla/tfliteconverter
Use tf.lite.TFLiteConverter to create tflite model
This commit is contained in:
commit
382707388a
@ -12,7 +12,6 @@ import evaluate
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import progressbar
|
import progressbar
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -30,9 +29,9 @@ from util.text import Alphabet
|
|||||||
|
|
||||||
#TODO: remove once fully switched to 1.13
|
#TODO: remove once fully switched to 1.13
|
||||||
try:
|
try:
|
||||||
from tensorflow.contrib.lite.python import tflite_convert # 1.12
|
import tensorflow.lite as lite # 1.13
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from tensorflow.lite.python import tflite_convert # 1.13
|
import tensorflow.contrib.lite as lite # 1.12
|
||||||
|
|
||||||
|
|
||||||
# Graph Creation
|
# Graph Creation
|
||||||
@ -804,7 +803,7 @@ def export():
|
|||||||
os.makedirs(FLAGS.export_dir)
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||||
freeze_graph.freeze_graph_with_def_protos(
|
return freeze_graph.freeze_graph_with_def_protos(
|
||||||
input_graph_def=session.graph_def,
|
input_graph_def=session.graph_def,
|
||||||
input_saver_def=saver.as_saver_def(),
|
input_saver_def=saver.as_saver_def(),
|
||||||
input_checkpoint=checkpoint_path,
|
input_checkpoint=checkpoint_path,
|
||||||
@ -819,39 +818,16 @@ def export():
|
|||||||
if not FLAGS.export_tflite:
|
if not FLAGS.export_tflite:
|
||||||
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||||
else:
|
else:
|
||||||
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
|
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
|
||||||
os.close(temp_fd)
|
|
||||||
do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
|
|
||||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||||
class TFLiteFlags():
|
|
||||||
def __init__(self):
|
|
||||||
self.graph_def_file = temp_freeze
|
|
||||||
self.inference_type = 'FLOAT'
|
|
||||||
self.input_arrays = input_names
|
|
||||||
self.input_shapes = input_shapes
|
|
||||||
self.output_arrays = output_names
|
|
||||||
self.output_file = output_tflite_path
|
|
||||||
self.output_format = 'TFLITE'
|
|
||||||
self.post_training_quantize = True
|
|
||||||
|
|
||||||
default_empty = [
|
converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||||
'inference_input_type',
|
converter.post_training_quantize = True
|
||||||
'mean_values',
|
tflite_model = converter.convert()
|
||||||
'default_ranges_min', 'default_ranges_max',
|
|
||||||
'drop_control_dependency',
|
with open(output_tflite_path, 'wb') as fout:
|
||||||
'reorder_across_fake_quant',
|
fout.write(tflite_model)
|
||||||
'change_concat_input_ranges',
|
|
||||||
'allow_custom_ops',
|
|
||||||
'converter_mode',
|
|
||||||
'dump_graphviz_dir',
|
|
||||||
'dump_graphviz_video'
|
|
||||||
]
|
|
||||||
for e in default_empty:
|
|
||||||
self.__dict__[e] = None
|
|
||||||
|
|
||||||
flags = TFLiteFlags()
|
|
||||||
tflite_convert._convert_model(flags)
|
|
||||||
os.unlink(temp_freeze)
|
|
||||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
||||||
|
|
||||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user