Refactor Scorer so model/trie package can be created by an external tool
This commit is contained in:
parent
7c0354483e
commit
be2229ef29
@ -1,6 +1,7 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from . import swigwrapper # pylint: disable=import-self
|
from . import swigwrapper # pylint: disable=import-self
|
||||||
|
from .swigwrapper import Alphabet
|
||||||
|
|
||||||
__version__ = swigwrapper.__version__
|
__version__ = swigwrapper.__version__
|
||||||
|
|
||||||
@ -16,7 +17,6 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
:alphabet: Alphabet
|
:alphabet: Alphabet
|
||||||
:type model_path: basestring
|
:type model_path: basestring
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
||||||
super(Scorer, self).__init__()
|
super(Scorer, self).__init__()
|
||||||
serialized = alphabet.serialize()
|
serialized = alphabet.serialize()
|
||||||
@ -32,6 +32,15 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Scorer, self).__init__()
|
||||||
|
|
||||||
|
def load_lm(self, lm_path, trie_path):
|
||||||
|
super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8'))
|
||||||
|
|
||||||
|
def save_dictionary(self, save_path):
|
||||||
|
super(Scorer, self).save_dictionary(save_path.encode('utf-8'))
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
alphabet,
|
alphabet,
|
||||||
|
@ -36,7 +36,7 @@ DecoderState::init(const Alphabet& alphabet,
|
|||||||
prefix_root_.reset(root);
|
prefix_root_.reset(root);
|
||||||
prefixes_.push_back(root);
|
prefixes_.push_back(root);
|
||||||
|
|
||||||
if (ext_scorer != nullptr) {
|
if (ext_scorer != nullptr && (bool)ext_scorer_->dictionary) {
|
||||||
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
||||||
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
||||||
root->set_dictionary(dict_ptr);
|
root->set_dictionary(dict_ptr);
|
||||||
|
@ -38,7 +38,8 @@ Scorer::init(double alpha,
|
|||||||
{
|
{
|
||||||
reset_params(alpha, beta);
|
reset_params(alpha, beta);
|
||||||
alphabet_ = alphabet;
|
alphabet_ = alphabet;
|
||||||
setup(lm_path, trie_path);
|
setup_char_map();
|
||||||
|
load_lm(lm_path, trie_path);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,11 +55,19 @@ Scorer::init(double alpha,
|
|||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
setup(lm_path, trie_path);
|
setup_char_map();
|
||||||
|
load_lm(lm_path, trie_path);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
void
|
||||||
|
Scorer::set_alphabet(const Alphabet& alphabet)
|
||||||
|
{
|
||||||
|
alphabet_ = alphabet;
|
||||||
|
setup_char_map();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::setup_char_map()
|
||||||
{
|
{
|
||||||
// (Re-)Initialize character map
|
// (Re-)Initialize character map
|
||||||
char_map_.clear();
|
char_map_.clear();
|
||||||
@ -71,52 +80,57 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
|||||||
// state, otherwise wrong decoding results would be given.
|
// state, otherwise wrong decoding results would be given.
|
||||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
||||||
|
{
|
||||||
// load language model
|
// load language model
|
||||||
const char* filename = lm_path.c_str();
|
const char* filename = lm_path.c_str();
|
||||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||||
|
|
||||||
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
|
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
|
||||||
VALID_CHECK(has_trie, "Invalid trie path");
|
// VALID_CHECK(has_trie, "Invalid trie path");
|
||||||
|
|
||||||
lm::ngram::Config config;
|
lm::ngram::Config config;
|
||||||
config.load_method = util::LoadMethod::LAZY;
|
config.load_method = util::LoadMethod::LAZY;
|
||||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||||
|
|
||||||
// Read metadata and trie from file
|
if (has_trie) {
|
||||||
std::ifstream fin(trie_path, std::ios::binary);
|
// Read metadata and trie from file
|
||||||
|
std::ifstream fin(trie_path, std::ios::binary);
|
||||||
|
|
||||||
int magic;
|
int magic;
|
||||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||||
if (magic != MAGIC) {
|
if (magic != MAGIC) {
|
||||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||||
"your trie file." << std::endl;
|
"your trie file." << std::endl;
|
||||||
throw 1;
|
throw 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int version;
|
||||||
|
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||||
|
if (version != FILE_VERSION) {
|
||||||
|
std::cerr << "Error: Trie file version mismatch (" << version
|
||||||
|
<< " instead of expected " << FILE_VERSION
|
||||||
|
<< "). Update your trie file."
|
||||||
|
<< std::endl;
|
||||||
|
throw 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
|
||||||
|
fst::FstReadOptions opt;
|
||||||
|
opt.mode = fst::FstReadOptions::MAP;
|
||||||
|
opt.source = trie_path;
|
||||||
|
dictionary.reset(FstType::Read(fin, opt));
|
||||||
}
|
}
|
||||||
|
|
||||||
int version;
|
|
||||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
|
||||||
if (version != FILE_VERSION) {
|
|
||||||
std::cerr << "Error: Trie file version mismatch (" << version
|
|
||||||
<< " instead of expected " << FILE_VERSION
|
|
||||||
<< "). Update your trie file."
|
|
||||||
<< std::endl;
|
|
||||||
throw 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
|
||||||
|
|
||||||
fst::FstReadOptions opt;
|
|
||||||
opt.mode = fst::FstReadOptions::MAP;
|
|
||||||
opt.source = trie_path;
|
|
||||||
dictionary_.reset(FstType::Read(fin, opt));
|
|
||||||
|
|
||||||
max_order_ = language_model_->Order();
|
max_order_ = language_model_->Order();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::save_dictionary(const std::string& path)
|
void Scorer::save_dictionary(const std::string& path)
|
||||||
{
|
{
|
||||||
std::ofstream fout(path, std::ios::binary);
|
std::fstream fout(path, std::ios::in|std::ios::out|std::ios::binary|std::ios::ate);
|
||||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||||
|
@ -40,9 +40,9 @@ public:
|
|||||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||||
*/
|
*/
|
||||||
class Scorer {
|
class Scorer {
|
||||||
|
public:
|
||||||
using FstType = PathTrie::FstType;
|
using FstType = PathTrie::FstType;
|
||||||
|
|
||||||
public:
|
|
||||||
Scorer() = default;
|
Scorer() = default;
|
||||||
~Scorer() = default;
|
~Scorer() = default;
|
||||||
|
|
||||||
@ -76,12 +76,15 @@ public:
|
|||||||
// return the max order
|
// return the max order
|
||||||
size_t get_max_order() const { return max_order_; }
|
size_t get_max_order() const { return max_order_; }
|
||||||
|
|
||||||
// retrun true if the language model is character based
|
// return true if the language model is character based
|
||||||
bool is_utf8_mode() const { return is_utf8_mode_; }
|
bool is_utf8_mode() const { return is_utf8_mode_; }
|
||||||
|
|
||||||
// reset params alpha & beta
|
// reset params alpha & beta
|
||||||
void reset_params(float alpha, float beta);
|
void reset_params(float alpha, float beta);
|
||||||
|
|
||||||
|
// force set UTF-8 mode, ignore value read from file
|
||||||
|
void set_utf8_mode(bool utf8) { is_utf8_mode_ = utf8; }
|
||||||
|
|
||||||
// make ngram for a given prefix
|
// make ngram for a given prefix
|
||||||
std::vector<std::string> make_ngram(PathTrie *prefix);
|
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||||
|
|
||||||
@ -89,12 +92,20 @@ public:
|
|||||||
// the vector of characters (character based lm)
|
// the vector of characters (character based lm)
|
||||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
||||||
|
|
||||||
|
void set_alphabet(const Alphabet& alphabet);
|
||||||
|
|
||||||
// save dictionary in file
|
// save dictionary in file
|
||||||
void save_dictionary(const std::string &path);
|
void save_dictionary(const std::string &path);
|
||||||
|
|
||||||
// return weather this step represents a boundary where beam scoring should happen
|
// return weather this step represents a boundary where beam scoring should happen
|
||||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||||
|
|
||||||
|
// fill dictionary FST from a vocabulary
|
||||||
|
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||||
|
|
||||||
|
// load language model from given path
|
||||||
|
void load_lm(const std::string &lm_path, const std::string &trie_path);
|
||||||
|
|
||||||
// language model weight
|
// language model weight
|
||||||
double alpha = 0.;
|
double alpha = 0.;
|
||||||
// word insertion weight
|
// word insertion weight
|
||||||
@ -104,14 +115,8 @@ public:
|
|||||||
std::unique_ptr<FstType> dictionary;
|
std::unique_ptr<FstType> dictionary;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// necessary setup: load language model, fill FST's dictionary
|
// necessary setup after setting alphabet
|
||||||
void setup(const std::string &lm_path, const std::string &trie_path);
|
void setup_char_map();
|
||||||
|
|
||||||
// load language model from given path
|
|
||||||
void load_lm(const std::string &lm_path);
|
|
||||||
|
|
||||||
// fill dictionary for FST
|
|
||||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<lm::base::Model> language_model_;
|
std::unique_ptr<lm::base::Model> language_model_;
|
||||||
|
@ -16,6 +16,10 @@
|
|||||||
import_array();
|
import_array();
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
%template(StringVector) vector<string>;
|
||||||
|
}
|
||||||
|
|
||||||
// Convert NumPy arrays to pointer+lengths
|
// Convert NumPy arrays to pointer+lengths
|
||||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user