Write default values for alpha and beta into trie header

This commit is contained in:
Reuben Morais 2020-01-16 16:27:54 +01:00
parent b33d90b7bd
commit 16d5632d6f
2 changed files with 21 additions and 5 deletions

View File

@ -14,7 +14,7 @@ from util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8, default_alpha, default_beta):
words = set()
vocab_looks_char_based = True
with open(vocab_path) as fin:
@ -49,6 +49,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
scorer = Scorer()
scorer.set_alphabet(alphabet)
scorer.set_utf8_mode(use_utf8)
scorer.reset_params(default_alpha, default_beta)
scorer.load_lm(lm_path, "")
scorer.fill_dictionary(list(words))
shutil.copy(lm_path, package_path)
@ -99,6 +100,8 @@ def main():
help="Path of vocabulary file. Must contain words separated by whitespace.",
)
parser.add_argument("--package", required=True, help="Path to save scorer package.")
parser.add_argument("--default_alpha", type=float, required=True, help="Default value of alpha hyperparameter.")
parser.add_argument("--default_beta", type=float, required=True, help="Default value of beta hyperparameter.")
parser.add_argument(
"--force_utf8",
default="",
@ -113,7 +116,7 @@ def main():
else:
force_utf8 = Tristate(None)
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8)
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8, args.default_alpha, args.default_beta)
if __name__ == "__main__":

View File

@ -27,7 +27,7 @@
using namespace lm::ngram;
static const int32_t MAGIC = 'TRIE';
static const int32_t FILE_VERSION = 5;
static const int32_t FILE_VERSION = 6;
int
Scorer::init(double alpha,
@ -125,13 +125,24 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
if (version != FILE_VERSION) {
std::cerr << "Error: Trie file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). Update your trie file."
<< std::endl;
<< "). ";
if (version < FILE_VERSION) {
std::cerr << "Update your trie file.";
} else {
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
}
std::cerr << std::endl;
throw 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
// Read hyperparameters from header
double alpha, beta;
fin.read(reinterpret_cast<char*>(&alpha), sizeof(alpha));
fin.read(reinterpret_cast<char*>(&beta), sizeof(beta));
reset_params(alpha, beta);
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
@ -150,6 +161,8 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
fst::FstWriteOptions opt;
opt.align = true;
opt.source = path;