Merge pull request #2522 from lissyx/package_as_zip
Support packaging as Zip file
This commit is contained in:
commit
c5935644a1
@ -787,7 +787,7 @@ def export():
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
@ -854,6 +854,26 @@ def export():
|
||||
except RuntimeError as e:
|
||||
log_error(str(e))
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
with open(os.path.join(export_dir, 'info.json'), 'w') as f:
|
||||
json.dump({
|
||||
'name': FLAGS.export_language,
|
||||
'parameters': {
|
||||
'beamWidth': FLAGS.beam_width,
|
||||
'lmAlpha': FLAGS.lm_alpha,
|
||||
'lmBeta': FLAGS.lm_beta
|
||||
}
|
||||
}, f)
|
||||
|
||||
shutil.copy(FLAGS.lm_binary_path, export_dir)
|
||||
shutil.copy(FLAGS.lm_trie_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
@ -918,10 +938,21 @@ def main(_):
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir:
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
@ -21,3 +21,14 @@ python -u DeepSpeech.py --noshow_progressbar \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--export_tflite
|
||||
|
||||
mkdir /tmp/train_tflite/en-us
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite/en-us' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--export_language 'Fake English (fk-FK)' \
|
||||
--export_zip
|
||||
|
@ -102,6 +102,7 @@ def create_flags():
|
||||
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
|
||||
|
||||
# Reporting
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user