Write model metadata to export folder unconditionally

This commit is contained in:
Reuben Morais 2020-02-03 15:44:21 +01:00
parent 4291db7309
commit 48178005a2
2 changed files with 46 additions and 8 deletions

View File

@ -727,6 +727,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
@ -768,7 +769,7 @@ def export():
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_name + '.pb'
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
@ -805,21 +806,42 @@ def export():
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
with open(os.path.join(export_dir, 'info.json'), 'w') as f:
json.dump({
'name': FLAGS.export_language,
}, f)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
@ -895,6 +917,7 @@ def main(_):
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__':
create_flags()
absl.app.run(main)

View File

@ -112,11 +112,26 @@ def create_flags():
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
f.DEFINE_string('export_name', 'output_graph', 'name for the export model')
f.DEFINE_string('export_file_name', 'output_graph', 'name for the exported model file name')
f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph')
# Model metadata
f.DEFINE_string('export_author_id', 'author', 'author of the exported model. This is a unique handle used for identifying and collating models per author. Suggestion: use your GitHub username or organization name')
f.DEFINE_string('export_model_name', 'model', 'name of the exported model')
f.DEFINE_string('export_model_version', '1', 'version of the exported model. This is fully controlled by you as author of the model and has no required connection with DeepSpeech versions')
def str_val_equals_help(name, val_desc):
f.DEFINE_string(name, '<{}>'.format(val_desc), val_desc)
str_val_equals_help('export_contact_info', 'public contact information of the author. Free form, could be an email address, a repository URL, a company website, etc.')
str_val_equals_help('export_license', 'license of the exported model')
str_val_equals_help('export_language', 'language the model was trained on - IETF BCP 47 language tag, e.g. "de-DE" or "zh-cmn-Hans-CN"')
str_val_equals_help('export_min_ds_version', 'minimum DeepSpeech version (inclusive) the exported model is compatible with')
str_val_equals_help('export_max_ds_version', 'maximum DeepSpeech version (inclusive) the exported model is compatible with')
str_val_equals_help('export_description', 'Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.')
# Reporting
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR')