Write default values for alpha and beta into trie header
This commit is contained in:
parent
b33d90b7bd
commit
16d5632d6f
@ -14,7 +14,7 @@ from util.text import Alphabet, UTF8Alphabet
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
|
||||
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8, default_alpha, default_beta):
|
||||
words = set()
|
||||
vocab_looks_char_based = True
|
||||
with open(vocab_path) as fin:
|
||||
@ -49,6 +49,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
|
||||
scorer = Scorer()
|
||||
scorer.set_alphabet(alphabet)
|
||||
scorer.set_utf8_mode(use_utf8)
|
||||
scorer.reset_params(default_alpha, default_beta)
|
||||
scorer.load_lm(lm_path, "")
|
||||
scorer.fill_dictionary(list(words))
|
||||
shutil.copy(lm_path, package_path)
|
||||
@ -99,6 +100,8 @@ def main():
|
||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
||||
)
|
||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
||||
parser.add_argument("--default_alpha", type=float, required=True, help="Default value of alpha hyperparameter.")
|
||||
parser.add_argument("--default_beta", type=float, required=True, help="Default value of beta hyperparameter.")
|
||||
parser.add_argument(
|
||||
"--force_utf8",
|
||||
default="",
|
||||
@ -113,7 +116,7 @@ def main():
|
||||
else:
|
||||
force_utf8 = Tristate(None)
|
||||
|
||||
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8)
|
||||
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8, args.default_alpha, args.default_beta)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -27,7 +27,7 @@
|
||||
using namespace lm::ngram;
|
||||
|
||||
static const int32_t MAGIC = 'TRIE';
|
||||
static const int32_t FILE_VERSION = 5;
|
||||
static const int32_t FILE_VERSION = 6;
|
||||
|
||||
int
|
||||
Scorer::init(double alpha,
|
||||
@ -125,13 +125,24 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Trie file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). Update your trie file."
|
||||
<< std::endl;
|
||||
<< "). ";
|
||||
if (version < FILE_VERSION) {
|
||||
std::cerr << "Update your trie file.";
|
||||
} else {
|
||||
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
// Read hyperparameters from header
|
||||
double alpha, beta;
|
||||
fin.read(reinterpret_cast<char*>(&alpha), sizeof(alpha));
|
||||
fin.read(reinterpret_cast<char*>(&beta), sizeof(beta));
|
||||
reset_params(alpha, beta);
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = file_path;
|
||||
@ -150,6 +161,8 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove
|
||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
|
||||
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
|
||||
fst::FstWriteOptions opt;
|
||||
opt.align = true;
|
||||
opt.source = path;
|
||||
|
Loading…
Reference in New Issue
Block a user