Avoid reconstructing strings twice on decode

This commit is contained in:
Reuben Morais 2019-11-09 14:35:24 +01:00
parent c1b1a59423
commit 0e6952c3a8
5 changed files with 18 additions and 35 deletions

View File

@ -192,27 +192,30 @@ DecoderState::decode() const
std::bind(prefix_compare_external, _1, _2, scores));
//TODO: expose this as an API parameter
const int top_paths = 1;
const size_t top_paths = 1;
size_t num_returned = std::min(num_prefixes, top_paths);
std::vector<Output> outputs;
outputs.reserve(num_returned);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
for (size_t i = 0; i < num_returned; ++i) {
Output output;
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
double approx_ctc = scores[prefixes_copy[i]];
if (ext_scorer_ != nullptr) {
std::vector<int> output;
std::vector<int> timesteps;
prefixes_copy[i]->get_path_vec(output, timesteps);
auto prefix_length = output.size();
auto words = ext_scorer_->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer_->beta;
// remove language model weight:
auto words = ext_scorer_->split_labels_into_scored_units(output.tokens);
// remove term insertion weight
approx_ctc -= words.size() * ext_scorer_->beta;
// remove language model weight
approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha;
}
prefixes_copy[i]->approx_ctc = approx_ctc;
output.confidence = -approx_ctc;
outputs.push_back(output);
}
return get_beam_search_result(prefixes_copy, top_paths);
return outputs;
}
std::vector<Output> ctc_beam_search_decoder(

View File

@ -38,21 +38,6 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
return log_prob_idx;
}
std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
size_t top_paths) {
std::vector<Output> output_vecs;
for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) {
Output output;
prefixes[i]->get_path_vec(output.tokens, output.timesteps);
output.confidence = -prefixes[i]->approx_ctc;
output_vecs.push_back(output);
}
return output_vecs;
}
size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0;
for (char c : str) {

View File

@ -59,11 +59,6 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<Output> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
size_t top_paths);
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);

View File

@ -273,14 +273,14 @@ void Scorer::reset_params(float alpha, float beta)
this->beta = beta;
}
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels)
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
{
if (labels.empty()) return {};
std::string s = alphabet_.LabelsToString(labels);
std::vector<std::string> words;
if (is_utf8_mode_) {
words = split_into_bytes(s);
words = split_into_codepoints(s);
} else {
words = split_str(s, " ");
}

View File

@ -87,7 +87,7 @@ public:
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels);
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
// save dictionary in file
void save_dictionary(const std::string &path);