diff --git a/DeepSpeech.py b/DeepSpeech.py index 9f9b5ce1..47809d54 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -12,7 +12,6 @@ import evaluate import numpy as np import progressbar import shutil -import tempfile import tensorflow as tf import traceback @@ -30,9 +29,9 @@ from util.text import Alphabet #TODO: remove once fully switched to 1.13 try: - from tensorflow.contrib.lite.python import tflite_convert # 1.12 + import tensorflow.lite as lite # 1.13 except ImportError: - from tensorflow.lite.python import tflite_convert # 1.13 + import tensorflow.contrib.lite as lite # 1.12 # Graph Creation @@ -804,7 +803,7 @@ def export(): os.makedirs(FLAGS.export_dir) def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): - freeze_graph.freeze_graph_with_def_protos( + return freeze_graph.freeze_graph_with_def_protos( input_graph_def=session.graph_def, input_saver_def=saver.as_saver_def(), input_checkpoint=checkpoint_path, @@ -819,39 +818,16 @@ def export(): if not FLAGS.export_tflite: do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') else: - temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir) - os.close(temp_fd) - do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='') + frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='') output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) - class TFLiteFlags(): - def __init__(self): - self.graph_def_file = temp_freeze - self.inference_type = 'FLOAT' - self.input_arrays = input_names - self.input_shapes = input_shapes - self.output_arrays = output_names - self.output_file = output_tflite_path - self.output_format = 'TFLITE' - self.post_training_quantize = True - default_empty = [ - 'inference_input_type', - 'mean_values', - 'default_ranges_min', 'default_ranges_max', - 'drop_control_dependency', - 'reorder_across_fake_quant', - 'change_concat_input_ranges', - 'allow_custom_ops', - 'converter_mode', - 'dump_graphviz_dir', - 'dump_graphviz_video' - ] - for e in default_empty: - self.__dict__[e] = None + converter = lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) + converter.post_training_quantize = True + tflite_model = converter.convert() + + with open(output_tflite_path, 'wb') as fout: + fout.write(tflite_model) - flags = TFLiteFlags() - tflite_convert._convert_model(flags) - os.unlink(temp_freeze) log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path))) log_info('Models exported at %s' % (FLAGS.export_dir))