135 lines
3.9 KiB
C++
135 lines
3.9 KiB
C++
#ifndef SCORER_H_
|
|
#define SCORER_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#include "flashlight/lib/text/decoder/lm/KenLM.h"
|
|
|
|
#include "path_trie.h"
|
|
#include "alphabet.h"
|
|
#include "coqui-stt.h"
|
|
|
|
const double OOV_SCORE = -1000.0;
|
|
const std::string START_TOKEN = "<s>";
|
|
const std::string UNK_TOKEN = "<unk>";
|
|
const std::string END_TOKEN = "</s>";
|
|
|
|
/* External scorer to query score for n-gram or sentence, including language
|
|
* model scoring and word insertion.
|
|
*
|
|
* Example:
|
|
* Scorer scorer(alpha, beta, "path_of_language_model");
|
|
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
|
|
*/
|
|
class Scorer : public fl::lib::text::LM {
|
|
public:
|
|
using FstType = PathTrie::FstType;
|
|
|
|
Scorer();
|
|
~Scorer();
|
|
|
|
// disallow copying
|
|
Scorer(const Scorer&) = delete;
|
|
Scorer& operator=(const Scorer&) = delete;
|
|
|
|
int init(const std::string &lm_path,
|
|
const Alphabet &alphabet);
|
|
|
|
int init(const std::string &lm_path,
|
|
const std::string &alphabet_config_path);
|
|
|
|
double get_log_cond_prob(const std::vector<std::string> &words,
|
|
bool bos = false,
|
|
bool eos = false);
|
|
|
|
double get_log_cond_prob(const std::vector<std::string>::const_iterator &begin,
|
|
const std::vector<std::string>::const_iterator &end,
|
|
bool bos = false,
|
|
bool eos = false);
|
|
|
|
// return the max order
|
|
size_t get_max_order() const { return max_order_; }
|
|
|
|
// return true if the language model is character based
|
|
bool is_utf8_mode() const { return is_utf8_mode_; }
|
|
|
|
// reset params alpha & beta
|
|
void reset_params(float alpha, float beta);
|
|
|
|
// force set UTF-8 mode, ignore value read from file
|
|
void set_utf8_mode(bool utf8) { is_utf8_mode_ = utf8; }
|
|
|
|
// make ngram for a given prefix
|
|
std::vector<std::string> make_ngram(PathTrie *prefix);
|
|
|
|
// trransform the labels in index to the vector of words (word based lm) or
|
|
// the vector of characters (character based lm)
|
|
std::vector<std::string> split_labels_into_scored_units(const std::vector<unsigned int> &labels);
|
|
|
|
void set_alphabet(const Alphabet& alphabet);
|
|
|
|
// save dictionary in file
|
|
bool save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
|
|
|
// return weather this step represents a boundary where beam scoring should happen
|
|
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
|
|
|
// fill dictionary FST from a vocabulary
|
|
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);
|
|
|
|
// load language model from given path
|
|
int load_lm(const std::string &lm_path);
|
|
|
|
// language model weight
|
|
double alpha = 0.;
|
|
// word insertion weight
|
|
double beta = 0.;
|
|
|
|
// pointer to the dictionary of FST
|
|
std::unique_ptr<FstType> dictionary;
|
|
|
|
// ---------------
|
|
// fl::lib::text::LM methods
|
|
|
|
/* Initialize or reset language model state */
|
|
fl::lib::text::LMStatePtr start(bool startWithNothing);
|
|
|
|
/**
|
|
* Query the language model given input state and a specific token, return a
|
|
* new language model state and score.
|
|
*/
|
|
std::pair<fl::lib::text::LMStatePtr, float> score(
|
|
const fl::lib::text::LMStatePtr& state,
|
|
const int usrTokenIdx);
|
|
|
|
/* Query the language model and finish decoding. */
|
|
std::pair<fl::lib::text::LMStatePtr, float> finish(const fl::lib::text::LMStatePtr& state);
|
|
|
|
// ---------------
|
|
// fl::lib::text helper
|
|
|
|
// Must be called before use of this Scorer with Flashlight APIs.
|
|
void load_words(const fl::lib::text::Dictionary& word_dict);
|
|
|
|
protected:
|
|
// necessary setup after setting alphabet
|
|
void setup_char_map();
|
|
|
|
int load_trie(std::ifstream& fin, const std::string& file_path);
|
|
|
|
private:
|
|
std::unique_ptr<lm::base::Model> language_model_;
|
|
bool is_utf8_mode_ = true;
|
|
size_t max_order_ = 0;
|
|
|
|
int SPACE_ID_;
|
|
Alphabet alphabet_;
|
|
std::unordered_map<std::string, int> char_map_;
|
|
};
|
|
|
|
#endif // SCORER_H_
|