Merge pull request #2522 from lissyx/package_as_zip

Support packaging as Zip file
This commit is contained in:
lissyx 2019-11-13 13:18:45 +01:00 committed by GitHub
commit c5935644a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 2 deletions

View File

@ -787,7 +787,7 @@ def export():
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
@ -854,6 +854,26 @@ def export():
except RuntimeError as e:
log_error(str(e))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
with open(os.path.join(export_dir, 'info.json'), 'w') as f:
json.dump({
'name': FLAGS.export_language,
'parameters': {
'beamWidth': FLAGS.beam_width,
'lmAlpha': FLAGS.lm_alpha,
'lmBeta': FLAGS.lm_beta
}
}, f)
shutil.copy(FLAGS.lm_binary_path, export_dir)
shutil.copy(FLAGS.lm_trie_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
@ -918,10 +938,21 @@ def main(_):
tfv1.reset_default_graph()
test()
if FLAGS.export_dir:
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)

View File

@ -21,3 +21,14 @@ python -u DeepSpeech.py --noshow_progressbar \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--export_tflite
mkdir /tmp/train_tflite/en-us
python -u DeepSpeech.py --noshow_progressbar \
--n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \
--export_dir '/tmp/train_tflite/en-us' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--export_language 'Fake English (fk-FK)' \
--export_zip

View File

@ -102,6 +102,7 @@ def create_flags():
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
# Reporting