Refactor Scorer so model/trie package can be created by an external tool

This commit is contained in:
Reuben Morais 2020-01-16 11:34:33 +01:00
parent 7c0354483e
commit be2229ef29
5 changed files with 74 additions and 42 deletions

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import Alphabet
__version__ = swigwrapper.__version__
@ -16,7 +17,6 @@ class Scorer(swigwrapper.Scorer):
:alphabet: Alphabet
:type model_path: basestring
"""
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
super(Scorer, self).__init__()
serialized = alphabet.serialize()
@ -32,6 +32,15 @@ class Scorer(swigwrapper.Scorer):
if err != 0:
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
def __init__(self):
super(Scorer, self).__init__()
def load_lm(self, lm_path, trie_path):
super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8'))
def save_dictionary(self, save_path):
super(Scorer, self).save_dictionary(save_path.encode('utf-8'))
def ctc_beam_search_decoder(probs_seq,
alphabet,

View File

@ -36,7 +36,7 @@ DecoderState::init(const Alphabet& alphabet,
prefix_root_.reset(root);
prefixes_.push_back(root);
if (ext_scorer != nullptr) {
if (ext_scorer != nullptr && (bool)ext_scorer_->dictionary) {
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
root->set_dictionary(dict_ptr);

View File

@ -38,7 +38,8 @@ Scorer::init(double alpha,
{
reset_params(alpha, beta);
alphabet_ = alphabet;
setup(lm_path, trie_path);
setup_char_map();
load_lm(lm_path, trie_path);
return 0;
}
@ -54,11 +55,19 @@ Scorer::init(double alpha,
if (err != 0) {
return err;
}
setup(lm_path, trie_path);
setup_char_map();
load_lm(lm_path, trie_path);
return 0;
}
void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
void
Scorer::set_alphabet(const Alphabet& alphabet)
{
alphabet_ = alphabet;
setup_char_map();
}
void Scorer::setup_char_map()
{
// (Re-)Initialize character map
char_map_.clear();
@ -71,52 +80,57 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
// state, otherwise wrong decoding results would be given.
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
}
}
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
{
// load language model
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
VALID_CHECK(has_trie, "Invalid trie path");
// VALID_CHECK(has_trie, "Invalid trie path");
lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
// Read metadata and trie from file
std::ifstream fin(trie_path, std::ios::binary);
if (has_trie) {
// Read metadata and trie from file
std::ifstream fin(trie_path, std::ios::binary);
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl;
throw 1;
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl;
throw 1;
}
int version;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != FILE_VERSION) {
std::cerr << "Error: Trie file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). Update your trie file."
<< std::endl;
throw 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path;
dictionary.reset(FstType::Read(fin, opt));
}
int version;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != FILE_VERSION) {
std::cerr << "Error: Trie file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). Update your trie file."
<< std::endl;
throw 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path;
dictionary_.reset(FstType::Read(fin, opt));
max_order_ = language_model_->Order();
}
void Scorer::save_dictionary(const std::string& path)
{
std::ofstream fout(path, std::ios::binary);
std::fstream fout(path, std::ios::in|std::ios::out|std::ios::binary|std::ios::ate);
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));

View File

@ -40,9 +40,9 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
public:
using FstType = PathTrie::FstType;
public:
Scorer() = default;
~Scorer() = default;
@ -76,12 +76,15 @@ public:
// return the max order
size_t get_max_order() const { return max_order_; }
// retrun true if the language model is character based
// return true if the language model is character based
bool is_utf8_mode() const { return is_utf8_mode_; }
// reset params alpha & beta
void reset_params(float alpha, float beta);
// force set UTF-8 mode, ignore value read from file
void set_utf8_mode(bool utf8) { is_utf8_mode_ = utf8; }
// make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix);
@ -89,12 +92,20 @@ public:
// the vector of characters (character based lm)
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
void set_alphabet(const Alphabet& alphabet);
// save dictionary in file
void save_dictionary(const std::string &path);
// return weather this step represents a boundary where beam scoring should happen
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
// fill dictionary FST from a vocabulary
void fill_dictionary(const std::vector<std::string> &vocabulary);
// load language model from given path
void load_lm(const std::string &lm_path, const std::string &trie_path);
// language model weight
double alpha = 0.;
// word insertion weight
@ -104,14 +115,8 @@ public:
std::unique_ptr<FstType> dictionary;
protected:
// necessary setup: load language model, fill FST's dictionary
void setup(const std::string &lm_path, const std::string &trie_path);
// load language model from given path
void load_lm(const std::string &lm_path);
// fill dictionary for FST
void fill_dictionary(const std::vector<std::string> &vocabulary);
// necessary setup after setting alphabet
void setup_char_map();
private:
std::unique_ptr<lm::base::Model> language_model_;

View File

@ -16,6 +16,10 @@
import_array();
%}
namespace std {
%template(StringVector) vector<string>;
}
// Convert NumPy arrays to pointer+lengths
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};