Refactor Scorer so model/trie package can be created by an external tool
This commit is contained in:
parent
7c0354483e
commit
be2229ef29
@ -1,6 +1,7 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from . import swigwrapper # pylint: disable=import-self
|
||||
from .swigwrapper import Alphabet
|
||||
|
||||
__version__ = swigwrapper.__version__
|
||||
|
||||
@ -16,7 +17,6 @@ class Scorer(swigwrapper.Scorer):
|
||||
:alphabet: Alphabet
|
||||
:type model_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
||||
super(Scorer, self).__init__()
|
||||
serialized = alphabet.serialize()
|
||||
@ -32,6 +32,15 @@ class Scorer(swigwrapper.Scorer):
|
||||
if err != 0:
|
||||
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
||||
|
||||
def __init__(self):
|
||||
super(Scorer, self).__init__()
|
||||
|
||||
def load_lm(self, lm_path, trie_path):
|
||||
super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8'))
|
||||
|
||||
def save_dictionary(self, save_path):
|
||||
super(Scorer, self).save_dictionary(save_path.encode('utf-8'))
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
alphabet,
|
||||
|
@ -36,7 +36,7 @@ DecoderState::init(const Alphabet& alphabet,
|
||||
prefix_root_.reset(root);
|
||||
prefixes_.push_back(root);
|
||||
|
||||
if (ext_scorer != nullptr) {
|
||||
if (ext_scorer != nullptr && (bool)ext_scorer_->dictionary) {
|
||||
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
||||
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
||||
root->set_dictionary(dict_ptr);
|
||||
|
@ -38,7 +38,8 @@ Scorer::init(double alpha,
|
||||
{
|
||||
reset_params(alpha, beta);
|
||||
alphabet_ = alphabet;
|
||||
setup(lm_path, trie_path);
|
||||
setup_char_map();
|
||||
load_lm(lm_path, trie_path);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -54,11 +55,19 @@ Scorer::init(double alpha,
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
setup(lm_path, trie_path);
|
||||
setup_char_map();
|
||||
load_lm(lm_path, trie_path);
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
||||
void
|
||||
Scorer::set_alphabet(const Alphabet& alphabet)
|
||||
{
|
||||
alphabet_ = alphabet;
|
||||
setup_char_map();
|
||||
}
|
||||
|
||||
void Scorer::setup_char_map()
|
||||
{
|
||||
// (Re-)Initialize character map
|
||||
char_map_.clear();
|
||||
@ -71,52 +80,57 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
||||
// state, otherwise wrong decoding results would be given.
|
||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
||||
{
|
||||
// load language model
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||
|
||||
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
|
||||
VALID_CHECK(has_trie, "Invalid trie path");
|
||||
// VALID_CHECK(has_trie, "Invalid trie path");
|
||||
|
||||
lm::ngram::Config config;
|
||||
config.load_method = util::LoadMethod::LAZY;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(trie_path, std::ios::binary);
|
||||
if (has_trie) {
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(trie_path, std::ios::binary);
|
||||
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||
"your trie file." << std::endl;
|
||||
throw 1;
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||
"your trie file." << std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
int version;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Trie file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). Update your trie file."
|
||||
<< std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = trie_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
}
|
||||
|
||||
int version;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Trie file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). Update your trie file."
|
||||
<< std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = trie_path;
|
||||
dictionary_.reset(FstType::Read(fin, opt));
|
||||
|
||||
max_order_ = language_model_->Order();
|
||||
}
|
||||
|
||||
void Scorer::save_dictionary(const std::string& path)
|
||||
{
|
||||
std::ofstream fout(path, std::ios::binary);
|
||||
std::fstream fout(path, std::ios::in|std::ios::out|std::ios::binary|std::ios::ate);
|
||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
@ -40,9 +40,9 @@ public:
|
||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
*/
|
||||
class Scorer {
|
||||
public:
|
||||
using FstType = PathTrie::FstType;
|
||||
|
||||
public:
|
||||
Scorer() = default;
|
||||
~Scorer() = default;
|
||||
|
||||
@ -76,12 +76,15 @@ public:
|
||||
// return the max order
|
||||
size_t get_max_order() const { return max_order_; }
|
||||
|
||||
// retrun true if the language model is character based
|
||||
// return true if the language model is character based
|
||||
bool is_utf8_mode() const { return is_utf8_mode_; }
|
||||
|
||||
// reset params alpha & beta
|
||||
void reset_params(float alpha, float beta);
|
||||
|
||||
// force set UTF-8 mode, ignore value read from file
|
||||
void set_utf8_mode(bool utf8) { is_utf8_mode_ = utf8; }
|
||||
|
||||
// make ngram for a given prefix
|
||||
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||
|
||||
@ -89,12 +92,20 @@ public:
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
||||
|
||||
void set_alphabet(const Alphabet& alphabet);
|
||||
|
||||
// save dictionary in file
|
||||
void save_dictionary(const std::string &path);
|
||||
|
||||
// return weather this step represents a boundary where beam scoring should happen
|
||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||
|
||||
// fill dictionary FST from a vocabulary
|
||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path, const std::string &trie_path);
|
||||
|
||||
// language model weight
|
||||
double alpha = 0.;
|
||||
// word insertion weight
|
||||
@ -104,14 +115,8 @@ public:
|
||||
std::unique_ptr<FstType> dictionary;
|
||||
|
||||
protected:
|
||||
// necessary setup: load language model, fill FST's dictionary
|
||||
void setup(const std::string &lm_path, const std::string &trie_path);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path);
|
||||
|
||||
// fill dictionary for FST
|
||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
// necessary setup after setting alphabet
|
||||
void setup_char_map();
|
||||
|
||||
private:
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
|
@ -16,6 +16,10 @@
|
||||
import_array();
|
||||
%}
|
||||
|
||||
namespace std {
|
||||
%template(StringVector) vector<string>;
|
||||
}
|
||||
|
||||
// Convert NumPy arrays to pointer+lengths
|
||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||
|
Loading…
Reference in New Issue
Block a user