Merge pull request #2779 from reuben/export-metadata

Write model metadata to export folder unconditionally
This commit is contained in:
Reuben Morais 2020-03-17 12:08:35 -03:00 committed by GitHub
commit 29a2ac37f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 8 deletions

View File

@ -739,6 +739,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
@ -780,7 +781,7 @@ def export():
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_name + '.pb'
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
@ -817,21 +818,42 @@ def export():
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
with open(os.path.join(export_dir, 'info.json'), 'w') as f:
json.dump({
'name': FLAGS.export_language,
}, f)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
@ -907,6 +929,7 @@ def main(_):
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__':
create_flags()
absl.app.run(main)

View File

@ -113,11 +113,26 @@ def create_flags():
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
f.DEFINE_string('export_name', 'output_graph', 'name for the export model')
f.DEFINE_string('export_file_name', 'output_graph', 'name for the exported model file name')
f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph')
# Model metadata
f.DEFINE_string('export_author_id', 'author', 'author of the exported model. GitHub user or organization name used to uniquely identify the author of this model')
f.DEFINE_string('export_model_name', 'model', 'name of the exported model. Must not contain forward slashes.')
f.DEFINE_string('export_model_version', '0.0.1', 'semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with DeepSpeech versions')
def str_val_equals_help(name, val_desc):
f.DEFINE_string(name, '<{}>'.format(val_desc), val_desc)
str_val_equals_help('export_contact_info', 'public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors')
str_val_equals_help('export_license', 'SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.')
str_val_equals_help('export_language', 'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".')
str_val_equals_help('export_min_ds_version', 'minimum DeepSpeech version (inclusive) the exported model is compatible with')
str_val_equals_help('export_max_ds_version', 'maximum DeepSpeech version (inclusive) the exported model is compatible with')
str_val_equals_help('export_description', 'Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.')
# Reporting
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR')