Merge pull request #1939 from mozilla/tfliteconverter

Use tf.lite.TFLiteConverter to create tflite model
This commit is contained in:
Reuben Morais 2019-03-09 19:52:13 +00:00 committed by GitHub
commit 382707388a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,7 +12,6 @@ import evaluate
import numpy as np
import progressbar
import shutil
import tempfile
import tensorflow as tf
import traceback
@ -30,9 +29,9 @@ from util.text import Alphabet
#TODO: remove once fully switched to 1.13
try:
from tensorflow.contrib.lite.python import tflite_convert # 1.12
import tensorflow.lite as lite # 1.13
except ImportError:
from tensorflow.lite.python import tflite_convert # 1.13
import tensorflow.contrib.lite as lite # 1.12
# Graph Creation
@ -804,7 +803,7 @@ def export():
os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
freeze_graph.freeze_graph_with_def_protos(
return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
@ -819,39 +818,16 @@ def export():
if not FLAGS.export_tflite:
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
else:
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
os.close(temp_fd)
do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
class TFLiteFlags():
def __init__(self):
self.graph_def_file = temp_freeze
self.inference_type = 'FLOAT'
self.input_arrays = input_names
self.input_shapes = input_shapes
self.output_arrays = output_names
self.output_file = output_tflite_path
self.output_format = 'TFLITE'
self.post_training_quantize = True
default_empty = [
'inference_input_type',
'mean_values',
'default_ranges_min', 'default_ranges_max',
'drop_control_dependency',
'reorder_across_fake_quant',
'change_concat_input_ranges',
'allow_custom_ops',
'converter_mode',
'dump_graphviz_dir',
'dump_graphviz_video'
]
for e in default_empty:
self.__dict__[e] = None
converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.post_training_quantize = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
flags = TFLiteFlags()
tflite_convert._convert_model(flags)
os.unlink(temp_freeze)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir))