Add generate_package tool to create combined scorer package
This commit is contained in:
parent
be2229ef29
commit
214b50f490
120
data/lm/generate_package.py
Normal file
120
data/lm/generate_package.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
# Make sure we can import stuff from util/
|
||||||
|
# This script needs to be run from the root of the DeepSpeech repository
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from util.text import Alphabet, UTF8Alphabet
|
||||||
|
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||||
|
|
||||||
|
|
||||||
|
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
|
||||||
|
words = set()
|
||||||
|
vocab_looks_char_based = True
|
||||||
|
with open(vocab_path) as fin:
|
||||||
|
for line in fin:
|
||||||
|
for word in line.split():
|
||||||
|
words.add(word.encode('utf-8'))
|
||||||
|
if len(word) > 1:
|
||||||
|
vocab_looks_char_based = False
|
||||||
|
print("{} unique words read from vocabulary file.".format(len(words)))
|
||||||
|
print(
|
||||||
|
"{} like a character based model.".format(
|
||||||
|
"Looks" if vocab_looks_char_based else "Doesn't look"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if force_utf8 != None:
|
||||||
|
use_utf8 = force_utf8.value
|
||||||
|
else:
|
||||||
|
use_utf8 = vocab_looks_char_based
|
||||||
|
|
||||||
|
if use_utf8:
|
||||||
|
serialized_alphabet = UTF8Alphabet().serialize()
|
||||||
|
else:
|
||||||
|
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
||||||
|
|
||||||
|
alphabet = NativeAlphabet()
|
||||||
|
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
|
||||||
|
if err != 0:
|
||||||
|
print("Error loading alphabet: {}".format(err))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
scorer = Scorer()
|
||||||
|
scorer.set_alphabet(alphabet)
|
||||||
|
scorer.set_utf8_mode(use_utf8)
|
||||||
|
scorer.load_lm(lm_path, "")
|
||||||
|
scorer.fill_dictionary(list(words))
|
||||||
|
shutil.copy(lm_path, package_path)
|
||||||
|
scorer.save_dictionary(package_path, True) # append, not overwrite
|
||||||
|
print('Package created in {}'.format(package_path))
|
||||||
|
|
||||||
|
|
||||||
|
class Tristate(object):
|
||||||
|
def __init__(self, value=None):
|
||||||
|
if any(value is v for v in (True, False, None)):
|
||||||
|
self.value = value
|
||||||
|
else:
|
||||||
|
raise ValueError("Tristate value must be True, False, or None")
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (self.value is other.value if isinstance(other, Tristate)
|
||||||
|
else self.value is other)
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self == other
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
raise TypeError("Tristate object may not be used as a Boolean")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Tristate(%s)" % self.value
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate an external scorer package for DeepSpeech."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alphabet",
|
||||||
|
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm",
|
||||||
|
required=True,
|
||||||
|
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab",
|
||||||
|
required=True,
|
||||||
|
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--force_utf8",
|
||||||
|
default="",
|
||||||
|
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
|
||||||
|
force_utf8 = Tristate(True)
|
||||||
|
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
|
||||||
|
force_utf8 = Tristate(False)
|
||||||
|
else:
|
||||||
|
force_utf8 = Tristate(None)
|
||||||
|
|
||||||
|
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -17,8 +17,10 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
:alphabet: Alphabet
|
:alphabet: Alphabet
|
||||||
:type model_path: basestring
|
:type model_path: basestring
|
||||||
"""
|
"""
|
||||||
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
def __init__(self, alpha=None, beta=None, model_path=None, trie_path=None, alphabet=None):
|
||||||
super(Scorer, self).__init__()
|
super(Scorer, self).__init__()
|
||||||
|
# Allow bare initialization
|
||||||
|
if alphabet:
|
||||||
serialized = alphabet.serialize()
|
serialized = alphabet.serialize()
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
native_alphabet = swigwrapper.Alphabet()
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||||
@ -32,14 +34,11 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(Scorer, self).__init__()
|
|
||||||
|
|
||||||
def load_lm(self, lm_path, trie_path):
|
def load_lm(self, lm_path, trie_path):
|
||||||
super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8'))
|
super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8'))
|
||||||
|
|
||||||
def save_dictionary(self, save_path):
|
def save_dictionary(self, save_path, *args, **kwargs):
|
||||||
super(Scorer, self).save_dictionary(save_path.encode('utf-8'))
|
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
|
@ -128,9 +128,15 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
|||||||
max_order_ = language_model_->Order();
|
max_order_ = language_model_->Order();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::save_dictionary(const std::string& path)
|
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||||
{
|
{
|
||||||
std::fstream fout(path, std::ios::in|std::ios::out|std::ios::binary|std::ios::ate);
|
std::ios::openmode om;
|
||||||
|
if (append_instead_of_overwrite) {
|
||||||
|
om = std::ios::in|std::ios::out|std::ios::binary|std::ios::ate;
|
||||||
|
} else {
|
||||||
|
om = std::ios::out|std::ios::binary;
|
||||||
|
}
|
||||||
|
std::fstream fout(path, om);
|
||||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
@ -95,7 +95,7 @@ public:
|
|||||||
void set_alphabet(const Alphabet& alphabet);
|
void set_alphabet(const Alphabet& alphabet);
|
||||||
|
|
||||||
// save dictionary in file
|
// save dictionary in file
|
||||||
void save_dictionary(const std::string &path);
|
void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
||||||
|
|
||||||
// return weather this step represents a boundary where beam scoring should happen
|
// return weather this step represents a boundary where beam scoring should happen
|
||||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||||
|
Loading…
Reference in New Issue
Block a user