Avoid reconstructing strings twice on decode
This commit is contained in:
parent
c1b1a59423
commit
0e6952c3a8
@ -192,27 +192,30 @@ DecoderState::decode() const
|
|||||||
std::bind(prefix_compare_external, _1, _2, scores));
|
std::bind(prefix_compare_external, _1, _2, scores));
|
||||||
|
|
||||||
//TODO: expose this as an API parameter
|
//TODO: expose this as an API parameter
|
||||||
const int top_paths = 1;
|
const size_t top_paths = 1;
|
||||||
|
size_t num_returned = std::min(num_prefixes, top_paths);
|
||||||
|
|
||||||
|
std::vector<Output> outputs;
|
||||||
|
outputs.reserve(num_returned);
|
||||||
|
|
||||||
// compute aproximate ctc score as the return score, without affecting the
|
// compute aproximate ctc score as the return score, without affecting the
|
||||||
// return order of decoding result. To delete when decoder gets stable.
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
for (size_t i = 0; i < num_returned; ++i) {
|
||||||
|
Output output;
|
||||||
|
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
|
||||||
double approx_ctc = scores[prefixes_copy[i]];
|
double approx_ctc = scores[prefixes_copy[i]];
|
||||||
if (ext_scorer_ != nullptr) {
|
if (ext_scorer_ != nullptr) {
|
||||||
std::vector<int> output;
|
auto words = ext_scorer_->split_labels_into_scored_units(output.tokens);
|
||||||
std::vector<int> timesteps;
|
// remove term insertion weight
|
||||||
prefixes_copy[i]->get_path_vec(output, timesteps);
|
approx_ctc -= words.size() * ext_scorer_->beta;
|
||||||
auto prefix_length = output.size();
|
// remove language model weight
|
||||||
auto words = ext_scorer_->split_labels(output);
|
|
||||||
// remove word insert
|
|
||||||
approx_ctc = approx_ctc - prefix_length * ext_scorer_->beta;
|
|
||||||
// remove language model weight:
|
|
||||||
approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha;
|
approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha;
|
||||||
}
|
}
|
||||||
prefixes_copy[i]->approx_ctc = approx_ctc;
|
output.confidence = -approx_ctc;
|
||||||
|
outputs.push_back(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
return get_beam_search_result(prefixes_copy, top_paths);
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Output> ctc_beam_search_decoder(
|
std::vector<Output> ctc_beam_search_decoder(
|
||||||
|
@ -38,21 +38,6 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
|||||||
return log_prob_idx;
|
return log_prob_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::vector<Output> get_beam_search_result(
|
|
||||||
const std::vector<PathTrie *> &prefixes,
|
|
||||||
size_t top_paths) {
|
|
||||||
std::vector<Output> output_vecs;
|
|
||||||
for (size_t i = 0; i < top_paths && i < prefixes.size(); ++i) {
|
|
||||||
Output output;
|
|
||||||
prefixes[i]->get_path_vec(output.tokens, output.timesteps);
|
|
||||||
output.confidence = -prefixes[i]->approx_ctc;
|
|
||||||
output_vecs.push_back(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
return output_vecs;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t get_utf8_str_len(const std::string &str) {
|
size_t get_utf8_str_len(const std::string &str) {
|
||||||
size_t str_len = 0;
|
size_t str_len = 0;
|
||||||
for (char c : str) {
|
for (char c : str) {
|
||||||
|
@ -59,11 +59,6 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
|||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n);
|
size_t cutoff_top_n);
|
||||||
|
|
||||||
// Get beam search result from prefixes in trie tree
|
|
||||||
std::vector<Output> get_beam_search_result(
|
|
||||||
const std::vector<PathTrie *> &prefixes,
|
|
||||||
size_t top_paths);
|
|
||||||
|
|
||||||
// Functor for prefix comparsion
|
// Functor for prefix comparsion
|
||||||
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||||
|
|
||||||
|
@ -273,14 +273,14 @@ void Scorer::reset_params(float alpha, float beta)
|
|||||||
this->beta = beta;
|
this->beta = beta;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels)
|
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
|
||||||
{
|
{
|
||||||
if (labels.empty()) return {};
|
if (labels.empty()) return {};
|
||||||
|
|
||||||
std::string s = alphabet_.LabelsToString(labels);
|
std::string s = alphabet_.LabelsToString(labels);
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
words = split_into_bytes(s);
|
words = split_into_codepoints(s);
|
||||||
} else {
|
} else {
|
||||||
words = split_str(s, " ");
|
words = split_str(s, " ");
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ public:
|
|||||||
|
|
||||||
// trransform the labels in index to the vector of words (word based lm) or
|
// trransform the labels in index to the vector of words (word based lm) or
|
||||||
// the vector of characters (character based lm)
|
// the vector of characters (character based lm)
|
||||||
std::vector<std::string> split_labels(const std::vector<int> &labels);
|
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
||||||
|
|
||||||
// save dictionary in file
|
// save dictionary in file
|
||||||
void save_dictionary(const std::string &path);
|
void save_dictionary(const std::string &path);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user