Add generate_package tool to create combined scorer package

This commit is contained in:
Reuben Morais 2020-01-16 15:45:57 +01:00
parent be2229ef29
commit 214b50f490
4 changed files with 145 additions and 20 deletions

120
data/lm/generate_package.py Normal file
View File

@ -0,0 +1,120 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
import argparse
import shutil
from util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
words = set()
vocab_looks_char_based = True
with open(vocab_path) as fin:
for line in fin:
for word in line.split():
words.add(word.encode('utf-8'))
if len(word) > 1:
vocab_looks_char_based = False
print("{} unique words read from vocabulary file.".format(len(words)))
print(
"{} like a character based model.".format(
"Looks" if vocab_looks_char_based else "Doesn't look"
)
)
if force_utf8 != None:
use_utf8 = force_utf8.value
else:
use_utf8 = vocab_looks_char_based
if use_utf8:
serialized_alphabet = UTF8Alphabet().serialize()
else:
serialized_alphabet = Alphabet(alphabet_path).serialize()
alphabet = NativeAlphabet()
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
if err != 0:
print("Error loading alphabet: {}".format(err))
sys.exit(1)
scorer = Scorer()
scorer.set_alphabet(alphabet)
scorer.set_utf8_mode(use_utf8)
scorer.load_lm(lm_path, "")
scorer.fill_dictionary(list(words))
shutil.copy(lm_path, package_path)
scorer.save_dictionary(package_path, True) # append, not overwrite
print('Package created in {}'.format(package_path))
class Tristate(object):
def __init__(self, value=None):
if any(value is v for v in (True, False, None)):
self.value = value
else:
raise ValueError("Tristate value must be True, False, or None")
def __eq__(self, other):
return (self.value is other.value if isinstance(other, Tristate)
else self.value is other)
def __ne__(self, other):
return not self == other
def __bool__(self):
raise TypeError("Tristate object may not be used as a Boolean")
def __str__(self):
return str(self.value)
def __repr__(self):
return "Tristate(%s)" % self.value
def main():
parser = argparse.ArgumentParser(
description="Generate an external scorer package for DeepSpeech."
)
parser.add_argument(
"--alphabet",
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
)
parser.add_argument(
"--lm",
required=True,
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
)
parser.add_argument(
"--vocab",
required=True,
help="Path of vocabulary file. Must contain words separated by whitespace.",
)
parser.add_argument("--package", required=True, help="Path to save scorer package.")
parser.add_argument(
"--force_utf8",
default="",
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary.",
)
args = parser.parse_args()
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
force_utf8 = Tristate(True)
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
force_utf8 = Tristate(False)
else:
force_utf8 = Tristate(None)
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8)
if __name__ == "__main__":
main()

View File

@ -17,29 +17,28 @@ class Scorer(swigwrapper.Scorer):
:alphabet: Alphabet
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
def __init__(self, alpha=None, beta=None, model_path=None, trie_path=None, alphabet=None):
super(Scorer, self).__init__()
serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0:
raise ValueError("Error when deserializing alphabet.")
# Allow bare initialization
if alphabet:
serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0:
raise ValueError("Error when deserializing alphabet.")
err = self.init(alpha, beta,
model_path.encode('utf-8'),
trie_path.encode('utf-8'),
native_alphabet)
if err != 0:
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
def __init__(self):
super(Scorer, self).__init__()
err = self.init(alpha, beta,
model_path.encode('utf-8'),
trie_path.encode('utf-8'),
native_alphabet)
if err != 0:
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
def load_lm(self, lm_path, trie_path):
super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8'))
def save_dictionary(self, save_path):
super(Scorer, self).save_dictionary(save_path.encode('utf-8'))
def save_dictionary(self, save_path, *args, **kwargs):
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
def ctc_beam_search_decoder(probs_seq,

View File

@ -128,9 +128,15 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
max_order_ = language_model_->Order();
}
void Scorer::save_dictionary(const std::string& path)
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
{
std::fstream fout(path, std::ios::in|std::ios::out|std::ios::binary|std::ios::ate);
std::ios::openmode om;
if (append_instead_of_overwrite) {
om = std::ios::in|std::ios::out|std::ios::binary|std::ios::ate;
} else {
om = std::ios::out|std::ios::binary;
}
std::fstream fout(path, om);
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));

View File

@ -95,7 +95,7 @@ public:
void set_alphabet(const Alphabet& alphabet);
// save dictionary in file
void save_dictionary(const std::string &path);
void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
// return weather this step represents a boundary where beam scoring should happen
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);