Add more arguments. Rename file variables.

This commit is contained in:
Daniel 2020-03-03 16:48:43 +01:00
parent c505a4ec6c
commit c9a433486f
2 changed files with 40 additions and 26 deletions

View File

@ -1,7 +1,12 @@
| Generate vocab-500000.txt and lm.binary files
| Add '--download_librispeech' to download the librispeech text corpus (will be downloaded to '--input_txt')
| Optional change the path of the kenlm binaries with '--kenlm_bins path/to/bins/'
| Optional change the number of most frequent words with '--top_k 300000'
|
| Optional Parameters:
* '--download_librispeech': Download the librispeech text corpus (will be downloaded to '--input_txt')
* '--kenlm_bins path/to/bins/': Change the path of the kenlm binaries (defaults to directory in docker container)
* '--top_k 300000': Change the number of most frequent words
* '--arpa_order 3': Change order of k-grams in arpa-file generation
* '--max_arpa_memory 90%': Set maximum allowed memory usage in arpa-file generation
.. code-block:: bash

View File

@ -1,12 +1,10 @@
import argparse
import gzip
import io
import os
import subprocess
from collections import Counter
from urllib import request
import tqdm
import progressbar
# ======================================================================================================================
@ -15,16 +13,15 @@ def convert_and_filter_topk(args):
""" Convert to lowercase, count word occurrences and save top-k words to a file """
counter = Counter()
data_lower = os.path.join(args.output_dir, 'lower.txt.gz')
data_lower = os.path.join(args.output_dir, 'lower.txt')
# Convert and count words
print('\nConverting to lowercase and counting word occurrences ...')
with io.TextIOWrapper(io.BufferedWriter(gzip.open(data_lower, 'w+')), encoding='utf-8') as lower:
with open(args.input_txt, encoding='utf8') as upper:
for line in tqdm.tqdm(upper):
with open(data_lower, 'w+', encoding='utf8') as file_out:
with open(args.input_txt, encoding='utf8') as file_in:
for line in progressbar.progressbar(file_in):
line_lower = line.lower()
counter.update(line_lower.split())
lower.write(line_lower)
file_out.write(line_lower)
# Save top-k words
print('\nSaving top {} words'.format(args.top_k))
@ -40,32 +37,34 @@ def convert_and_filter_topk(args):
# ======================================================================================================================
def build_lm(args, data_lower, vocab_str):
""" Create the lm.binary file """
# Calculate n-grams for the lm.arpa file
print('\nCreating ARPA file ...')
lm_path = os.path.join(args.output_dir, 'lm.arpa')
subprocess.check_call([
args.kenlm_bins + 'lmplz', '--order', '5',
args.kenlm_bins + 'lmplz',
'--order', str(args.arpa_order),
'--temp_prefix', args.output_dir,
'--memory', '75%',
'--memory', args.max_arpa_memory,
'--text', data_lower,
'--arpa', lm_path,
'--prune', '0', '0', '1'
])
# Filter lm.arpa using vocabulary of top-k words
# Filter LM using vocabulary of top 500k words
print('\nFiltering ARPA file using vocabulary of top-k words ...')
filtered_path = os.path.join(args.output_dir, 'lm_filtered.arpa')
subprocess.run([args.kenlm_bins + 'filter', 'single', 'model:{}'.format(lm_path), filtered_path],
input=vocab_str.encode('utf-8'),
check=True)
subprocess.run([
args.kenlm_bins + 'filter',
'single',
'model:{}'.format(lm_path),
filtered_path
], input=vocab_str.encode('utf-8'), check=True)
# Quantize, produce trie and save to lm.binary
# Quantize and produce trie binary.
print('\nBuilding lm.binary ...')
binary_path = os.path.join(args.output_dir, 'lm.binary')
subprocess.check_call([
args.kenlm_bins + 'build_binary', '-a', '255',
args.kenlm_bins + 'build_binary',
'-a', '255',
'-q', '8',
'-v',
'trie',
@ -108,6 +107,18 @@ def main():
type=str,
default='/DeepSpeech/native_client/kenlm/build/bin/'
)
parser.add_argument(
'--arpa_order',
help='Order of k-grams in arpa-file generation',
type=int,
default=5
)
parser.add_argument(
'--max_arpa_memory',
help='Maximum allowed memory usage in arpa-file generation',
type=str,
default='75%'
)
args = parser.parse_args()
if args.download_librispeech:
@ -122,9 +133,7 @@ def main():
build_lm(args, data_lower, vocab_str)
# Delete intermediate files
if args.download_librispeech:
os.remove(args.input_txt)
os.remove(os.path.join(args.output_dir, 'lower.txt.gz'))
os.remove(os.path.join(args.output_dir, 'lower.txt'))
os.remove(os.path.join(args.output_dir, 'lm.arpa'))
os.remove(os.path.join(args.output_dir, 'lm_filtered.arpa'))