From c46d8396bc8fc4ffc820feefc8b67e7fa23cd203 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 17 Feb 2020 12:45:37 +0100 Subject: [PATCH] Respect --load when exporting --- DeepSpeech.py | 62 ++++++++++++++++++++------------------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index ba4ce998..27f801b9 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -21,7 +21,6 @@ from datetime import datetime from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from evaluate import evaluate from six.moves import zip, range -from tensorflow.python.tools import freeze_graph, strip_unused_lib from util.config import Config, initialize_globals from util.checkpoints import load_or_init_graph from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features @@ -733,49 +732,40 @@ def export(): if FLAGS.export_language: outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language') + # Prevent further graph changes + tfv1.get_default_graph().finalize() + output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)] output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] - output_names = ",".join(output_names_tensors + output_names_ops) + output_names = output_names_tensors + output_names_ops - # Create a saver using variables from the above newly created graph - saver = tfv1.train.Saver() + with tf.Session() as session: + # Restore variables from checkpoint + if FLAGS.load == 'auto': + method_order = ['best', 'last'] + else: + method_order = [FLAGS.load] + load_or_init_graph(session, method_order) - # Restore variables from training checkpoint - checkpoint = tf.train.get_checkpoint_state(FLAGS.save_checkpoint_dir) - checkpoint_path = checkpoint.model_checkpoint_path + output_filename = FLAGS.export_name + '.pb' + if FLAGS.remove_export: + if os.path.isdir(FLAGS.export_dir): + log_info('Removing old export') + shutil.rmtree(FLAGS.export_dir) - output_filename = FLAGS.export_name + '.pb' - if FLAGS.remove_export: - if os.path.isdir(FLAGS.export_dir): - log_info('Removing old export') - shutil.rmtree(FLAGS.export_dir) - try: output_graph_path = os.path.join(FLAGS.export_dir, output_filename) if not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) - def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''): - frozen = freeze_graph.freeze_graph_with_def_protos( - input_graph_def=tfv1.get_default_graph().as_graph_def(), - input_saver_def=saver.as_saver_def(), - input_checkpoint=checkpoint_path, - output_node_names=output_node_names, - restore_op_name=None, - filename_tensor_name=None, - output_graph=output_file, - clear_devices=False, - variable_names_blacklist=variables_blacklist, - initializer_nodes='') + frozen_graph = tfv1.graph_util.convert_variables_to_constants( + sess=session, + input_graph_def=tfv1.get_default_graph().as_graph_def(), + output_node_names=output_names) - input_node_names = [] - return strip_unused_lib.strip_unused( - input_graph_def=frozen, - input_node_names=input_node_names, - output_node_names=output_node_names.split(','), - placeholder_type_enum=tf.float32.as_datatype_enum) - - frozen_graph = do_graph_freeze(output_node_names=output_names) + frozen_graph = tfv1.graph_util.extract_sub_graph( + graph_def=frozen_graph, + dest_nodes=output_names) if not FLAGS.export_tflite: with open(output_graph_path, 'wb') as fout: @@ -784,7 +774,7 @@ def export(): output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) - converter.optimizations = [ tf.lite.Optimize.DEFAULT ] + converter.optimizations = [tf.lite.Optimize.DEFAULT] # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite converter.allow_custom_ops = True tflite_model = converter.convert() @@ -792,11 +782,7 @@ def export(): with open(output_tflite_path, 'wb') as fout: fout.write(tflite_model) - log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path))) - log_info('Models exported at %s' % (FLAGS.export_dir)) - except RuntimeError as e: - log_error(str(e)) def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip