From 16d5632d6f85aaec7b31d6e6b7b978bd55690576 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 16 Jan 2020 16:27:54 +0100 Subject: [PATCH] Write default values for alpha and beta into trie header --- data/lm/generate_package.py | 7 +++++-- native_client/ctcdecode/scorer.cpp | 19 ++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/data/lm/generate_package.py b/data/lm/generate_package.py index ee3c106b..4d064fdd 100644 --- a/data/lm/generate_package.py +++ b/data/lm/generate_package.py @@ -14,7 +14,7 @@ from util.text import Alphabet, UTF8Alphabet from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet -def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8): +def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8, default_alpha, default_beta): words = set() vocab_looks_char_based = True with open(vocab_path) as fin: @@ -49,6 +49,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8): scorer = Scorer() scorer.set_alphabet(alphabet) scorer.set_utf8_mode(use_utf8) + scorer.reset_params(default_alpha, default_beta) scorer.load_lm(lm_path, "") scorer.fill_dictionary(list(words)) shutil.copy(lm_path, package_path) @@ -99,6 +100,8 @@ def main(): help="Path of vocabulary file. Must contain words separated by whitespace.", ) parser.add_argument("--package", required=True, help="Path to save scorer package.") + parser.add_argument("--default_alpha", type=float, required=True, help="Default value of alpha hyperparameter.") + parser.add_argument("--default_beta", type=float, required=True, help="Default value of beta hyperparameter.") parser.add_argument( "--force_utf8", default="", @@ -113,7 +116,7 @@ def main(): else: force_utf8 = Tristate(None) - create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8) + create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8, args.default_alpha, args.default_beta) if __name__ == "__main__": diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 5bd4da8e..c2bdc4c2 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -27,7 +27,7 @@ using namespace lm::ngram; static const int32_t MAGIC = 'TRIE'; -static const int32_t FILE_VERSION = 5; +static const int32_t FILE_VERSION = 6; int Scorer::init(double alpha, @@ -125,13 +125,24 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path) if (version != FILE_VERSION) { std::cerr << "Error: Trie file version mismatch (" << version << " instead of expected " << FILE_VERSION - << "). Update your trie file." - << std::endl; + << "). "; + if (version < FILE_VERSION) { + std::cerr << "Update your trie file."; + } else { + std::cerr << "Downgrade your trie file or update your version of DeepSpeech."; + } + std::cerr << std::endl; throw 1; } fin.read(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); + // Read hyperparameters from header + double alpha, beta; + fin.read(reinterpret_cast(&alpha), sizeof(alpha)); + fin.read(reinterpret_cast(&beta), sizeof(beta)); + reset_params(alpha, beta); + fst::FstReadOptions opt; opt.mode = fst::FstReadOptions::MAP; opt.source = file_path; @@ -150,6 +161,8 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove fout.write(reinterpret_cast(&MAGIC), sizeof(MAGIC)); fout.write(reinterpret_cast(&FILE_VERSION), sizeof(FILE_VERSION)); fout.write(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); + fout.write(reinterpret_cast(&alpha), sizeof(alpha)); + fout.write(reinterpret_cast(&beta), sizeof(beta)); fst::FstWriteOptions opt; opt.align = true; opt.source = path;