diff --git a/DeepSpeech.py b/DeepSpeech.py index 8ebd1e25..c31f5515 100644 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -727,6 +727,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): return inputs, outputs, layers + def file_relative_read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() @@ -768,7 +769,7 @@ def export(): method_order = [FLAGS.load] load_or_init_graph(session, method_order) - output_filename = FLAGS.export_name + '.pb' + output_filename = FLAGS.export_file_name + '.pb' if FLAGS.remove_export: if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') @@ -805,21 +806,42 @@ def export(): log_info('Models exported at %s' % (FLAGS.export_dir)) + metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format( + FLAGS.export_author_id, + FLAGS.export_model_name, + FLAGS.export_model_version)) + + model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' + with open(metadata_fname, 'w') as f: + f.write('---\n') + f.write('author: {}\n'.format(FLAGS.export_author_id)) + f.write('model_name: {}\n'.format(FLAGS.export_model_name)) + f.write('model_version: {}\n'.format(FLAGS.export_model_version)) + f.write('contact_info: {}\n'.format(FLAGS.export_contact_info)) + f.write('license: {}\n'.format(FLAGS.export_license)) + f.write('language: {}\n'.format(FLAGS.export_language)) + f.write('runtime: {}\n'.format(model_runtime)) + f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version)) + f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version)) + f.write('acoustic_model_url: \n') + f.write('scorer_url: \n') + f.write('---\n') + f.write('{}\n'.format(FLAGS.export_description)) + + log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname)) + + def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' zip_filename = os.path.dirname(export_dir) - with open(os.path.join(export_dir, 'info.json'), 'w') as f: - json.dump({ - 'name': FLAGS.export_language, - }, f) - shutil.copy(FLAGS.scorer_path, export_dir) archive = shutil.make_archive(zip_filename, 'zip', export_dir) log_info('Exported packaged model {}'.format(archive)) + def do_single_file_inference(input_file_path): with tfv1.Session(config=Config.session_config) as session: inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) @@ -895,6 +917,7 @@ def main(_): tfv1.reset_default_graph() do_single_file_inference(FLAGS.one_shot_infer) + if __name__ == '__main__': create_flags() absl.app.run(main) diff --git a/util/flags.py b/util/flags.py index f46fdc81..20a0c8ec 100644 --- a/util/flags.py +++ b/util/flags.py @@ -112,11 +112,26 @@ def create_flags(): f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models') f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine') f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') - f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') - f.DEFINE_string('export_name', 'output_graph', 'name for the export model') + f.DEFINE_string('export_file_name', 'output_graph', 'name for the exported model file name') f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph') + # Model metadata + + f.DEFINE_string('export_author_id', 'author', 'author of the exported model. This is a unique handle used for identifying and collating models per author. Suggestion: use your GitHub username or organization name') + f.DEFINE_string('export_model_name', 'model', 'name of the exported model') + f.DEFINE_string('export_model_version', '1', 'version of the exported model. This is fully controlled by you as author of the model and has no required connection with DeepSpeech versions') + + def str_val_equals_help(name, val_desc): + f.DEFINE_string(name, '<{}>'.format(val_desc), val_desc) + + str_val_equals_help('export_contact_info', 'public contact information of the author. Free form, could be an email address, a repository URL, a company website, etc.') + str_val_equals_help('export_license', 'license of the exported model') + str_val_equals_help('export_language', 'language the model was trained on - IETF BCP 47 language tag, e.g. "de-DE" or "zh-cmn-Hans-CN"') + str_val_equals_help('export_min_ds_version', 'minimum DeepSpeech version (inclusive) the exported model is compatible with') + str_val_equals_help('export_max_ds_version', 'maximum DeepSpeech version (inclusive) the exported model is compatible with') + str_val_equals_help('export_description', 'Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.') + # Reporting f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR')