Avoid reconstructing strings twice on decode
This commit is contained in:
parent
c1b1a59423
commit
0e6952c3a8
@ -192,27 +192,30 @@ DecoderState::decode() const
|
||||
std::bind(prefix_compare_external, _1, _2, scores));
|
||||
|
||||
//TODO: expose this as an API parameter
|
||||
const int top_paths = 1;
|
||||
const size_t top_paths = 1;
|
||||
size_t num_returned = std::min(num_prefixes, top_paths);
|
||||
|
||||
std::vector<Output> outputs;
|
||||
outputs.reserve(num_returned);
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
||||
for (size_t i = 0; i < num_returned; ++i) {
|
||||
Output output;
|
||||
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
|
||||
double approx_ctc = scores[prefixes_copy[i]];
|
||||
if (ext_scorer_ != nullptr) {
|
||||
std::vector<int> output;
|
||||
std::vector<int> timesteps;
|
||||
prefixes_copy[i]->get_path_vec(output, timesteps);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer_->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer_->beta;
|
||||
// remove language model weight:
|
||||
auto words = ext_scorer_->split_labels_into_scored_units(output.tokens);
|
||||
// remove term insertion weight
|
||||
approx_ctc -= words.size() * ext_scorer_->beta;
|
||||
// remove language model weight
|
||||
approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha;
|
||||
}
|
||||
prefixes_copy[i]->approx_ctc = approx_ctc;
|
||||
output.confidence = -approx_ctc;
|
||||
outputs.push_back(output);
|
||||
}
|
||||
|
||||
return get_beam_search_result(prefixes_copy, top_paths);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<Output> ctc_beam_search_decoder(
|
||||
|
@ -38,21 +38,6 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
return log_prob_idx;
|
||||
}
|
||||
|
||||
|
||||
std::vector<Output> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
size_t top_paths) {
|
||||
std::vector<Output> output_vecs;
|
||||
for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) {
|
||||
Output output;
|
||||
prefixes[i]->get_path_vec(output.tokens, output.timesteps);
|
||||
output.confidence = -prefixes[i]->approx_ctc;
|
||||
output_vecs.push_back(output);
|
||||
}
|
||||
|
||||
return output_vecs;
|
||||
}
|
||||
|
||||
size_t get_utf8_str_len(const std::string &str) {
|
||||
size_t str_len = 0;
|
||||
for (char c : str) {
|
||||
|
@ -59,11 +59,6 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n);
|
||||
|
||||
// Get beam search result from prefixes in trie tree
|
||||
std::vector<Output> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
size_t top_paths);
|
||||
|
||||
// Functor for prefix comparsion
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||
|
||||
|
@ -273,14 +273,14 @@ void Scorer::reset_params(float alpha, float beta)
|
||||
this->beta = beta;
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels)
|
||||
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
|
||||
{
|
||||
if (labels.empty()) return {};
|
||||
|
||||
std::string s = alphabet_.LabelsToString(labels);
|
||||
std::vector<std::string> words;
|
||||
if (is_utf8_mode_) {
|
||||
words = split_into_bytes(s);
|
||||
words = split_into_codepoints(s);
|
||||
} else {
|
||||
words = split_str(s, " ");
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ public:
|
||||
|
||||
// trransform the labels in index to the vector of words (word based lm) or
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels(const std::vector<int> &labels);
|
||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
||||
|
||||
// save dictionary in file
|
||||
void save_dictionary(const std::string &path);
|
||||
|
Loading…
x
Reference in New Issue
Block a user