Write default values for alpha and beta into trie header
This commit is contained in:
parent
b33d90b7bd
commit
16d5632d6f
@ -14,7 +14,7 @@ from util.text import Alphabet, UTF8Alphabet
|
|||||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||||
|
|
||||||
|
|
||||||
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
|
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8, default_alpha, default_beta):
|
||||||
words = set()
|
words = set()
|
||||||
vocab_looks_char_based = True
|
vocab_looks_char_based = True
|
||||||
with open(vocab_path) as fin:
|
with open(vocab_path) as fin:
|
||||||
@ -49,6 +49,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8):
|
|||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
scorer.set_alphabet(alphabet)
|
scorer.set_alphabet(alphabet)
|
||||||
scorer.set_utf8_mode(use_utf8)
|
scorer.set_utf8_mode(use_utf8)
|
||||||
|
scorer.reset_params(default_alpha, default_beta)
|
||||||
scorer.load_lm(lm_path, "")
|
scorer.load_lm(lm_path, "")
|
||||||
scorer.fill_dictionary(list(words))
|
scorer.fill_dictionary(list(words))
|
||||||
shutil.copy(lm_path, package_path)
|
shutil.copy(lm_path, package_path)
|
||||||
@ -99,6 +100,8 @@ def main():
|
|||||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
||||||
|
parser.add_argument("--default_alpha", type=float, required=True, help="Default value of alpha hyperparameter.")
|
||||||
|
parser.add_argument("--default_beta", type=float, required=True, help="Default value of beta hyperparameter.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--force_utf8",
|
"--force_utf8",
|
||||||
default="",
|
default="",
|
||||||
@ -113,7 +116,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
force_utf8 = Tristate(None)
|
force_utf8 = Tristate(None)
|
||||||
|
|
||||||
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8)
|
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8, args.default_alpha, args.default_beta)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -27,7 +27,7 @@
|
|||||||
using namespace lm::ngram;
|
using namespace lm::ngram;
|
||||||
|
|
||||||
static const int32_t MAGIC = 'TRIE';
|
static const int32_t MAGIC = 'TRIE';
|
||||||
static const int32_t FILE_VERSION = 5;
|
static const int32_t FILE_VERSION = 6;
|
||||||
|
|
||||||
int
|
int
|
||||||
Scorer::init(double alpha,
|
Scorer::init(double alpha,
|
||||||
@ -125,13 +125,24 @@ void Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
|||||||
if (version != FILE_VERSION) {
|
if (version != FILE_VERSION) {
|
||||||
std::cerr << "Error: Trie file version mismatch (" << version
|
std::cerr << "Error: Trie file version mismatch (" << version
|
||||||
<< " instead of expected " << FILE_VERSION
|
<< " instead of expected " << FILE_VERSION
|
||||||
<< "). Update your trie file."
|
<< "). ";
|
||||||
<< std::endl;
|
if (version < FILE_VERSION) {
|
||||||
|
std::cerr << "Update your trie file.";
|
||||||
|
} else {
|
||||||
|
std::cerr << "Downgrade your trie file or update your version of DeepSpeech.";
|
||||||
|
}
|
||||||
|
std::cerr << std::endl;
|
||||||
throw 1;
|
throw 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
|
||||||
|
// Read hyperparameters from header
|
||||||
|
double alpha, beta;
|
||||||
|
fin.read(reinterpret_cast<char*>(&alpha), sizeof(alpha));
|
||||||
|
fin.read(reinterpret_cast<char*>(&beta), sizeof(beta));
|
||||||
|
reset_params(alpha, beta);
|
||||||
|
|
||||||
fst::FstReadOptions opt;
|
fst::FstReadOptions opt;
|
||||||
opt.mode = fst::FstReadOptions::MAP;
|
opt.mode = fst::FstReadOptions::MAP;
|
||||||
opt.source = file_path;
|
opt.source = file_path;
|
||||||
@ -150,6 +161,8 @@ void Scorer::save_dictionary(const std::string& path, bool append_instead_of_ove
|
|||||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
|
||||||
|
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
|
||||||
fst::FstWriteOptions opt;
|
fst::FstWriteOptions opt;
|
||||||
opt.align = true;
|
opt.align = true;
|
||||||
opt.source = path;
|
opt.source = path;
|
||||||
|
Loading…
Reference in New Issue
Block a user