Merge pull request #1939 from mozilla/tfliteconverter

Use tf.lite.TFLiteConverter to create tflite model
This commit is contained in:
Reuben Morais 2019-03-09 19:52:13 +00:00 committed by GitHub
commit 382707388a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,7 +12,6 @@ import evaluate
import numpy as np import numpy as np
import progressbar import progressbar
import shutil import shutil
import tempfile
import tensorflow as tf import tensorflow as tf
import traceback import traceback
@ -30,9 +29,9 @@ from util.text import Alphabet
#TODO: remove once fully switched to 1.13 #TODO: remove once fully switched to 1.13
try: try:
from tensorflow.contrib.lite.python import tflite_convert # 1.12 import tensorflow.lite as lite # 1.13
except ImportError: except ImportError:
from tensorflow.lite.python import tflite_convert # 1.13 import tensorflow.contrib.lite as lite # 1.12
# Graph Creation # Graph Creation
@ -804,7 +803,7 @@ def export():
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
freeze_graph.freeze_graph_with_def_protos( return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def, input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(), input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path, input_checkpoint=checkpoint_path,
@ -819,39 +818,16 @@ def export():
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
else: else:
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir) frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
os.close(temp_fd)
do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
class TFLiteFlags():
def __init__(self):
self.graph_def_file = temp_freeze
self.inference_type = 'FLOAT'
self.input_arrays = input_names
self.input_shapes = input_shapes
self.output_arrays = output_names
self.output_file = output_tflite_path
self.output_format = 'TFLITE'
self.post_training_quantize = True
default_empty = [ converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
'inference_input_type', converter.post_training_quantize = True
'mean_values', tflite_model = converter.convert()
'default_ranges_min', 'default_ranges_max',
'drop_control_dependency', with open(output_tflite_path, 'wb') as fout:
'reorder_across_fake_quant', fout.write(tflite_model)
'change_concat_input_ranges',
'allow_custom_ops',
'converter_mode',
'dump_graphviz_dir',
'dump_graphviz_video'
]
for e in default_empty:
self.__dict__[e] = None
flags = TFLiteFlags()
tflite_convert._convert_model(flags)
os.unlink(temp_freeze)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path))) log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))