diff --git a/DeepSpeech.py b/DeepSpeech.py index b847a136..941eeb81 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -812,7 +812,7 @@ def export(): checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint_path = checkpoint.model_checkpoint_path - output_filename = 'output_graph.pb' + output_filename = FLAGS.export_name + '.pb' if FLAGS.remove_export: if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') diff --git a/util/flags.py b/util/flags.py index dda3cd58..f2d8e75d 100644 --- a/util/flags.py +++ b/util/flags.py @@ -110,6 +110,7 @@ def create_flags(): f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') + f.DEFINE_string('export_name', 'output_graph', 'name for the export model') # Reporting